├── .gitignore ├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── LICENSE_MODEL ├── README.md ├── app ├── Configuration │ └── Build.xcconfig ├── FastVLM App │ ├── Assets.xcassets │ │ ├── AccentColor.colorset │ │ │ └── Contents.json │ │ ├── AppIcon.appiconset │ │ │ ├── Contents.json │ │ │ ├── FastVLM - 150 Blue - Dark@2x.png │ │ │ ├── FastVLM - 150 Blue - Light@2x.png │ │ │ ├── FastVLM - 150 White - Tinted@2x.png │ │ │ ├── FastVLM - MacOS - Dark@1x.png │ │ │ └── FastVLM - MacOS - Dark@2x.png │ │ └── Contents.json │ ├── ContentView.swift │ ├── FastVLM.entitlements │ ├── FastVLMApp.swift │ ├── FastVLMModel.swift │ ├── Info.plist │ ├── InfoView.swift │ └── Preview Content │ │ └── Preview Assets.xcassets │ │ └── Contents.json ├── FastVLM.xcodeproj │ ├── project.pbxproj │ └── xcshareddata │ │ └── xcschemes │ │ └── FastVLM App.xcscheme ├── FastVLM │ ├── FastVLM.h │ ├── FastVLM.swift │ └── MediaProcessingExtensions.swift ├── README.md ├── Video │ ├── CameraController.swift │ ├── CameraControlsView.swift │ ├── CameraType.swift │ ├── Video.h │ └── VideoFrameView.swift └── get_pretrained_mlx_model.sh ├── docs ├── acc_vs_latency_qwen-2.png ├── fastvlm-counting.gif ├── fastvlm-emoji.gif ├── fastvlm-flexible_prompts.png └── fastvlm-handwriting.gif ├── get_models.sh ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_mistral.py │ │ ├── llava_mpt.py │ │ └── llava_qwen.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ ├── mobileclip │ │ │ ├── __init__.py │ │ │ ├── configs │ │ │ │ └── mobileclip_l.json │ │ │ └── mci.py │ │ └── mobileclip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ ├── sglang_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_xformers_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ ├── train_mem.py │ ├── train_qwen.py │ └── train_xformers.py └── utils.py ├── model_export ├── README.md ├── export_vision_encoder.py └── fastvlm_mlx-vlm.patch ├── predict.py └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # macOS 132 | **/.DS_Store 133 | 134 | # PyCharm project settings 135 | .idea/ 136 | 137 | # Xcode 138 | *.xcworkspace 139 | 140 | # FastVLM models 141 | app/FastVLM/model -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED WITH ML-FASTVLM: 43 | 44 | The ml-fastvlm software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 46 | 47 | The ml-fastvlm model weights copyright and license terms can be 48 | found in LICENSE_MODEL file. 49 | ------------------------------------------------------------------------------- 50 | -------------------------------------------------------------------------------- /LICENSE_MODEL: -------------------------------------------------------------------------------- 1 | Disclaimer: IMPORTANT: This Apple Machine Learning Research Model is 2 | specifically developed and released by Apple Inc. ("Apple") for the sole purpose 3 | of scientific research of artificial intelligence and machine-learning 4 | technology. “Apple Machine Learning Research Model” means the model, including 5 | but not limited to algorithms, formulas, trained model weights, parameters, 6 | configurations, checkpoints, and any related materials (including 7 | documentation). 8 | 9 | This Apple Machine Learning Research Model is provided to You by 10 | Apple in consideration of your agreement to the following terms, and your use, 11 | modification, creation of Model Derivatives, and or redistribution of the Apple 12 | Machine Learning Research Model constitutes acceptance of this Agreement. If You 13 | do not agree with these terms, please do not use, modify, create Model 14 | Derivatives of, or distribute this Apple Machine Learning Research Model or 15 | Model Derivatives. 16 | 17 | * License Scope: In consideration of your agreement to abide by the following 18 | terms, and subject to these terms, Apple hereby grants you a personal, 19 | non-exclusive, worldwide, non-transferable, royalty-free, revocable, and 20 | limited license, to use, copy, modify, distribute, and create Model 21 | Derivatives (defined below) of the Apple Machine Learning Research Model 22 | exclusively for Research Purposes. You agree that any Model Derivatives You 23 | may create or that may be created for You will be limited to Research Purposes 24 | as well. “Research Purposes” means non-commercial scientific research and 25 | academic development activities, such as experimentation, analysis, testing 26 | conducted by You with the sole intent to advance scientific knowledge and 27 | research. “Research Purposes” does not include any commercial exploitation, 28 | product development or use in any commercial product or service. 29 | 30 | * Distribution of Apple Machine Learning Research Model and Model Derivatives: 31 | If you choose to redistribute Apple Machine Learning Research Model or its 32 | Model Derivatives, you must provide a copy of this Agreement to such third 33 | party, and ensure that the following attribution notice be provided: “Apple 34 | Machine Learning Research Model is licensed under the Apple Machine Learning 35 | Research Model License Agreement.” Additionally, all Model Derivatives must 36 | clearly be identified as such, including disclosure of modifications and 37 | changes made to the Apple Machine Learning Research Model. The name, 38 | trademarks, service marks or logos of Apple may not be used to endorse or 39 | promote Model Derivatives or the relationship between You and Apple. “Model 40 | Derivatives” means any models or any other artifacts created by modifications, 41 | improvements, adaptations, alterations to the architecture, algorithm or 42 | training processes of the Apple Machine Learning Research Model, or by any 43 | retraining, fine-tuning of the Apple Machine Learning Research Model. 44 | 45 | * No Other License: Except as expressly stated in this notice, no other rights 46 | or licenses, express or implied, are granted by Apple herein, including but 47 | not limited to any patent, trademark, and similar intellectual property rights 48 | worldwide that may be infringed by the Apple Machine Learning Research Model, 49 | the Model Derivatives or by other works in which the Apple Machine Learning 50 | Research Model may be incorporated. 51 | 52 | * Compliance with Laws: Your use of Apple Machine Learning Research Model must 53 | be in compliance with all applicable laws and regulations. 54 | 55 | * Term and Termination: The term of this Agreement will begin upon your 56 | acceptance of this Agreement or use of the Apple Machine Learning Research 57 | Model and will continue until terminated in accordance with the following 58 | terms. Apple may terminate this Agreement at any time if You are in breach of 59 | any term or condition of this Agreement. Upon termination of this Agreement, 60 | You must cease to use all Apple Machine Learning Research Models and Model 61 | Derivatives and permanently delete any copy thereof. Sections 3, 6 and 7 will 62 | survive termination. 63 | 64 | * Disclaimer and Limitation of Liability: This Apple Machine Learning Research 65 | Model and any outputs generated by the Apple Machine Learning Research Model 66 | are provided on an “AS IS” basis. APPLE MAKES NO WARRANTIES, EXPRESS OR 67 | IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF 68 | NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, 69 | REGARDING THE APPLE MACHINE LEARNING RESEARCH MODEL OR OUTPUTS GENERATED BY 70 | THE APPLE MACHINE LEARNING RESEARCH MODEL. You are solely responsible for 71 | determining the appropriateness of using or redistributing the Apple Machine 72 | Learning Research Model and any outputs of the Apple Machine Learning Research 73 | Model and assume any risks associated with Your use of the Apple Machine 74 | Learning Research Model and any output and results. IN NO EVENT SHALL APPLE BE 75 | LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 76 | IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION AND/OR DISTRIBUTION OF 77 | THE APPLE MACHINE LEARNING RESEARCH MODEL AND ANY OUTPUTS OF THE APPLE MACHINE 78 | LEARNING RESEARCH MODEL, HOWEVER CAUSED AND WHETHER UNDER THEORY OF CONTRACT, 79 | TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS 80 | BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 81 | 82 | * Governing Law: This Agreement will be governed by and construed under the laws 83 | of the State of California without regard to its choice of law principles. The 84 | Convention on Contracts for the International Sale of Goods shall not apply to 85 | the Agreement except that the arbitration clause and any arbitration hereunder 86 | shall be governed by the Federal Arbitration Act, Chapters 1 and 2. 87 | 88 | Copyright (C) 2025 Apple Inc. All Rights Reserved. 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastVLM: Efficient Vision Encoding for Vision Language Models 2 | 3 | This is the official repository of 4 | **[FastVLM: Efficient Vision Encoding for Vision Language Models](https://www.arxiv.org/abs/2412.13303). (CVPR 2025)** 5 | 6 | [//]: # (![FastViTHD Performance](docs/acc_vs_latency_qwen-2.png)) 7 |

8 | Accuracy vs latency figure. 9 |

10 | 11 | ### Highlights 12 | * We introduce FastViTHD, a novel hybrid vision encoder designed to output fewer tokens and significantly reduce encoding time for high-resolution images. 13 | * Our smallest variant outperforms LLaVA-OneVision-0.5B with 85x faster Time-to-First-Token (TTFT) and 3.4x smaller vision encoder. 14 | * Our larger variants using Qwen2-7B LLM outperform recent works like Cambrian-1-8B while using a single image encoder with a 7.9x faster TTFT. 15 | * Demo iOS app to demonstrate the performance of our model on a mobile device. 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 |
FastVLM - CountingFastVLM - HandwritingFastVLM - Emoji
24 | 25 | ## Getting Started 26 | We use LLaVA codebase to train FastVLM variants. In order to train or finetune your own variants, 27 | please follow instructions provided in [LLaVA](https://github.com/haotian-liu/LLaVA) codebase. 28 | We provide instructions for running inference with our models. 29 | 30 | ### Setup 31 | ```bash 32 | conda create -n fastvlm python=3.10 33 | conda activate fastvlm 34 | pip install -e . 35 | ``` 36 | 37 | ### Model Zoo 38 | For detailed information on various evaluations, please refer to our [paper](https://www.arxiv.org/abs/2412.13303). 39 | 40 | | Model | Stage | Pytorch Checkpoint (url) | 41 | |:-------------|:-----:|:---------------------------------------------------------------------------------------------------------------:| 42 | | FastVLM-0.5B | 2 | [fastvlm_0.5b_stage2](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage2.zip) | 43 | | | 3 | [fastvlm_0.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage3.zip) | 44 | | FastVLM-1.5B | 2 | [fastvlm_1.5b_stage2](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage2.zip) | 45 | | | 3 | [fastvlm_1.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage3.zip) | 46 | | FastVLM-7B | 2 | [fastvlm_7b_stage2](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage2.zip) | 47 | | | 3 | [fastvlm_7b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3.zip) | 48 | 49 | To download all the pretrained checkpoints run the command below (note that this might take some time depending on your connection so might be good to grab ☕️ while you wait). 50 | 51 | ```bash 52 | bash get_models.sh # Files will be downloaded to `checkpoints` directory. 53 | ``` 54 | 55 | ### Usage Example 56 | To run inference of PyTorch checkpoint, follow the instruction below 57 | ```bash 58 | python predict.py --model-path /path/to/checkpoint-dir \ 59 | --image-file /path/to/image.png \ 60 | --prompt "Describe the image." 61 | ``` 62 | 63 | ### Inference on Apple Silicon 64 | To run inference on Apple Silicon, pytorch checkpoints have to be exported to format 65 | suitable for running on Apple Silicon, detailed instructions and code can be found [`model_export`](model_export/) subfolder. 66 | Please see the README there for more details. 67 | 68 | For convenience, we provide 3 models that are in Apple Silicon compatible format: [fastvlm_0.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage3_llm.fp16.zip), 69 | [fastvlm_1.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage3_llm.int8.zip), 70 | [fastvlm_7b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3_llm.int4.zip). 71 | We encourage developers to export the model of their choice with the appropriate quantization levels following 72 | the instructions in [`model_export`](model_export/). 73 | 74 | ### Inference on Apple Devices 75 | To run inference on Apple devices like iPhone, iPad or Mac, see [`app`](app/) subfolder for more details. 76 | 77 | ## Citation 78 | If you found this code useful, please cite the following paper: 79 | ``` 80 | @InProceedings{fastvlm2025, 81 | author = {Pavan Kumar Anasosalu Vasu, Fartash Faghri, Chun-Liang Li, Cem Koc, Nate True, Albert Antony, Gokul Santhanam, James Gabriel, Peter Grasch, Oncel Tuzel, Hadi Pouransari}, 82 | title = {FastVLM: Efficient Vision Encoding for Vision Language Models}, 83 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 84 | month = {June}, 85 | year = {2025}, 86 | } 87 | ``` 88 | 89 | ## Acknowledgements 90 | Our codebase is built using multiple opensource contributions, please see [ACKNOWLEDGEMENTS](ACKNOWLEDGEMENTS) for more details. 91 | 92 | ## License 93 | Please check out the repository [LICENSE](LICENSE) before using the provided code and 94 | [LICENSE_MODEL](LICENSE_MODEL) for the released models. 95 | -------------------------------------------------------------------------------- /app/Configuration/Build.xcconfig: -------------------------------------------------------------------------------- 1 | // The `DISAMBIGUATOR` configuration is to make it easier to build 2 | // and run a sample code project. Once you set your project's development team, 3 | // you'll have a unique bundle identifier. This is because the bundle identifier 4 | // is derived based on the 'DISAMBIGUATOR' value. Do not use this 5 | // approach in your own projects—it's only useful for example projects because 6 | // they are frequently downloaded and don't have a development team set. 7 | DISAMBIGUATOR=${DEVELOPMENT_TEAM} 8 | -------------------------------------------------------------------------------- /app/FastVLM App/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /app/FastVLM App/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "filename" : "FastVLM - 150 Blue - Light@2x.png", 5 | "idiom" : "universal", 6 | "platform" : "ios", 7 | "size" : "1024x1024" 8 | }, 9 | { 10 | "appearances" : [ 11 | { 12 | "appearance" : "luminosity", 13 | "value" : "dark" 14 | } 15 | ], 16 | "filename" : "FastVLM - 150 Blue - Dark@2x.png", 17 | "idiom" : "universal", 18 | "platform" : "ios", 19 | "size" : "1024x1024" 20 | }, 21 | { 22 | "appearances" : [ 23 | { 24 | "appearance" : "luminosity", 25 | "value" : "tinted" 26 | } 27 | ], 28 | "filename" : "FastVLM - 150 White - Tinted@2x.png", 29 | "idiom" : "universal", 30 | "platform" : "ios", 31 | "size" : "1024x1024" 32 | }, 33 | { 34 | "idiom" : "mac", 35 | "scale" : "1x", 36 | "size" : "16x16" 37 | }, 38 | { 39 | "idiom" : "mac", 40 | "scale" : "2x", 41 | "size" : "16x16" 42 | }, 43 | { 44 | "idiom" : "mac", 45 | "scale" : "1x", 46 | "size" : "32x32" 47 | }, 48 | { 49 | "idiom" : "mac", 50 | "scale" : "2x", 51 | "size" : "32x32" 52 | }, 53 | { 54 | "idiom" : "mac", 55 | "scale" : "1x", 56 | "size" : "128x128" 57 | }, 58 | { 59 | "idiom" : "mac", 60 | "scale" : "2x", 61 | "size" : "128x128" 62 | }, 63 | { 64 | "idiom" : "mac", 65 | "scale" : "1x", 66 | "size" : "256x256" 67 | }, 68 | { 69 | "idiom" : "mac", 70 | "scale" : "2x", 71 | "size" : "256x256" 72 | }, 73 | { 74 | "filename" : "FastVLM - MacOS - Dark@1x.png", 75 | "idiom" : "mac", 76 | "scale" : "1x", 77 | "size" : "512x512" 78 | }, 79 | { 80 | "filename" : "FastVLM - MacOS - Dark@2x.png", 81 | "idiom" : "mac", 82 | "scale" : "2x", 83 | "size" : "512x512" 84 | } 85 | ], 86 | "info" : { 87 | "author" : "xcode", 88 | "version" : 1 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - 150 Blue - Dark@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - 150 Blue - Dark@2x.png -------------------------------------------------------------------------------- /app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - 150 Blue - Light@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - 150 Blue - Light@2x.png -------------------------------------------------------------------------------- /app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - 150 White - Tinted@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - 150 White - Tinted@2x.png -------------------------------------------------------------------------------- /app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - MacOS - Dark@1x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - MacOS - Dark@1x.png -------------------------------------------------------------------------------- /app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - MacOS - Dark@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - MacOS - Dark@2x.png -------------------------------------------------------------------------------- /app/FastVLM App/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /app/FastVLM App/FastVLM.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.device.camera 10 | 11 | com.apple.security.files.user-selected.read-only 12 | 13 | com.apple.security.network.client 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /app/FastVLM App/FastVLMApp.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import SwiftUI 7 | 8 | @main 9 | struct FastVLMApp: App { 10 | var body: some Scene { 11 | WindowGroup { 12 | ContentView() 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /app/FastVLM App/FastVLMModel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import CoreImage 7 | import FastVLM 8 | import Foundation 9 | import MLX 10 | import MLXLMCommon 11 | import MLXRandom 12 | import MLXVLM 13 | 14 | @Observable 15 | @MainActor 16 | class FastVLMModel { 17 | 18 | public var running = false 19 | public var modelInfo = "" 20 | public var output = "" 21 | public var promptTime: String = "" 22 | 23 | enum LoadState { 24 | case idle 25 | case loaded(ModelContainer) 26 | } 27 | 28 | private let modelConfiguration = FastVLM.modelConfiguration 29 | 30 | /// parameters controlling the output 31 | let generateParameters = GenerateParameters(temperature: 0.0) 32 | let maxTokens = 240 33 | 34 | /// update the display every N tokens -- 4 looks like it updates continuously 35 | /// and is low overhead. observed ~15% reduction in tokens/s when updating 36 | /// on every token 37 | let displayEveryNTokens = 4 38 | 39 | private var loadState = LoadState.idle 40 | private var currentTask: Task? 41 | 42 | enum EvaluationState: String, CaseIterable { 43 | case idle = "Idle" 44 | case processingPrompt = "Processing Prompt" 45 | case generatingResponse = "Generating Response" 46 | } 47 | 48 | public var evaluationState = EvaluationState.idle 49 | 50 | public init() { 51 | FastVLM.register(modelFactory: VLMModelFactory.shared) 52 | } 53 | 54 | private func _load() async throws -> ModelContainer { 55 | switch loadState { 56 | case .idle: 57 | // limit the buffer cache 58 | MLX.GPU.set(cacheLimit: 20 * 1024 * 1024) 59 | 60 | let modelContainer = try await VLMModelFactory.shared.loadContainer( 61 | configuration: modelConfiguration 62 | ) { 63 | [modelConfiguration] progress in 64 | Task { @MainActor in 65 | self.modelInfo = 66 | "Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%" 67 | } 68 | } 69 | self.modelInfo = "Loaded" 70 | loadState = .loaded(modelContainer) 71 | return modelContainer 72 | 73 | case .loaded(let modelContainer): 74 | return modelContainer 75 | } 76 | } 77 | 78 | public func load() async { 79 | do { 80 | _ = try await _load() 81 | } catch { 82 | self.modelInfo = "Error loading model: \(error)" 83 | } 84 | } 85 | 86 | public func generate(_ userInput: UserInput) async -> Task { 87 | if let currentTask, running { 88 | return currentTask 89 | } 90 | 91 | running = true 92 | 93 | // Cancel any existing task 94 | currentTask?.cancel() 95 | 96 | // Create new task and store reference 97 | let task = Task { 98 | do { 99 | let modelContainer = try await _load() 100 | 101 | // each time you generate you will get something new 102 | MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000)) 103 | 104 | // Check if task was cancelled 105 | if Task.isCancelled { return } 106 | 107 | let result = try await modelContainer.perform { context in 108 | // Measure the time it takes to prepare the input 109 | 110 | Task { @MainActor in 111 | evaluationState = .processingPrompt 112 | } 113 | 114 | let llmStart = Date() 115 | let input = try await context.processor.prepare(input: userInput) 116 | 117 | var seenFirstToken = false 118 | 119 | // FastVLM generates the output 120 | let result = try MLXLMCommon.generate( 121 | input: input, parameters: generateParameters, context: context 122 | ) { tokens in 123 | // Check if task was cancelled 124 | if Task.isCancelled { 125 | return .stop 126 | } 127 | 128 | if !seenFirstToken { 129 | seenFirstToken = true 130 | 131 | // produced first token, update the time to first token, 132 | // the processing state and start displaying the text 133 | let llmDuration = Date().timeIntervalSince(llmStart) 134 | let text = context.tokenizer.decode(tokens: tokens) 135 | Task { @MainActor in 136 | evaluationState = .generatingResponse 137 | self.output = text 138 | self.promptTime = "\(Int(llmDuration * 1000)) ms" 139 | } 140 | } 141 | 142 | // Show the text in the view as it generates 143 | if tokens.count % displayEveryNTokens == 0 { 144 | let text = context.tokenizer.decode(tokens: tokens) 145 | Task { @MainActor in 146 | self.output = text 147 | } 148 | } 149 | 150 | if tokens.count >= maxTokens { 151 | return .stop 152 | } else { 153 | return .more 154 | } 155 | } 156 | 157 | // Return the duration of the LLM and the result 158 | return result 159 | } 160 | 161 | // Check if task was cancelled before updating UI 162 | if !Task.isCancelled { 163 | self.output = result.output 164 | } 165 | 166 | } catch { 167 | if !Task.isCancelled { 168 | output = "Failed: \(error)" 169 | } 170 | } 171 | 172 | if evaluationState == .generatingResponse { 173 | evaluationState = .idle 174 | } 175 | 176 | running = false 177 | } 178 | 179 | currentTask = task 180 | return task 181 | } 182 | 183 | public func cancel() { 184 | currentTask?.cancel() 185 | currentTask = nil 186 | running = false 187 | output = "" 188 | promptTime = "" 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /app/FastVLM App/Info.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /app/FastVLM App/InfoView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import Foundation 7 | import SwiftUI 8 | 9 | struct InfoView: View { 10 | @Environment(\.dismiss) var dismiss 11 | 12 | let paragraph1 = "**FastVLM¹** is a new family of Vision-Language models that makes use of **FastViTHD**, a hierarchical hybrid vision encoder that produces small number of high quality tokens at low latencies, resulting in significantly faster time-to-first-token (TTFT)." 13 | let paragraph2 = "This app showcases the **FastVLM** model in action, allowing users to freely customize the prompt. FastVLM utilizes Qwen2-Instruct LLMs without additional safety tuning, so please exercise caution when modifying the prompt." 14 | let footer = "1. **FastVLM: Efficient Vision Encoding for Vision Language Models.** (CVPR 2025) Pavan Kumar Anasosalu Vasu, Fartash Faghri, Chun-Liang Li, Cem Koc, Nate True, Albert Antony, Gokul Santhanam, James Gabriel, Peter Grasch, Oncel Tuzel, Hadi Pouransari" 15 | 16 | var body: some View { 17 | NavigationStack { 18 | VStack(alignment: .leading, spacing: 20.0) { 19 | // I'm not going to lie, this doesn't make sense... 20 | // Wrapping `String`s with `.init()` turns them into `LocalizedStringKey`s 21 | // which gives us all of the fun Markdown formatting while retaining the 22 | // ability to use `String` variables. ¯\_(ツ)_/¯ 23 | Text("\(.init(paragraph1))\n\n\(.init(paragraph2))\n\n") 24 | .font(.body) 25 | 26 | Spacer() 27 | 28 | Text(.init(footer)) 29 | .font(.caption) 30 | .foregroundStyle(.secondary) 31 | } 32 | .padding() 33 | .frame(maxWidth: .infinity, maxHeight: .infinity, alignment: .top) 34 | .textSelection(.enabled) 35 | .navigationTitle("Information") 36 | #if os(iOS) 37 | .navigationBarTitleDisplayMode(.inline) 38 | #endif 39 | .toolbar { 40 | #if os(iOS) 41 | ToolbarItem(placement: .navigationBarLeading) { 42 | Button { 43 | dismiss() 44 | } label: { 45 | Image(systemName: "xmark.circle") 46 | .resizable() 47 | .frame(width: 25, height: 25) 48 | .foregroundStyle(.secondary) 49 | } 50 | .buttonStyle(.plain) 51 | } 52 | #elseif os(macOS) 53 | ToolbarItem(placement: .cancellationAction) { 54 | Button("Done") { 55 | dismiss() 56 | } 57 | .buttonStyle(.bordered) 58 | } 59 | #endif 60 | } 61 | } 62 | } 63 | } 64 | 65 | #Preview { 66 | InfoView() 67 | } 68 | -------------------------------------------------------------------------------- /app/FastVLM App/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /app/FastVLM.xcodeproj/xcshareddata/xcschemes/FastVLM App.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 9 | 10 | 16 | 22 | 23 | 24 | 25 | 26 | 32 | 33 | 43 | 45 | 51 | 52 | 53 | 54 | 60 | 62 | 68 | 69 | 70 | 71 | 73 | 74 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /app/FastVLM/FastVLM.h: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | #ifndef FastVLM_h 7 | #define FastVLM_h 8 | 9 | 10 | #endif /* FastVLM_h */ 11 | -------------------------------------------------------------------------------- /app/FastVLM/MediaProcessingExtensions.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import Accelerate 7 | import CoreImage 8 | import MLX 9 | import MLXLMCommon 10 | import MLXVLM 11 | 12 | /// Additions to MediaProcessing -- not currently present in mlx-libraries 13 | enum MediaProcessingExtensions { 14 | 15 | // this function is not exported in current mlx-swift-examples -- local copy until it is exposed 16 | // properly 17 | public static func apply(_ image: CIImage, processing: UserInput.Processing?) -> CIImage { 18 | var image = image 19 | 20 | if let resize = processing?.resize { 21 | let scale = MediaProcessing.bestFitScale(image.extent.size, in: resize) 22 | image = image.transformed(by: CGAffineTransform(scaleX: scale, y: scale)) 23 | } 24 | 25 | return image 26 | } 27 | 28 | public static func rectSmallerOrEqual(_ extent: CGRect, size: CGSize) -> Bool { 29 | return extent.width <= size.width && extent.height <= size.height 30 | } 31 | 32 | public static func centerCrop(_ extent: CGRect, size: CGSize) -> CGRect { 33 | let targetWidth = min(extent.width, size.width) 34 | let targetHeight = min(extent.height, size.height) 35 | 36 | return CGRect( 37 | x: (extent.maxX - targetWidth) / 2, 38 | y: (extent.maxY - targetHeight) / 2, 39 | width: targetWidth, height: targetHeight 40 | ) 41 | } 42 | 43 | public static func centerCrop(_ image: CIImage, size: CGSize) -> CIImage { 44 | let extent = image.extent 45 | if rectSmallerOrEqual(extent, size: size) { 46 | return image 47 | } 48 | 49 | let crop = centerCrop(extent, size: size) 50 | return 51 | image 52 | .cropped(to: crop) 53 | .transformed(by: CGAffineTransform(translationX: -crop.minX, y: -crop.minY)) 54 | } 55 | 56 | public static func fitIn(_ size: CGSize, shortestEdge: Int) -> CGSize { 57 | let floatShortestEdge = CGFloat(shortestEdge) 58 | 59 | let (short, long) = 60 | size.width <= size.height ? (size.width, size.height) : (size.height, size.width) 61 | let newShort = floatShortestEdge 62 | let newLong = floatShortestEdge * long / short 63 | 64 | return size.width <= size.height 65 | ? CGSize(width: newShort, height: newLong) : CGSize(width: newLong, height: newShort) 66 | } 67 | 68 | public static func fitIn(_ size: CGSize, longestEdge: Int) -> CGSize { 69 | let floatLongestEdge = CGFloat(longestEdge) 70 | 71 | var (newShort, newLong) = 72 | size.width <= size.height ? (size.width, size.height) : (size.height, size.width) 73 | 74 | if newLong > floatLongestEdge { 75 | newLong = floatLongestEdge 76 | newShort = floatLongestEdge * newShort / newLong 77 | } 78 | 79 | return size.width <= size.height 80 | ? CGSize(width: newShort, height: newLong) : CGSize(width: newLong, height: newShort) 81 | } 82 | 83 | // version of function from https://github.com/ml-explore/mlx-swift-examples/pull/222 84 | public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage { 85 | // Create a bicubic scale filter 86 | 87 | let yScale = size.height / image.extent.height 88 | let xScale = size.width / image.extent.width 89 | 90 | let filter = CIFilter.bicubicScaleTransform() 91 | filter.inputImage = image 92 | filter.scale = Float(yScale) 93 | filter.aspectRatio = Float(xScale / yScale) 94 | let scaledImage = filter.outputImage! 95 | 96 | // Create a rect with the exact dimensions we want 97 | let exactRect = CGRect( 98 | x: 0, 99 | y: 0, 100 | width: size.width, 101 | height: size.height 102 | ) 103 | // Crop to ensure exact dimensions 104 | return scaledImage.cropped(to: exactRect) 105 | } 106 | 107 | static let context = CIContext() 108 | 109 | /// Convert the CIImage into a planar 3 channel MLXArray `[1, C, H, W]`. 110 | /// 111 | /// This physically moves the channels into a planar configuration -- this is 112 | /// required for feeding into the CoreML model and is faster to use 113 | /// dedicated functions than transforming into contiguous memory 114 | /// on readout. 115 | static public func asPlanarMLXArray(_ image: CIImage, colorSpace: CGColorSpace? = nil) 116 | -> MLXArray 117 | { 118 | let size = image.extent.size 119 | let w = Int(size.width.rounded()) 120 | let h = Int(size.height.rounded()) 121 | 122 | // probably not strictly necessary, but this is what happens in 123 | // e.g. image_processing_siglip in transformers (float32) 124 | let format = CIFormat.RGBAf 125 | let componentsPerPixel = 4 126 | let bytesPerComponent: Int = MemoryLayout.size 127 | let bytesPerPixel = componentsPerPixel * bytesPerComponent 128 | let bytesPerRow = w * bytesPerPixel 129 | 130 | var data = Data(count: w * h * bytesPerPixel) 131 | var planarData = Data(count: 3 * w * h * bytesPerComponent) 132 | data.withUnsafeMutableBytes { ptr in 133 | context.render( 134 | image, toBitmap: ptr.baseAddress!, rowBytes: bytesPerRow, bounds: image.extent, 135 | format: format, colorSpace: colorSpace) 136 | context.clearCaches() 137 | 138 | let vh = vImagePixelCount(h) 139 | let vw = vImagePixelCount(w) 140 | 141 | // convert from RGBAf -> RGBf in place 142 | let rgbBytesPerRow = w * 3 * bytesPerComponent 143 | var rgbaSrc = vImage_Buffer( 144 | data: ptr.baseAddress!, height: vh, width: vw, rowBytes: bytesPerRow) 145 | var rgbDest = vImage_Buffer( 146 | data: ptr.baseAddress!, height: vh, width: vw, rowBytes: rgbBytesPerRow) 147 | 148 | vImageConvert_RGBAFFFFtoRGBFFF(&rgbaSrc, &rgbDest, vImage_Flags(kvImageNoFlags)) 149 | 150 | // and convert to planar data in a second buffer 151 | planarData.withUnsafeMutableBytes { planarPtr in 152 | let planeBytesPerRow = w * bytesPerComponent 153 | 154 | var rDest = vImage_Buffer( 155 | data: planarPtr.baseAddress!.advanced(by: 0 * planeBytesPerRow * h), height: vh, 156 | width: vw, rowBytes: planeBytesPerRow) 157 | var gDest = vImage_Buffer( 158 | data: planarPtr.baseAddress!.advanced(by: 1 * planeBytesPerRow * h), height: vh, 159 | width: vw, rowBytes: planeBytesPerRow) 160 | var bDest = vImage_Buffer( 161 | data: planarPtr.baseAddress!.advanced(by: 2 * planeBytesPerRow * h), height: vh, 162 | width: vw, rowBytes: planeBytesPerRow) 163 | 164 | vImageConvert_RGBFFFtoPlanarF( 165 | &rgbDest, &rDest, &gDest, &bDest, vImage_Flags(kvImageNoFlags)) 166 | } 167 | } 168 | 169 | return MLXArray(planarData, [1, 3, h, w], type: Float32.self) 170 | } 171 | 172 | } 173 | -------------------------------------------------------------------------------- /app/README.md: -------------------------------------------------------------------------------- 1 | # FastVLM 2 | 3 | Demonstrates the performance of **FastVLM** models for on-device, visual question answering. 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 |
FastVLM - CountingFastVLM - HandwritingFastVLM - Emoji
12 | 13 | ## Features 14 | 15 | - FastVLM runs on iOS (18.2+) and macOS (15.2+). 16 | - View Time-To-First-Token (TTFT) with every inference. 17 | - All predictions are processed privately and securely using on-device models. 18 | 19 | ### Flexible Prompting 20 | 21 | Flexible prompting 22 | 23 | The app includes a set of built-in prompts to help you get started quickly. Tap the **Prompts** button in the top-right corner to explore them. Selecting a prompt will immediately update the active input. To create new prompts or edit existing ones, choose **Customize…** from the **Prompts** menu. 24 | 25 | ## Pretrained Model Options 26 | 27 | There are 3 pretrained sizes of FastVLM to choose from: 28 | 29 | - **FastVLM 0.5B**: Small and fast - great for mobile devices where speed matters. 30 | - **FastVLM 1.5B**: Well balanced - great for larger devices where speed and accuracy matters. 31 | - **FastVLM 7B**: Fast and accurate - ideal for situations where accuracy matters over speed. 32 | 33 | To download any FastVLM listed above, use the [get_pretrained_mlx_model.sh](get_pretrained_mlx_model.sh) script. The script downloads the model from the web and places it in the appropriate location. Once a model has been downloaded using the steps below, no additional steps are needed to build the app in Xcode. 34 | 35 | To explore how the other models work for your use-case, simply re-run the `get_pretrained_mlx_model.sh` with the new model selected, follow the prompts, and rebuild your app in Xcode. 36 | 37 | ### Download Instructions 38 | 39 | 1. Make the script executable 40 | 41 | ```shell 42 | chmod +x app/get_pretrained_mlx_model.sh 43 | ``` 44 | 45 | 2. Download FastVLM 46 | 47 | ```shell 48 | app/get_pretrained_mlx_model.sh --model 0.5b --dest app/FastVLM/model 49 | ``` 50 | 51 | 3. Open the app in Xcode, Build, and Run. 52 | 53 | ### Custom Model 54 | 55 | In addition to pretrained sizes of FastVLM, you can further quantize or fine-tune FastVLM to best fit their needs. To learn more, check out our documentation on how to [`export the model`](../model_export#export-vlm). 56 | Please clear existing model in `app/FastVLM/model` before downloading or copying a new model. 57 | -------------------------------------------------------------------------------- /app/Video/CameraController.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import AVFoundation 7 | import CoreImage 8 | 9 | #if os(iOS) 10 | import UIKit 11 | #endif 12 | 13 | @Observable 14 | public class CameraController: NSObject { 15 | 16 | private var framesContinuation: AsyncStream.Continuation? 17 | 18 | public var backCamera = true { 19 | didSet { 20 | stop() 21 | start() 22 | } 23 | } 24 | 25 | public var devices = [AVCaptureDevice]() 26 | 27 | public var device: AVCaptureDevice = AVCaptureDevice.default(for: .video)! { 28 | didSet { 29 | stop() 30 | start() 31 | } 32 | } 33 | 34 | private var permissionGranted = true 35 | private var captureSession: AVCaptureSession? 36 | private let sessionQueue = DispatchQueue(label: "sessionQueue") 37 | @objc dynamic private var rotationCoordinator : AVCaptureDevice.RotationCoordinator? 38 | private var rotationObservation: NSKeyValueObservation? 39 | 40 | public func attach(continuation: AsyncStream.Continuation) { 41 | sessionQueue.async { 42 | self.framesContinuation = continuation 43 | } 44 | } 45 | 46 | public func detatch() { 47 | sessionQueue.async { 48 | self.framesContinuation = nil 49 | } 50 | } 51 | 52 | public func stop() { 53 | sessionQueue.sync { [self] in 54 | captureSession?.stopRunning() 55 | captureSession = nil 56 | } 57 | 58 | } 59 | 60 | public func start() { 61 | sessionQueue.async { [self] in 62 | let captureSession = AVCaptureSession() 63 | self.captureSession = captureSession 64 | 65 | self.checkPermission() 66 | self.setupCaptureSession(position: backCamera ? .back : .front) 67 | captureSession.startRunning() 68 | } 69 | } 70 | 71 | #if os(iOS) 72 | private func setOrientation(_ orientation: UIDeviceOrientation) { 73 | guard let captureSession else { return } 74 | 75 | let angle: Double? 76 | switch orientation { 77 | case .unknown, .faceDown: 78 | angle = nil 79 | case .portrait, .faceUp: 80 | angle = 90 81 | case .portraitUpsideDown: 82 | angle = 270 83 | case .landscapeLeft: 84 | angle = 0 85 | case .landscapeRight: 86 | angle = 180 87 | @unknown default: 88 | angle = nil 89 | } 90 | 91 | if let angle { 92 | for output in captureSession.outputs { 93 | output.connection(with: .video)?.videoRotationAngle = angle 94 | } 95 | } 96 | } 97 | 98 | private func updateRotation(rotation : CGFloat) { 99 | guard let captureSession else { return } 100 | for output in captureSession.outputs { 101 | output.connection(with: .video)?.videoRotationAngle = rotation 102 | } 103 | } 104 | #endif 105 | 106 | func checkPermission() { 107 | switch AVCaptureDevice.authorizationStatus(for: .video) { 108 | case .authorized: 109 | // The user has previously granted access to the camera. 110 | self.permissionGranted = true 111 | 112 | case .notDetermined: 113 | // The user has not yet been asked for camera access. 114 | self.requestPermission() 115 | 116 | // Combine the two other cases into the default case 117 | default: 118 | self.permissionGranted = false 119 | } 120 | } 121 | 122 | func requestPermission() { 123 | // Strong reference not a problem here but might become one in the future. 124 | AVCaptureDevice.requestAccess(for: .video) { [unowned self] granted in 125 | self.permissionGranted = granted 126 | } 127 | } 128 | 129 | func setupCaptureSession(position: AVCaptureDevice.Position) { 130 | guard let captureSession else { return } 131 | 132 | let videoOutput = AVCaptureVideoDataOutput() 133 | 134 | guard permissionGranted else { 135 | print("No permission for camera") 136 | return 137 | } 138 | 139 | let deviceTypes: [AVCaptureDevice.DeviceType] 140 | #if os(iOS) 141 | deviceTypes = [.builtInDualCamera, .builtInWideAngleCamera] 142 | #else 143 | deviceTypes = [.external, .continuityCamera, .builtInWideAngleCamera] 144 | #endif 145 | 146 | let videoDeviceDiscoverySession = AVCaptureDevice.DiscoverySession( 147 | deviceTypes: deviceTypes, 148 | mediaType: .video, 149 | position: position) 150 | 151 | let videoDevice: AVCaptureDevice? 152 | if videoDeviceDiscoverySession.devices.contains(self.device) { 153 | videoDevice = self.device 154 | } else { 155 | videoDevice = videoDeviceDiscoverySession.devices.first 156 | } 157 | 158 | if devices.isEmpty { 159 | self.devices = videoDeviceDiscoverySession.devices 160 | } 161 | 162 | guard 163 | let videoDevice 164 | else { 165 | print("Unable to find video device") 166 | return 167 | } 168 | guard let videoDeviceInput = try? AVCaptureDeviceInput(device: videoDevice) else { 169 | print("Unable to create AVCaptureDeviceInput") 170 | return 171 | } 172 | guard captureSession.canAddInput(videoDeviceInput) else { 173 | print("Unable to add input") 174 | return 175 | } 176 | captureSession.addInput(videoDeviceInput) 177 | 178 | videoOutput.setSampleBufferDelegate(self, queue: DispatchQueue(label: "sampleBufferQueue")) 179 | captureSession.addOutput(videoOutput) 180 | captureSession.sessionPreset = AVCaptureSession.Preset.hd1920x1080 181 | 182 | #if os(iOS) 183 | rotationCoordinator = AVCaptureDevice.RotationCoordinator(device: videoDevice, previewLayer: nil) 184 | rotationObservation = observe(\.rotationCoordinator!.videoRotationAngleForHorizonLevelCapture, options: [.initial, .new]) { [weak self] _, change in 185 | if let nv = change.newValue { 186 | self?.updateRotation(rotation: nv) 187 | } 188 | } 189 | #endif 190 | } 191 | } 192 | 193 | extension CameraController: AVCaptureVideoDataOutputSampleBufferDelegate { 194 | public func captureOutput( 195 | _ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer, 196 | from connection: AVCaptureConnection 197 | ) { 198 | if sampleBuffer.isValid && sampleBuffer.imageBuffer != nil { 199 | framesContinuation?.yield(sampleBuffer) 200 | } 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /app/Video/CameraControlsView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import AVFoundation 7 | import SwiftUI 8 | 9 | public struct CameraControlsView: View { 10 | 11 | @Binding public var backCamera: Bool 12 | @Binding public var device: AVCaptureDevice 13 | @Binding public var devices: [AVCaptureDevice] 14 | 15 | public init( 16 | backCamera: Binding, 17 | device: Binding, 18 | devices: Binding<[AVCaptureDevice]> 19 | ) { 20 | self._backCamera = backCamera 21 | self._device = device 22 | self._devices = devices 23 | } 24 | 25 | public var body: some View { 26 | Button { 27 | backCamera.toggle() 28 | } label: { 29 | RoundedRectangle(cornerRadius: 8.0) 30 | .fill(.regularMaterial) 31 | .frame(width: 32.0, height: 32.0) 32 | .overlay(alignment: .center) { 33 | // Switch cameras image 34 | Image(systemName: "arrow.triangle.2.circlepath.camera.fill") 35 | .foregroundStyle(.primary) 36 | .padding(6.0) 37 | } 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /app/Video/CameraType.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import Foundation 7 | 8 | public enum CameraType: String, CaseIterable { 9 | case continuous 10 | case single 11 | } 12 | -------------------------------------------------------------------------------- /app/Video/Video.h: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | #import 7 | 8 | //! Project version number for Video. 9 | FOUNDATION_EXPORT double VideoVersionNumber; 10 | 11 | //! Project version string for Video. 12 | FOUNDATION_EXPORT const unsigned char VideoVersionString[]; 13 | -------------------------------------------------------------------------------- /app/Video/VideoFrameView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // For licensing see accompanying LICENSE file. 3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | // 5 | 6 | import AVFoundation 7 | import CoreImage 8 | import Foundation 9 | import SwiftUI 10 | 11 | /// Displays a stream of video frames 12 | public struct VideoFrameView: View { 13 | @Environment(\.colorScheme) private var colorScheme 14 | 15 | public let frames: AsyncStream 16 | public let cameraType: CameraType 17 | public let action: ((CVImageBuffer) -> Void)? 18 | 19 | @State private var hold: Bool = false 20 | @State private var videoFrame: CVImageBuffer? 21 | 22 | private var backgroundColor: Color { 23 | #if os(iOS) 24 | return Color(.secondarySystemBackground) 25 | #elseif os(macOS) 26 | return Color(.secondarySystemFill) 27 | #else 28 | // When in doubt, use these values that I captured to match iOS' secondarySystemBackground 29 | if colorScheme == .dark { 30 | return Color(red: 0.11, green: 0.11, blue: 0.12) 31 | } else { 32 | return Color(red: 0.95, green: 0.95, blue: 0.97) 33 | } 34 | #endif 35 | } 36 | 37 | public init( 38 | frames: AsyncStream, 39 | cameraType: CameraType, 40 | action: ((CVImageBuffer) -> Void)? 41 | ) { 42 | self.frames = frames 43 | self.cameraType = cameraType 44 | self.action = action 45 | } 46 | 47 | public var body: some View { 48 | Group { 49 | if let videoFrame { 50 | _ImageView(image: videoFrame) 51 | .overlay(alignment: .bottom) { 52 | if cameraType == .single { 53 | Button { 54 | tap() 55 | } label: { 56 | if hold { 57 | Label("Resume", systemImage: "play.fill") 58 | } else { 59 | Label("Capture Photo", systemImage: "camera.fill") 60 | } 61 | } 62 | .clipShape(.capsule) 63 | .buttonStyle(.borderedProminent) 64 | .tint(hold ? .gray : .accentColor) 65 | .foregroundColor(.white) 66 | .padding() 67 | } 68 | } 69 | } else { 70 | // spinner before the camera comes up 71 | ProgressView() 72 | .controlSize(.large) 73 | } 74 | } 75 | // This ensures that we take up the full 4/3 aspect ratio 76 | // even if we don't have an image to display 77 | .frame(maxWidth: .infinity, maxHeight: .infinity) 78 | .background(backgroundColor) 79 | .clipShape(RoundedRectangle(cornerRadius: 10.0)) 80 | .task { 81 | // feed frames to the _ImageView 82 | if Task.isCancelled { 83 | return 84 | } 85 | for await frame in frames { 86 | if !hold { 87 | videoFrame = frame 88 | } 89 | } 90 | } 91 | .onChange(of: cameraType) { _, newType in 92 | // No matter what, when the user switches to .continuous, 93 | // we need to continue showing updated frames 94 | if newType == .continuous { 95 | hold = false 96 | } 97 | } 98 | } 99 | 100 | private func tap() { 101 | if hold { 102 | // resume 103 | hold = false 104 | } else if let videoFrame { 105 | hold = true 106 | if let action { 107 | action(videoFrame) 108 | } 109 | } 110 | } 111 | } 112 | 113 | #if os(iOS) 114 | /// Internal view to display a CVImageBuffer 115 | private struct _ImageView: UIViewRepresentable { 116 | 117 | let image: Any 118 | var gravity = CALayerContentsGravity.resizeAspectFill 119 | 120 | func makeUIView(context: Context) -> UIView { 121 | let view = UIView() 122 | view.layer.contentsGravity = gravity 123 | return view 124 | } 125 | 126 | func updateUIView(_ uiView: UIView, context: Context) { 127 | uiView.layer.contents = image 128 | } 129 | } 130 | #else 131 | private struct _ImageView: NSViewRepresentable { 132 | 133 | let image: Any 134 | var gravity = CALayerContentsGravity.resizeAspectFill 135 | 136 | func makeNSView(context: Context) -> NSView { 137 | let view = NSView() 138 | view.wantsLayer = true 139 | view.layer?.contentsGravity = gravity 140 | return view 141 | } 142 | 143 | func updateNSView(_ uiView: NSView, context: Context) { 144 | uiView.layer?.contents = image 145 | } 146 | } 147 | 148 | #endif 149 | -------------------------------------------------------------------------------- /app/get_pretrained_mlx_model.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # For licensing see accompanying LICENSE_MODEL file. 4 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 5 | # 6 | set -e 7 | 8 | # Help function 9 | show_help() { 10 | local is_error=${1:-true} # Default to error mode if no argument provided 11 | 12 | echo "Usage: $0 --model --dest " 13 | echo 14 | echo "Required arguments:" 15 | echo " --model Size of the model to download" 16 | echo " --dest Directory where the model will be downloaded" 17 | echo 18 | echo "Available model sizes:" 19 | echo " 0.5b - 0.5B parameter model (FP16)" 20 | echo " 1.5b - 1.5B parameter model (INT8)" 21 | echo " 7b - 7B parameter model (INT4)" 22 | echo 23 | echo "Options:" 24 | echo " --help Show help message" 25 | 26 | # Exit with success (0) for help flag, error (1) for usage errors 27 | if [ "$is_error" = "false" ]; then 28 | exit 0 29 | else 30 | exit 1 31 | fi 32 | } 33 | 34 | # Parse command line arguments 35 | while [[ "$#" -gt 0 ]]; do 36 | case $1 in 37 | --model) model_size="$2"; shift ;; 38 | --dest) dest_dir="$2"; shift ;; 39 | --help) show_help false ;; # Explicit help request 40 | *) echo -e "Unknown parameter: $1\n"; show_help true ;; # Error case 41 | esac 42 | shift 43 | done 44 | 45 | # Validate required parameters 46 | if [ -z "$model_size" ]; then 47 | echo -e "Error: --model parameter is required\n" 48 | show_help true 49 | fi 50 | 51 | if [ -z "$dest_dir" ]; then 52 | echo -e "Error: --dest parameter is required\n" 53 | show_help true 54 | fi 55 | 56 | # Map model size to full model name 57 | case "$model_size" in 58 | "0.5b") model="llava-fastvithd_0.5b_stage3_llm.fp16" ;; 59 | "1.5b") model="llava-fastvithd_1.5b_stage3_llm.int8" ;; 60 | "7b") model="llava-fastvithd_7b_stage3_llm.int4" ;; 61 | *) 62 | echo -e "Error: Invalid model size '$model_size'\n" 63 | show_help true 64 | ;; 65 | esac 66 | 67 | cleanup() { 68 | rm -rf "$tmp_dir" 69 | } 70 | 71 | download_model() { 72 | # Download directory 73 | tmp_dir=$(mktemp -d) 74 | 75 | # Model paths 76 | base_url="https://ml-site.cdn-apple.com/datasets/fastvlm" 77 | 78 | # Create destination directory if it doesn't exist 79 | if [ ! -d "$dest_dir" ]; then 80 | echo "Creating destination directory: $dest_dir" 81 | mkdir -p "$dest_dir" 82 | elif [ "$(ls -A "$dest_dir")" ]; then 83 | echo -e "Destination directory '$dest_dir' exists and is not empty.\n" 84 | read -p "Do you want to clear it and continue? [y/N]: " confirm 85 | if [[ ! "$confirm" =~ ^[Yy]$ ]]; then 86 | echo -e "\nStopping." 87 | exit 1 88 | fi 89 | echo -e "\nClearing existing contents in '$dest_dir'" 90 | rm -rf "${dest_dir:?}"/* 91 | fi 92 | 93 | # Create temp variables 94 | tmp_zip_file="${tmp_dir}/${model}.zip" 95 | tmp_extract_dir="${tmp_dir}/${model}" 96 | 97 | # Create temp extract directory 98 | mkdir -p "$tmp_extract_dir" 99 | 100 | # Download model 101 | echo -e "\nDownloading '${model}' model ...\n" 102 | wget -q --progress=bar:noscroll --show-progress -O "$tmp_zip_file" "$base_url/$model.zip" 103 | 104 | # Unzip model 105 | echo -e "\nUnzipping model..." 106 | unzip -q "$tmp_zip_file" -d "$tmp_extract_dir" 107 | 108 | # Copy model files to destination directory 109 | echo -e "\nCopying model files to destination directory..." 110 | cp -r "$tmp_extract_dir/$model"/* "$dest_dir" 111 | 112 | # Verify destination directory exists and is not empty 113 | if [ ! -d "$dest_dir" ] || [ -z "$(ls -A "$dest_dir")" ]; then 114 | echo -e "\nModel extraction failed. Destination directory '$dest_dir' is missing or empty." 115 | exit 1 116 | fi 117 | 118 | echo -e "\nModel downloaded and extracted to '$dest_dir'" 119 | } 120 | 121 | # Cleanup download directory on exit 122 | trap cleanup EXIT INT TERM 123 | 124 | # Download models 125 | download_model -------------------------------------------------------------------------------- /docs/acc_vs_latency_qwen-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/docs/acc_vs_latency_qwen-2.png -------------------------------------------------------------------------------- /docs/fastvlm-counting.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/docs/fastvlm-counting.gif -------------------------------------------------------------------------------- /docs/fastvlm-emoji.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/docs/fastvlm-emoji.gif -------------------------------------------------------------------------------- /docs/fastvlm-flexible_prompts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/docs/fastvlm-flexible_prompts.png -------------------------------------------------------------------------------- /docs/fastvlm-handwriting.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/docs/fastvlm-handwriting.gif -------------------------------------------------------------------------------- /get_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # For licensing see accompanying LICENSE_MODEL file. 4 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 5 | # 6 | 7 | mkdir -p checkpoints 8 | wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage2.zip -P checkpoints 9 | wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage3.zip -P checkpoints 10 | wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage2.zip -P checkpoints 11 | wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage3.zip -P checkpoints 12 | wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage2.zip -P checkpoints 13 | wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3.zip -P checkpoints 14 | 15 | # Extract models 16 | cd checkpoints 17 | unzip -qq llava-fastvithd_0.5b_stage2.zip 18 | unzip -qq llava-fastvithd_0.5b_stage3.zip 19 | unzip -qq llava-fastvithd_1.5b_stage2.zip 20 | unzip -qq llava-fastvithd_1.5b_stage3.zip 21 | unzip -qq llava-fastvithd_7b_stage2.zip 22 | unzip -qq llava-fastvithd_7b_stage3.zip 23 | 24 | # Clean up 25 | rm llava-fastvithd_0.5b_stage2.zip 26 | rm llava-fastvithd_0.5b_stage3.zip 27 | rm llava-fastvithd_1.5b_stage2.zip 28 | rm llava-fastvithd_1.5b_stage3.zip 29 | rm llava-fastvithd_7b_stage2.zip 30 | rm llava-fastvithd_7b_stage3.zip 31 | cd - 32 | -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM, LlavaQwen2ForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | from PIL import Image 3 | PIL.Image.MAX_IMAGE_PIXELS=500000000 4 | from io import BytesIO 5 | import base64 6 | import torch 7 | import math 8 | import ast 9 | 10 | from transformers import StoppingCriteria 11 | from llava.constants import IMAGE_TOKEN_INDEX 12 | 13 | 14 | def select_best_resolution(original_size, possible_resolutions): 15 | """ 16 | Selects the best resolution from a list of possible resolutions based on the original size. 17 | 18 | Args: 19 | original_size (tuple): The original size of the image in the format (width, height). 20 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. 21 | 22 | Returns: 23 | tuple: The best fit resolution in the format (width, height). 24 | """ 25 | original_width, original_height = original_size 26 | best_fit = None 27 | max_effective_resolution = 0 28 | min_wasted_resolution = float('inf') 29 | 30 | for width, height in possible_resolutions: 31 | scale = min(width / original_width, height / original_height) 32 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) 33 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) 34 | wasted_resolution = (width * height) - effective_resolution 35 | 36 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): 37 | max_effective_resolution = effective_resolution 38 | min_wasted_resolution = wasted_resolution 39 | best_fit = (width, height) 40 | 41 | return best_fit 42 | 43 | 44 | def resize_and_pad_image(image, target_resolution): 45 | """ 46 | Resize and pad an image to a target resolution while maintaining aspect ratio. 47 | 48 | Args: 49 | image (PIL.Image.Image): The input image. 50 | target_resolution (tuple): The target resolution (width, height) of the image. 51 | 52 | Returns: 53 | PIL.Image.Image: The resized and padded image. 54 | """ 55 | original_width, original_height = image.size 56 | target_width, target_height = target_resolution 57 | 58 | scale_w = target_width / original_width 59 | scale_h = target_height / original_height 60 | 61 | if scale_w < scale_h: 62 | new_width = target_width 63 | new_height = min(math.ceil(original_height * scale_w), target_height) 64 | else: 65 | new_height = target_height 66 | new_width = min(math.ceil(original_width * scale_h), target_width) 67 | 68 | # Resize the image 69 | resized_image = image.resize((new_width, new_height)) 70 | 71 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) 72 | paste_x = (target_width - new_width) // 2 73 | paste_y = (target_height - new_height) // 2 74 | new_image.paste(resized_image, (paste_x, paste_y)) 75 | 76 | return new_image 77 | 78 | 79 | def divide_to_patches(image, patch_size): 80 | """ 81 | Divides an image into patches of a specified size. 82 | 83 | Args: 84 | image (PIL.Image.Image): The input image. 85 | patch_size (int): The size of each patch. 86 | 87 | Returns: 88 | list: A list of PIL.Image.Image objects representing the patches. 89 | """ 90 | patches = [] 91 | width, height = image.size 92 | for i in range(0, height, patch_size): 93 | for j in range(0, width, patch_size): 94 | box = (j, i, j + patch_size, i + patch_size) 95 | patch = image.crop(box) 96 | patches.append(patch) 97 | 98 | return patches 99 | 100 | 101 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): 102 | """ 103 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution. 104 | 105 | Args: 106 | image_size (tuple): The size of the input image in the format (width, height). 107 | grid_pinpoints (str): A string representation of a list of possible resolutions. 108 | patch_size (int): The size of each image patch. 109 | 110 | Returns: 111 | tuple: The shape of the image patch grid in the format (width, height). 112 | """ 113 | if type(grid_pinpoints) is list: 114 | possible_resolutions = grid_pinpoints 115 | else: 116 | possible_resolutions = ast.literal_eval(grid_pinpoints) 117 | width, height = select_best_resolution(image_size, possible_resolutions) 118 | return width // patch_size, height // patch_size 119 | 120 | 121 | def process_anyres_image(image, processor, grid_pinpoints): 122 | """ 123 | Process an image with variable resolutions. 124 | 125 | Args: 126 | image (PIL.Image.Image): The input image to be processed. 127 | processor: The image processor object. 128 | grid_pinpoints (str): A string representation of a list of possible resolutions. 129 | 130 | Returns: 131 | torch.Tensor: A tensor containing the processed image patches. 132 | """ 133 | if type(grid_pinpoints) is list: 134 | possible_resolutions = grid_pinpoints 135 | else: 136 | possible_resolutions = ast.literal_eval(grid_pinpoints) 137 | best_resolution = select_best_resolution(image.size, possible_resolutions) 138 | image_padded = resize_and_pad_image(image, best_resolution) 139 | 140 | patches = divide_to_patches(image_padded, processor.crop_size['height']) 141 | 142 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) 143 | 144 | image_patches = [image_original_resize] + patches 145 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 146 | for image_patch in image_patches] 147 | return torch.stack(image_patches, dim=0) 148 | 149 | 150 | def load_image_from_base64(image): 151 | return Image.open(BytesIO(base64.b64decode(image))) 152 | 153 | 154 | def expand2square(pil_img, background_color): 155 | width, height = pil_img.size 156 | if width == height: 157 | return pil_img 158 | elif width > height: 159 | result = Image.new(pil_img.mode, (width, width), background_color) 160 | result.paste(pil_img, (0, (width - height) // 2)) 161 | return result 162 | else: 163 | result = Image.new(pil_img.mode, (height, height), background_color) 164 | result.paste(pil_img, ((height - width) // 2, 0)) 165 | return result 166 | 167 | 168 | def process_images(images, image_processor, model_cfg): 169 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 170 | new_images = [] 171 | if image_aspect_ratio == 'pad': 172 | for image in images: 173 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 174 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 175 | new_images.append(image) 176 | elif image_aspect_ratio == "anyres": 177 | for image in images: 178 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) 179 | new_images.append(image) 180 | else: 181 | return image_processor(images, return_tensors='pt')['pixel_values'] 182 | if all(x.shape == new_images[0].shape for x in new_images): 183 | new_images = torch.stack(new_images, dim=0) 184 | return new_images 185 | 186 | 187 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 188 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 189 | 190 | def insert_separator(X, sep): 191 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 192 | 193 | input_ids = [] 194 | offset = 0 195 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 196 | offset = 1 197 | input_ids.append(prompt_chunks[0][0]) 198 | 199 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 200 | input_ids.extend(x[offset:]) 201 | 202 | if return_tensors is not None: 203 | if return_tensors == 'pt': 204 | return torch.tensor(input_ids, dtype=torch.long) 205 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 206 | return input_ids 207 | 208 | 209 | def get_model_name_from_path(model_path): 210 | model_path = model_path.strip("/") 211 | model_paths = model_path.split("/") 212 | if model_paths[-1].startswith('checkpoint-'): 213 | return model_paths[-2] + "_" + model_paths[-1] 214 | else: 215 | return model_paths[-1] 216 | 217 | 218 | class KeywordsStoppingCriteria(StoppingCriteria): 219 | def __init__(self, keywords, tokenizer, input_ids): 220 | self.keywords = keywords 221 | self.keyword_ids = [] 222 | self.max_keyword_len = 0 223 | for keyword in keywords: 224 | cur_keyword_ids = tokenizer(keyword).input_ids 225 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 226 | cur_keyword_ids = cur_keyword_ids[1:] 227 | if len(cur_keyword_ids) > self.max_keyword_len: 228 | self.max_keyword_len = len(cur_keyword_ids) 229 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 230 | self.tokenizer = tokenizer 231 | self.start_len = input_ids.shape[1] 232 | 233 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 234 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 235 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 236 | for keyword_id in self.keyword_ids: 237 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] 238 | if torch.equal(truncated_output_ids, keyword_id): 239 | return True 240 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 241 | for keyword in self.keywords: 242 | if keyword in outputs: 243 | return True 244 | return False 245 | 246 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 247 | outputs = [] 248 | for i in range(output_ids.shape[0]): 249 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 250 | return all(outputs) 251 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | # try: 2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig 4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig 5 | from .language_model.llava_qwen import LlavaQwen2ForCausalLM, LlavaConfig 6 | # except: 7 | # pass 8 | 9 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from llava.model import * 23 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | 25 | 26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): 27 | kwargs = {"device_map": device_map, **kwargs} 28 | 29 | if device != "cuda": 30 | kwargs['device_map'] = {"": device} 31 | 32 | if load_8bit: 33 | kwargs['load_in_8bit'] = True 34 | elif load_4bit: 35 | kwargs['load_in_4bit'] = True 36 | kwargs['quantization_config'] = BitsAndBytesConfig( 37 | load_in_4bit=True, 38 | bnb_4bit_compute_dtype=torch.float16, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type='nf4' 41 | ) 42 | else: 43 | kwargs['torch_dtype'] = torch.float16 44 | 45 | if use_flash_attn: 46 | kwargs['attn_implementation'] = 'flash_attention_2' 47 | 48 | if 'llava' in model_name.lower(): 49 | # Load LLaVA model 50 | if 'lora' in model_name.lower() and model_base is None: 51 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 52 | if 'lora' in model_name.lower() and model_base is not None: 53 | from llava.model.language_model.llava_llama import LlavaConfig 54 | lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) 55 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 56 | print('Loading LLaVA from base model...') 57 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 58 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 59 | if model.lm_head.weight.shape[0] != token_num: 60 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 61 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 62 | 63 | print('Loading additional LLaVA weights...') 64 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 65 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 66 | else: 67 | # this is probably from HF Hub 68 | from huggingface_hub import hf_hub_download 69 | 70 | def load_from_hf(repo_id, filename, subfolder=None): 71 | cache_file = hf_hub_download( 72 | repo_id=repo_id, 73 | filename=filename, 74 | subfolder=subfolder) 75 | return torch.load(cache_file, map_location='cpu') 76 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 77 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 78 | if any(k.startswith('model.model.') for k in non_lora_trainables): 79 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 80 | model.load_state_dict(non_lora_trainables, strict=False) 81 | 82 | from peft import PeftModel 83 | print('Loading LoRA weights...') 84 | model = PeftModel.from_pretrained(model, model_path) 85 | print('Merging LoRA weights...') 86 | model = model.merge_and_unload() 87 | print('Model is loaded...') 88 | elif model_base is not None: 89 | # this may be mm projector only 90 | print('Loading LLaVA from base model...') 91 | if 'mpt' in model_name.lower(): 92 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): 93 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) 94 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) 95 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 96 | model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 97 | else: 98 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 99 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 100 | # model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 101 | model = LlavaQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 102 | 103 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 104 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 105 | model.load_state_dict(mm_projector_weights, strict=False) 106 | else: 107 | if 'mpt' in model_name.lower(): 108 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 109 | model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 110 | elif 'mistral' in model_name.lower(): 111 | tokenizer = AutoTokenizer.from_pretrained(model_path) 112 | model = LlavaMistralForCausalLM.from_pretrained( 113 | model_path, 114 | low_cpu_mem_usage=True, 115 | **kwargs 116 | ) 117 | elif 'dclm' in model_name.lower(): 118 | tokenizer = AutoTokenizer.from_pretrained(model_path) 119 | model = LlavaOpenlmForCausalLM.from_pretrained( 120 | model_path, 121 | low_cpu_mem_usage=True, 122 | **kwargs 123 | ) 124 | else: 125 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 126 | # model = LlavaLlamaForCausalLM.from_pretrained( 127 | # model_path, 128 | # low_cpu_mem_usage=True, 129 | # **kwargs 130 | # ) 131 | model = LlavaQwen2ForCausalLM.from_pretrained( 132 | model_path, 133 | low_cpu_mem_usage=True, 134 | **kwargs 135 | ) 136 | else: 137 | # Load language model 138 | if model_base is not None: 139 | # PEFT model 140 | from peft import PeftModel 141 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 142 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 143 | print(f"Loading LoRA weights from {model_path}") 144 | model = PeftModel.from_pretrained(model, model_path) 145 | print(f"Merging weights") 146 | model = model.merge_and_unload() 147 | print('Convert to FP16...') 148 | model.to(torch.float16) 149 | else: 150 | use_fast = False 151 | if 'mpt' in model_name.lower(): 152 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 153 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) 154 | else: 155 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 156 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 157 | 158 | image_processor = None 159 | 160 | if 'llava' in model_name.lower(): 161 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 162 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 163 | if mm_use_im_patch_token: 164 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 165 | if mm_use_im_start_end: 166 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 167 | model.resize_token_embeddings(len(tokenizer)) 168 | 169 | vision_tower = model.get_vision_tower() 170 | if not vision_tower.is_loaded: 171 | vision_tower.load_model(device_map=device_map) 172 | if device_map != 'auto': 173 | vision_tower.to(device=device_map, dtype=torch.float16) 174 | image_processor = vision_tower.image_processor 175 | 176 | if hasattr(model.config, "max_sequence_length"): 177 | context_len = model.config.max_sequence_length 178 | else: 179 | context_len = 2048 180 | 181 | return tokenizer, model, image_processor, context_len 182 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava_llama" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | self.pretraining_tp = config.pretraining_tp 48 | self.vocab_size = config.vocab_size 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | cache_position=None, 72 | ) -> Union[Tuple, CausalLMOutputWithPast]: 73 | 74 | if inputs_embeds is None: 75 | ( 76 | input_ids, 77 | position_ids, 78 | attention_mask, 79 | past_key_values, 80 | inputs_embeds, 81 | labels 82 | ) = self.prepare_inputs_labels_for_multimodal( 83 | input_ids, 84 | position_ids, 85 | attention_mask, 86 | past_key_values, 87 | labels, 88 | images, 89 | image_sizes 90 | ) 91 | 92 | return super().forward( 93 | input_ids=input_ids, 94 | attention_mask=attention_mask, 95 | position_ids=position_ids, 96 | past_key_values=past_key_values, 97 | inputs_embeds=inputs_embeds, 98 | labels=labels, 99 | use_cache=use_cache, 100 | output_attentions=output_attentions, 101 | output_hidden_states=output_hidden_states, 102 | return_dict=return_dict 103 | ) 104 | 105 | @torch.no_grad() 106 | def generate( 107 | self, 108 | inputs: Optional[torch.Tensor] = None, 109 | images: Optional[torch.Tensor] = None, 110 | image_sizes: Optional[torch.Tensor] = None, 111 | **kwargs, 112 | ) -> Union[GenerateOutput, torch.LongTensor]: 113 | position_ids = kwargs.pop("position_ids", None) 114 | attention_mask = kwargs.pop("attention_mask", None) 115 | if "inputs_embeds" in kwargs: 116 | raise NotImplementedError("`inputs_embeds` is not supported") 117 | 118 | if images is not None: 119 | ( 120 | inputs, 121 | position_ids, 122 | attention_mask, 123 | _, 124 | inputs_embeds, 125 | _ 126 | ) = self.prepare_inputs_labels_for_multimodal( 127 | inputs, 128 | position_ids, 129 | attention_mask, 130 | None, 131 | None, 132 | images, 133 | image_sizes=image_sizes 134 | ) 135 | else: 136 | inputs_embeds = self.get_model().embed_tokens(inputs) 137 | 138 | return super().generate( 139 | position_ids=position_ids, 140 | attention_mask=attention_mask, 141 | inputs_embeds=inputs_embeds, 142 | **kwargs 143 | ) 144 | 145 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 146 | inputs_embeds=None, **kwargs): 147 | images = kwargs.pop("images", None) 148 | image_sizes = kwargs.pop("image_sizes", None) 149 | inputs = super().prepare_inputs_for_generation( 150 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 151 | ) 152 | if images is not None: 153 | inputs['images'] = images 154 | if image_sizes is not None: 155 | inputs['image_sizes'] = image_sizes 156 | return inputs 157 | 158 | AutoConfig.register("llava_llama", LlavaConfig) 159 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 160 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | MistralConfig, MistralModel, MistralForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | from transformers.generation.utils import GenerateOutput 27 | 28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 29 | 30 | 31 | class LlavaMistralConfig(MistralConfig): 32 | model_type = "llava_mistral" 33 | 34 | 35 | class LlavaMistralModel(LlavaMetaModel, MistralModel): 36 | config_class = LlavaMistralConfig 37 | 38 | def __init__(self, config: MistralConfig): 39 | super(LlavaMistralModel, self).__init__(config) 40 | 41 | 42 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): 43 | config_class = LlavaMistralConfig 44 | 45 | def __init__(self, config): 46 | super(MistralForCausalLM, self).__init__(config) 47 | self.model = LlavaMistralModel(config) 48 | 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if images is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_multimodal( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | images, 132 | image_sizes=image_sizes 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | return inputs 156 | 157 | AutoConfig.register("llava_mistral", LlavaMistralConfig) 158 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) 159 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, \ 21 | MptConfig, MptForCausalLM, MptModel 22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 23 | 24 | 25 | class LlavaMptConfig(MptConfig): 26 | model_type = "llava_mpt" 27 | 28 | 29 | class LlavaMptModel(LlavaMetaModel, MptModel): 30 | config_class = LlavaMptConfig 31 | 32 | def __init__(self, config: MptConfig): 33 | config.hidden_size = config.d_model 34 | super(LlavaMptModel, self).__init__(config) 35 | 36 | def embed_tokens(self, x): 37 | return self.wte(x) 38 | 39 | 40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaMptConfig 42 | supports_gradient_checkpointing = True 43 | 44 | def __init__(self, config): 45 | super(MptForCausalLM, self).__init__(config) 46 | 47 | self.transformer = LlavaMptModel(config) 48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.transformer 55 | 56 | def _set_gradient_checkpointing(self, module, value=False): 57 | if isinstance(module, LlavaMptModel): 58 | module.gradient_checkpointing = value 59 | 60 | def forward( 61 | self, 62 | input_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 64 | attention_mask: Optional[torch.Tensor] = None, 65 | inputs_embeds: Optional[torch.Tensor] = None, 66 | labels: Optional[torch.Tensor] = None, 67 | use_cache: Optional[bool] = None, 68 | output_attentions: Optional[bool] = None, 69 | output_hidden_states: Optional[bool] = None, 70 | return_dict: Optional[bool] = None, 71 | images=None): 72 | 73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 74 | 75 | return super().forward( 76 | input_ids, 77 | past_key_values=past_key_values, 78 | attention_mask=attention_mask, 79 | inputs_embeds=inputs_embeds, 80 | labels=labels, 81 | use_cache=use_cache, 82 | output_attentions=output_attentions, 83 | output_hidden_states=output_hidden_states, 84 | return_dict=return_dict, 85 | ) 86 | 87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 88 | images = kwargs.pop("images", None) 89 | _inputs = super().prepare_inputs_for_generation( 90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 91 | ) 92 | _inputs['images'] = images 93 | return _inputs 94 | 95 | 96 | AutoConfig.register("llava_mpt", LlavaMptConfig) 97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 98 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_qwen.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2023 Haotian Liu 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from typing import List, Optional, Tuple, Union 18 | 19 | import torch 20 | import torch.nn as nn 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(Qwen2Config): 31 | model_type = "llava_qwen2" 32 | 33 | 34 | class LlavaQwen2Model(LlavaMetaModel, Qwen2Model): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: Qwen2Config): 38 | super(LlavaQwen2Model, self).__init__(config) 39 | 40 | 41 | class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(Qwen2ForCausalLM, self).__init__(config) 46 | self.model = LlavaQwen2Model(config) 47 | # self.pretraining_tp = config.pretraining_tp 48 | self.vocab_size = config.vocab_size 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | cache_position=None, 72 | ) -> Union[Tuple, CausalLMOutputWithPast]: 73 | 74 | if inputs_embeds is None: 75 | ( 76 | input_ids, 77 | position_ids, 78 | attention_mask, 79 | past_key_values, 80 | inputs_embeds, 81 | labels 82 | ) = self.prepare_inputs_labels_for_multimodal( 83 | input_ids, 84 | position_ids, 85 | attention_mask, 86 | past_key_values, 87 | labels, 88 | images, 89 | image_sizes 90 | ) 91 | 92 | return super().forward( 93 | input_ids=input_ids, 94 | attention_mask=attention_mask, 95 | position_ids=position_ids, 96 | past_key_values=past_key_values, 97 | inputs_embeds=inputs_embeds, 98 | labels=labels, 99 | use_cache=use_cache, 100 | output_attentions=output_attentions, 101 | output_hidden_states=output_hidden_states, 102 | return_dict=return_dict 103 | ) 104 | 105 | @torch.no_grad() 106 | def generate( 107 | self, 108 | inputs: Optional[torch.Tensor] = None, 109 | images: Optional[torch.Tensor] = None, 110 | image_sizes: Optional[torch.Tensor] = None, 111 | **kwargs, 112 | ) -> Union[GenerateOutput, torch.LongTensor]: 113 | position_ids = kwargs.pop("position_ids", None) 114 | attention_mask = kwargs.pop("attention_mask", None) 115 | if "inputs_embeds" in kwargs: 116 | raise NotImplementedError("`inputs_embeds` is not supported") 117 | 118 | if images is not None: 119 | ( 120 | inputs, 121 | position_ids, 122 | attention_mask, 123 | _, 124 | inputs_embeds, 125 | _ 126 | ) = self.prepare_inputs_labels_for_multimodal( 127 | inputs, 128 | position_ids, 129 | attention_mask, 130 | None, 131 | None, 132 | images, 133 | image_sizes=image_sizes 134 | ) 135 | else: 136 | inputs_embeds = self.get_model().embed_tokens(inputs) 137 | 138 | return super().generate( 139 | position_ids=position_ids, 140 | attention_mask=attention_mask, 141 | inputs_embeds=inputs_embeds, 142 | **kwargs 143 | ) 144 | 145 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 146 | inputs_embeds=None, **kwargs): 147 | images = kwargs.pop("images", None) 148 | image_sizes = kwargs.pop("image_sizes", None) 149 | inputs = super().prepare_inputs_for_generation( 150 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 151 | ) 152 | if images is not None: 153 | inputs['images'] = images 154 | if image_sizes is not None: 155 | inputs['image_sizes'] = image_sizes 156 | return inputs 157 | 158 | 159 | AutoConfig.register("llava_qwen2", LlavaConfig) 160 | AutoModelForCausalLM.register(LlavaConfig, LlavaQwen2ForCausalLM) 161 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 3 | from .mobileclip_encoder import MobileCLIPVisionTower 4 | 5 | 6 | def build_vision_tower(vision_tower_cfg, **kwargs): 7 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 8 | is_absolute_path_exists = os.path.exists(vision_tower) 9 | use_s2 = getattr(vision_tower_cfg, 's2', False) 10 | 11 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 12 | if use_s2: 13 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 14 | else: 15 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 16 | elif "mobileclip" in vision_tower.lower(): 17 | return MobileCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 18 | 19 | raise ValueError(f'Unknown vision tower: {vision_tower}') 20 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False) 17 | self.input_image_size = getattr(args, 'input_image_size', None) 18 | 19 | if self.tune_vision_tower: 20 | print("CLIP Vision tower is set to tunable") 21 | 22 | if not delay_load: 23 | self.load_model() 24 | elif getattr(args, 'unfreeze_mm_vision_tower', False): 25 | self.load_model() 26 | else: 27 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 28 | if self.input_image_size is not None: 29 | self.cfg_only.image_size = self.input_image_size 30 | 31 | def load_model(self, device_map=None): 32 | if self.is_loaded: 33 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 34 | return 35 | 36 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 37 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 38 | if not self.tune_vision_tower: 39 | self.vision_tower.requires_grad_(False) 40 | 41 | if self.input_image_size is not None: 42 | print("Using input image size: {}".format(self.input_image_size)) 43 | self.image_processor.size['shortest_edge'] = self.input_image_size 44 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.input_image_size 45 | 46 | self.is_loaded = True 47 | 48 | def feature_select(self, image_forward_outs): 49 | image_features = image_forward_outs.hidden_states[self.select_layer] 50 | if self.select_feature == 'patch': 51 | image_features = image_features[:, 1:] 52 | elif self.select_feature == 'cls_patch': 53 | image_features = image_features 54 | else: 55 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 56 | return image_features 57 | 58 | def forward(self, images): 59 | if self.tune_vision_tower: 60 | return self.forward_images(images) 61 | else: 62 | with torch.no_grad(): 63 | return self.forward_images(images) 64 | 65 | def forward_images(self, images): 66 | if type(images) is list: 67 | image_features = [] 68 | for image in images: 69 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 70 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 71 | image_features.append(image_feature) 72 | else: 73 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 74 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 75 | 76 | return image_features 77 | 78 | @property 79 | def dummy_feature(self): 80 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 81 | 82 | @property 83 | def dtype(self): 84 | return self.vision_tower.dtype 85 | 86 | @property 87 | def device(self): 88 | return self.vision_tower.device 89 | 90 | @property 91 | def config(self): 92 | if self.is_loaded: 93 | return self.vision_tower.config 94 | else: 95 | return self.cfg_only 96 | 97 | @property 98 | def hidden_size(self): 99 | return self.config.hidden_size 100 | 101 | @property 102 | def num_patches_per_side(self): 103 | return self.config.image_size // self.config.patch_size 104 | 105 | @property 106 | def num_patches(self): 107 | return (self.config.image_size // self.config.patch_size) ** 2 108 | 109 | 110 | 111 | class CLIPVisionTowerS2(CLIPVisionTower): 112 | def __init__(self, vision_tower, args, delay_load=False): 113 | self.s2_scales = getattr(args, 's2_scales', '336,672,1008') 114 | self.s2_scales = list(map(int, self.s2_scales.split(','))) 115 | self.s2_scales.sort() 116 | self.s2_split_size = self.s2_scales[0] 117 | self.s2_image_size = self.s2_scales[-1] 118 | 119 | super().__init__(vision_tower, args, delay_load) 120 | 121 | try: 122 | from s2wrapper import forward as multiscale_forward 123 | except ImportError: 124 | raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git') 125 | self.multiscale_forward = multiscale_forward 126 | 127 | # change resize/crop size in preprocessing to the largest image size in s2_scale 128 | if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False): 129 | self.image_processor.size['shortest_edge'] = self.s2_image_size 130 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size 131 | 132 | def load_model(self, device_map=None): 133 | if self.is_loaded: 134 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 135 | return 136 | 137 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 138 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 139 | self.vision_tower.requires_grad_(False) 140 | 141 | self.image_processor.size['shortest_edge'] = self.s2_image_size 142 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size 143 | 144 | self.is_loaded = True 145 | 146 | @torch.no_grad() 147 | def forward_feature(self, images): 148 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 149 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 150 | return image_features 151 | 152 | @torch.no_grad() 153 | def forward(self, images): 154 | if type(images) is list: 155 | image_features = [] 156 | for image in images: 157 | image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size) 158 | image_features.append(image_feature) 159 | else: 160 | image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size) 161 | 162 | return image_features 163 | 164 | @property 165 | def hidden_size(self): 166 | return self.config.hidden_size * len(self.s2_scales) 167 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/mobileclip/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import json 7 | from typing import Any 8 | 9 | import torch.nn as nn 10 | from timm.models import create_model 11 | 12 | from .mci import GlobalPool2D 13 | 14 | 15 | def load_model_config( 16 | model_name: str, 17 | ) -> Any: 18 | # Strip suffixes to model name 19 | model_name = "_".join(model_name.split("_")[0:2]) 20 | 21 | # Config files 22 | root_dir = os.path.dirname(os.path.abspath(__file__)) 23 | configs_dir = os.path.join(root_dir, "configs") 24 | model_cfg_file = os.path.join(configs_dir, model_name + ".json") 25 | 26 | # Get config from yaml file 27 | if not os.path.exists(model_cfg_file): 28 | raise ValueError(f"Unsupported model name: {model_name}") 29 | model_cfg = json.load(open(model_cfg_file, "r")) 30 | 31 | return model_cfg 32 | 33 | 34 | class MCi(nn.Module): 35 | """ 36 | This class implements `MCi Models `_ 37 | """ 38 | 39 | def __init__(self, model_name: str, *args, **kwargs) -> None: 40 | super().__init__() 41 | self.projection_dim = None 42 | if "projection_dim" in kwargs: 43 | self.projection_dim = kwargs.get("projection_dim") 44 | 45 | # Create model 46 | self.model = create_model(model_name, projection_dim=self.projection_dim) 47 | 48 | # Build out projection head. 49 | if self.projection_dim is not None: 50 | if hasattr(self.model, "head"): 51 | self.model.head = MCi._update_image_classifier( 52 | image_classifier=self.model.head, projection_dim=self.projection_dim 53 | ) 54 | 55 | def forward(self, x: Any, *args, **kwargs) -> Any: 56 | """A forward function of the model.""" 57 | x = self.model(x, *args, **kwargs) 58 | return x 59 | 60 | @staticmethod 61 | def _get_in_feature_dimension(image_classifier: nn.Module) -> int: 62 | """Return the input feature dimension to the image classification head.""" 63 | in_features = None 64 | if isinstance(image_classifier, nn.Sequential): 65 | # Classifier that uses nn.Sequential usually has global pooling and 66 | # multiple linear layers. Find the first linear layer and get its 67 | # in_features 68 | for layer in image_classifier: 69 | if isinstance(layer, nn.Linear): 70 | in_features = layer.in_features 71 | break 72 | elif isinstance(image_classifier, nn.Linear): 73 | in_features = image_classifier.in_features 74 | 75 | if in_features is None: 76 | raise NotImplementedError( 77 | f"Cannot get input feature dimension of {image_classifier}." 78 | ) 79 | return in_features 80 | 81 | @staticmethod 82 | def _update_image_classifier( 83 | image_classifier: nn.Module, projection_dim: int, *args, **kwargs 84 | ) -> nn.Module: 85 | in_features = MCi._get_in_feature_dimension(image_classifier) 86 | new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim) 87 | return new_img_classifier 88 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/mobileclip/configs/mobileclip_l.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "image_cfg": { 4 | "image_size": 1024, 5 | "model_name": "fastvithd", 6 | "embed_dim": 3072, 7 | "patch_size": 64 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "dim": 768, 13 | "ffn_multiplier_per_layer": 4.0, 14 | "n_heads_per_layer": 12, 15 | "n_transformer_layers": 12, 16 | "norm_layer": "layer_norm_fp32", 17 | "causal_masking": false, 18 | "model_name": "base" 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/mobileclip_encoder.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from transformers import CLIPImageProcessor 10 | import llava.model.multimodal_encoder.mobileclip as mobileclip 11 | 12 | 13 | class MobileCLIPVisionTower(nn.Module): 14 | def __init__(self, vision_tower, args, delay_load=False): 15 | super().__init__() 16 | 17 | self.is_loaded = False 18 | self.vision_tower_name = vision_tower 19 | self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False) 20 | self.input_image_size = int(vision_tower.split("_")[-1]) 21 | 22 | # Delay load is disabled for now 23 | if not delay_load: 24 | self.load_model() 25 | elif getattr(args, 'unfreeze_mm_vision_tower', False): 26 | self.load_model() 27 | else: 28 | model_cfg = mobileclip.load_model_config(self.vision_tower_name) 29 | self.cfg_only = model_cfg 30 | 31 | def load_model(self, device_map=None): 32 | if self.is_loaded: 33 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 34 | return 35 | 36 | # Load model config 37 | model_cfg = mobileclip.load_model_config(self.vision_tower_name) 38 | 39 | # Override default image resolution 40 | model_cfg["image_cfg"]["image_size"] = self.input_image_size 41 | 42 | self.cfg_only = model_cfg 43 | 44 | # Build HF CLIPImageProcessor with MobileCLIP parameters 45 | self.image_processor = CLIPImageProcessor(crop_size={"height": model_cfg["image_cfg"]["image_size"], 46 | "width": model_cfg["image_cfg"]["image_size"]}, 47 | image_mean=[0.0, 0.0, 0.0], 48 | image_std=[1.0, 1.0, 1.0], 49 | size={"shortest_edge": model_cfg["image_cfg"]["image_size"]}) 50 | 51 | # Instantiate the image encoder 52 | self.vision_tower = mobileclip.MCi(model_name=model_cfg["image_cfg"]["model_name"], 53 | projection_dim=model_cfg["embed_dim"]) 54 | 55 | if not self.tune_vision_tower: 56 | self.vision_tower.requires_grad_(False) 57 | 58 | self.is_loaded = True 59 | 60 | def feature_select(self, image_forward_outs): 61 | # Features from penultimate layer 62 | image_features = image_forward_outs["image_embeddings"] 63 | 64 | # Reshape 4D tensor to 3D 65 | B, C, H, W = image_features.shape 66 | image_features = image_features.reshape(B, C, H*W) 67 | image_features = image_features.transpose(1, 2) 68 | return image_features 69 | 70 | def forward(self, images): 71 | if self.tune_vision_tower: 72 | return self.forward_images(images) 73 | else: 74 | with torch.no_grad(): 75 | return self.forward_images(images) 76 | 77 | def forward_images(self, images): 78 | if type(images) is list: 79 | image_features = [] 80 | for image in images: 81 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), return_image_embeddings=True) 82 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 83 | image_features.append(image_feature) 84 | else: 85 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), return_image_embeddings=True) 86 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 87 | 88 | return image_features 89 | 90 | @property 91 | def dummy_feature(self): 92 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 93 | 94 | @property 95 | def dtype(self): 96 | return next(self.vision_tower.parameters()).dtype 97 | 98 | @property 99 | def device(self): 100 | return next(self.vision_tower.parameters()).device 101 | 102 | @property 103 | def config(self): 104 | return self.cfg_only 105 | 106 | @property 107 | def hidden_size(self): 108 | return self.config["image_cfg"]["embed_dim"] 109 | 110 | @property 111 | def num_patches_per_side(self): 112 | return self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"] 113 | 114 | @property 115 | def num_patches(self): 116 | return (self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]) ** 2 117 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import re 3 | 4 | 5 | class IdentityMap(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, x, *args, **kwargs): 10 | return x 11 | 12 | @property 13 | def config(self): 14 | return {"mm_projector_type": 'identity'} 15 | 16 | 17 | def build_vision_projector(config, delay_load=False, **kwargs): 18 | projector_type = getattr(config, 'mm_projector_type', 'linear') 19 | 20 | if projector_type == 'linear': 21 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 22 | 23 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 24 | if mlp_gelu_match: 25 | mlp_depth = int(mlp_gelu_match.group(1)) 26 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 27 | for _ in range(1, mlp_depth): 28 | modules.append(nn.GELU()) 29 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 30 | return nn.Sequential(*modules) 31 | 32 | if projector_type == 'identity': 33 | return IdentityMap() 34 | 35 | raise ValueError(f'Unknown projector type: {projector_type}') 36 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http://') or image_file.startswith('https://'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 33 | 34 | if "llama-2" in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "mistral" in model_name.lower(): 37 | conv_mode = "mistral_instruct" 38 | elif "v1.6-34b" in model_name.lower(): 39 | conv_mode = "chatml_direct" 40 | elif "v1" in model_name.lower(): 41 | conv_mode = "llava_v1" 42 | elif "mpt" in model_name.lower(): 43 | conv_mode = "mpt" 44 | else: 45 | conv_mode = "llava_v0" 46 | 47 | if args.conv_mode is not None and conv_mode != args.conv_mode: 48 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 49 | else: 50 | args.conv_mode = conv_mode 51 | 52 | conv = conv_templates[args.conv_mode].copy() 53 | if "mpt" in model_name.lower(): 54 | roles = ('user', 'assistant') 55 | else: 56 | roles = conv.roles 57 | 58 | image = load_image(args.image_file) 59 | image_size = image.size 60 | # Similar operation in model_worker.py 61 | image_tensor = process_images([image], image_processor, model.config) 62 | if type(image_tensor) is list: 63 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 64 | else: 65 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 66 | 67 | while True: 68 | try: 69 | inp = input(f"{roles[0]}: ") 70 | except EOFError: 71 | inp = "" 72 | if not inp: 73 | print("exit...") 74 | break 75 | 76 | print(f"{roles[1]}: ", end="") 77 | 78 | if image is not None: 79 | # first message 80 | if model.config.mm_use_im_start_end: 81 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 82 | else: 83 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 84 | image = None 85 | 86 | conv.append_message(conv.roles[0], inp) 87 | conv.append_message(conv.roles[1], None) 88 | prompt = conv.get_prompt() 89 | 90 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 91 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 92 | keywords = [stop_str] 93 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 94 | 95 | with torch.inference_mode(): 96 | output_ids = model.generate( 97 | input_ids, 98 | images=image_tensor, 99 | image_sizes=[image_size], 100 | do_sample=True if args.temperature > 0 else False, 101 | temperature=args.temperature, 102 | max_new_tokens=args.max_new_tokens, 103 | streamer=streamer, 104 | use_cache=True) 105 | 106 | outputs = tokenizer.decode(output_ids[0]).strip() 107 | conv.messages[-1][-1] = outputs 108 | 109 | if args.debug: 110 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 116 | parser.add_argument("--model-base", type=str, default=None) 117 | parser.add_argument("--image-file", type=str, required=True) 118 | parser.add_argument("--device", type=str, default="cuda") 119 | parser.add_argument("--conv-mode", type=str, default=None) 120 | parser.add_argument("--temperature", type=float, default=0.2) 121 | parser.add_argument("--max-new-tokens", type=int, default=512) 122 | parser.add_argument("--load-8bit", action="store_true") 123 | parser.add_argument("--load-4bit", action="store_true") 124 | parser.add_argument("--debug", action="store_true") 125 | args = parser.parse_args() 126 | main(args) 127 | -------------------------------------------------------------------------------- /llava/serve/controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | A controller manages distributed workers. 3 | It sends worker addresses to clients. 4 | """ 5 | import argparse 6 | import asyncio 7 | import dataclasses 8 | from enum import Enum, auto 9 | import json 10 | import logging 11 | import time 12 | from typing import List, Union 13 | import threading 14 | 15 | from fastapi import FastAPI, Request 16 | from fastapi.responses import StreamingResponse 17 | import numpy as np 18 | import requests 19 | import uvicorn 20 | 21 | from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION 22 | from llava.utils import build_logger, server_error_msg 23 | 24 | 25 | logger = build_logger("controller", "controller.log") 26 | 27 | 28 | class DispatchMethod(Enum): 29 | LOTTERY = auto() 30 | SHORTEST_QUEUE = auto() 31 | 32 | @classmethod 33 | def from_str(cls, name): 34 | if name == "lottery": 35 | return cls.LOTTERY 36 | elif name == "shortest_queue": 37 | return cls.SHORTEST_QUEUE 38 | else: 39 | raise ValueError(f"Invalid dispatch method") 40 | 41 | 42 | @dataclasses.dataclass 43 | class WorkerInfo: 44 | model_names: List[str] 45 | speed: int 46 | queue_length: int 47 | check_heart_beat: bool 48 | last_heart_beat: str 49 | 50 | 51 | def heart_beat_controller(controller): 52 | while True: 53 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) 54 | controller.remove_stable_workers_by_expiration() 55 | 56 | 57 | class Controller: 58 | def __init__(self, dispatch_method: str): 59 | # Dict[str -> WorkerInfo] 60 | self.worker_info = {} 61 | self.dispatch_method = DispatchMethod.from_str(dispatch_method) 62 | 63 | self.heart_beat_thread = threading.Thread( 64 | target=heart_beat_controller, args=(self,), daemon=True) 65 | self.heart_beat_thread.start() 66 | 67 | logger.info("Init controller") 68 | 69 | def register_worker(self, worker_name: str, check_heart_beat: bool, 70 | worker_status: dict): 71 | if worker_name not in self.worker_info: 72 | logger.info(f"Register a new worker: {worker_name}") 73 | else: 74 | logger.info(f"Register an existing worker: {worker_name}") 75 | 76 | if not worker_status: 77 | worker_status = self.get_worker_status(worker_name) 78 | if not worker_status: 79 | return False 80 | 81 | self.worker_info[worker_name] = WorkerInfo( 82 | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], 83 | check_heart_beat, time.time()) 84 | 85 | logger.info(f"Register done: {worker_name}, {worker_status}") 86 | return True 87 | 88 | def get_worker_status(self, worker_name: str): 89 | try: 90 | r = requests.post(worker_name + "/worker_get_status", timeout=5) 91 | except requests.exceptions.RequestException as e: 92 | logger.error(f"Get status fails: {worker_name}, {e}") 93 | return None 94 | 95 | if r.status_code != 200: 96 | logger.error(f"Get status fails: {worker_name}, {r}") 97 | return None 98 | 99 | return r.json() 100 | 101 | def remove_worker(self, worker_name: str): 102 | del self.worker_info[worker_name] 103 | 104 | def refresh_all_workers(self): 105 | old_info = dict(self.worker_info) 106 | self.worker_info = {} 107 | 108 | for w_name, w_info in old_info.items(): 109 | if not self.register_worker(w_name, w_info.check_heart_beat, None): 110 | logger.info(f"Remove stale worker: {w_name}") 111 | 112 | def list_models(self): 113 | model_names = set() 114 | 115 | for w_name, w_info in self.worker_info.items(): 116 | model_names.update(w_info.model_names) 117 | 118 | return list(model_names) 119 | 120 | def get_worker_address(self, model_name: str): 121 | if self.dispatch_method == DispatchMethod.LOTTERY: 122 | worker_names = [] 123 | worker_speeds = [] 124 | for w_name, w_info in self.worker_info.items(): 125 | if model_name in w_info.model_names: 126 | worker_names.append(w_name) 127 | worker_speeds.append(w_info.speed) 128 | worker_speeds = np.array(worker_speeds, dtype=np.float32) 129 | norm = np.sum(worker_speeds) 130 | if norm < 1e-4: 131 | return "" 132 | worker_speeds = worker_speeds / norm 133 | if True: # Directly return address 134 | pt = np.random.choice(np.arange(len(worker_names)), 135 | p=worker_speeds) 136 | worker_name = worker_names[pt] 137 | return worker_name 138 | 139 | # Check status before returning 140 | while True: 141 | pt = np.random.choice(np.arange(len(worker_names)), 142 | p=worker_speeds) 143 | worker_name = worker_names[pt] 144 | 145 | if self.get_worker_status(worker_name): 146 | break 147 | else: 148 | self.remove_worker(worker_name) 149 | worker_speeds[pt] = 0 150 | norm = np.sum(worker_speeds) 151 | if norm < 1e-4: 152 | return "" 153 | worker_speeds = worker_speeds / norm 154 | continue 155 | return worker_name 156 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: 157 | worker_names = [] 158 | worker_qlen = [] 159 | for w_name, w_info in self.worker_info.items(): 160 | if model_name in w_info.model_names: 161 | worker_names.append(w_name) 162 | worker_qlen.append(w_info.queue_length / w_info.speed) 163 | if len(worker_names) == 0: 164 | return "" 165 | min_index = np.argmin(worker_qlen) 166 | w_name = worker_names[min_index] 167 | self.worker_info[w_name].queue_length += 1 168 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") 169 | return w_name 170 | else: 171 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") 172 | 173 | def receive_heart_beat(self, worker_name: str, queue_length: int): 174 | if worker_name not in self.worker_info: 175 | logger.info(f"Receive unknown heart beat. {worker_name}") 176 | return False 177 | 178 | self.worker_info[worker_name].queue_length = queue_length 179 | self.worker_info[worker_name].last_heart_beat = time.time() 180 | logger.info(f"Receive heart beat. {worker_name}") 181 | return True 182 | 183 | def remove_stable_workers_by_expiration(self): 184 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION 185 | to_delete = [] 186 | for worker_name, w_info in self.worker_info.items(): 187 | if w_info.check_heart_beat and w_info.last_heart_beat < expire: 188 | to_delete.append(worker_name) 189 | 190 | for worker_name in to_delete: 191 | self.remove_worker(worker_name) 192 | 193 | def worker_api_generate_stream(self, params): 194 | worker_addr = self.get_worker_address(params["model"]) 195 | if not worker_addr: 196 | logger.info(f"no worker: {params['model']}") 197 | ret = { 198 | "text": server_error_msg, 199 | "error_code": 2, 200 | } 201 | yield json.dumps(ret).encode() + b"\0" 202 | 203 | try: 204 | response = requests.post(worker_addr + "/worker_generate_stream", 205 | json=params, stream=True, timeout=5) 206 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 207 | if chunk: 208 | yield chunk + b"\0" 209 | except requests.exceptions.RequestException as e: 210 | logger.info(f"worker timeout: {worker_addr}") 211 | ret = { 212 | "text": server_error_msg, 213 | "error_code": 3, 214 | } 215 | yield json.dumps(ret).encode() + b"\0" 216 | 217 | # Let the controller act as a worker to achieve hierarchical 218 | # management. This can be used to connect isolated sub networks. 219 | 220 | def worker_api_get_status(self): 221 | model_names = set() 222 | speed = 0 223 | queue_length = 0 224 | 225 | for w_name in self.worker_info: 226 | worker_status = self.get_worker_status(w_name) 227 | if worker_status is not None: 228 | model_names.update(worker_status["model_names"]) 229 | speed += worker_status["speed"] 230 | queue_length += worker_status["queue_length"] 231 | 232 | return { 233 | "model_names": list(model_names), 234 | "speed": speed, 235 | "queue_length": queue_length, 236 | } 237 | 238 | 239 | app = FastAPI() 240 | 241 | 242 | @app.post("/register_worker") 243 | async def register_worker(request: Request): 244 | data = await request.json() 245 | controller.register_worker( 246 | data["worker_name"], data["check_heart_beat"], 247 | data.get("worker_status", None)) 248 | 249 | 250 | @app.post("/refresh_all_workers") 251 | async def refresh_all_workers(): 252 | models = controller.refresh_all_workers() 253 | 254 | 255 | @app.post("/list_models") 256 | async def list_models(): 257 | models = controller.list_models() 258 | return {"models": models} 259 | 260 | 261 | @app.post("/get_worker_address") 262 | async def get_worker_address(request: Request): 263 | data = await request.json() 264 | addr = controller.get_worker_address(data["model"]) 265 | return {"address": addr} 266 | 267 | 268 | @app.post("/receive_heart_beat") 269 | async def receive_heart_beat(request: Request): 270 | data = await request.json() 271 | exist = controller.receive_heart_beat( 272 | data["worker_name"], data["queue_length"]) 273 | return {"exist": exist} 274 | 275 | 276 | @app.post("/worker_generate_stream") 277 | async def worker_api_generate_stream(request: Request): 278 | params = await request.json() 279 | generator = controller.worker_api_generate_stream(params) 280 | return StreamingResponse(generator) 281 | 282 | 283 | @app.post("/worker_get_status") 284 | async def worker_api_get_status(request: Request): 285 | return controller.worker_api_get_status() 286 | 287 | 288 | if __name__ == "__main__": 289 | parser = argparse.ArgumentParser() 290 | parser.add_argument("--host", type=str, default="localhost") 291 | parser.add_argument("--port", type=int, default=21001) 292 | parser.add_argument("--dispatch-method", type=str, choices=[ 293 | "lottery", "shortest_queue"], default="shortest_queue") 294 | args = parser.parse_args() 295 | logger.info(f"args: {args}") 296 | 297 | controller = Controller(args.dispatch_method) 298 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 299 | -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/model_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model worker executes the model. 3 | """ 4 | import argparse 5 | import asyncio 6 | import json 7 | import time 8 | import threading 9 | import uuid 10 | 11 | from fastapi import FastAPI, Request, BackgroundTasks 12 | from fastapi.responses import StreamingResponse 13 | import requests 14 | import torch 15 | import uvicorn 16 | from functools import partial 17 | 18 | from llava.constants import WORKER_HEART_BEAT_INTERVAL 19 | from llava.utils import (build_logger, server_error_msg, 20 | pretty_print_semaphore) 21 | from llava.model.builder import load_pretrained_model 22 | from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token 23 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | from transformers import TextIteratorStreamer 25 | from threading import Thread 26 | 27 | 28 | GB = 1 << 30 29 | 30 | worker_id = str(uuid.uuid4())[:6] 31 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log") 32 | global_counter = 0 33 | 34 | model_semaphore = None 35 | 36 | 37 | def heart_beat_worker(controller): 38 | 39 | while True: 40 | time.sleep(WORKER_HEART_BEAT_INTERVAL) 41 | controller.send_heart_beat() 42 | 43 | 44 | class ModelWorker: 45 | def __init__(self, controller_addr, worker_addr, 46 | worker_id, no_register, 47 | model_path, model_base, model_name, 48 | load_8bit, load_4bit, device, use_flash_attn=False): 49 | self.controller_addr = controller_addr 50 | self.worker_addr = worker_addr 51 | self.worker_id = worker_id 52 | if model_path.endswith("/"): 53 | model_path = model_path[:-1] 54 | if model_name is None: 55 | model_paths = model_path.split("/") 56 | if model_paths[-1].startswith('checkpoint-'): 57 | self.model_name = model_paths[-2] + "_" + model_paths[-1] 58 | else: 59 | self.model_name = model_paths[-1] 60 | else: 61 | self.model_name = model_name 62 | 63 | self.device = device 64 | logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") 65 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( 66 | model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn) 67 | self.is_multimodal = 'llava' in self.model_name.lower() 68 | 69 | if not no_register: 70 | self.register_to_controller() 71 | self.heart_beat_thread = threading.Thread( 72 | target=heart_beat_worker, args=(self,), daemon=True) 73 | self.heart_beat_thread.start() 74 | 75 | def register_to_controller(self): 76 | logger.info("Register to controller") 77 | 78 | url = self.controller_addr + "/register_worker" 79 | data = { 80 | "worker_name": self.worker_addr, 81 | "check_heart_beat": True, 82 | "worker_status": self.get_status() 83 | } 84 | r = requests.post(url, json=data) 85 | assert r.status_code == 200 86 | 87 | def send_heart_beat(self): 88 | logger.info(f"Send heart beat. Models: {[self.model_name]}. " 89 | f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " 90 | f"global_counter: {global_counter}") 91 | 92 | url = self.controller_addr + "/receive_heart_beat" 93 | 94 | while True: 95 | try: 96 | ret = requests.post(url, json={ 97 | "worker_name": self.worker_addr, 98 | "queue_length": self.get_queue_length()}, timeout=5) 99 | exist = ret.json()["exist"] 100 | break 101 | except requests.exceptions.RequestException as e: 102 | logger.error(f"heart beat error: {e}") 103 | time.sleep(5) 104 | 105 | if not exist: 106 | self.register_to_controller() 107 | 108 | def get_queue_length(self): 109 | if model_semaphore is None: 110 | return 0 111 | else: 112 | return args.limit_model_concurrency - model_semaphore._value + (len( 113 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0) 114 | 115 | def get_status(self): 116 | return { 117 | "model_names": [self.model_name], 118 | "speed": 1, 119 | "queue_length": self.get_queue_length(), 120 | } 121 | 122 | @torch.inference_mode() 123 | def generate_stream(self, params): 124 | tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor 125 | 126 | prompt = params["prompt"] 127 | ori_prompt = prompt 128 | images = params.get("images", None) 129 | num_image_tokens = 0 130 | if images is not None and len(images) > 0 and self.is_multimodal: 131 | if len(images) > 0: 132 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): 133 | raise ValueError("Number of images does not match number of tokens in prompt") 134 | 135 | images = [load_image_from_base64(image) for image in images] 136 | image_sizes = [image.size for image in images] 137 | images = process_images(images, image_processor, model.config) 138 | 139 | if type(images) is list: 140 | images = [image.to(self.model.device, dtype=torch.float16) for image in images] 141 | else: 142 | images = images.to(self.model.device, dtype=torch.float16) 143 | 144 | replace_token = DEFAULT_IMAGE_TOKEN 145 | if getattr(self.model.config, 'mm_use_im_start_end', False): 146 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 147 | prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) 148 | 149 | num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches 150 | else: 151 | images = None 152 | image_sizes = None 153 | image_args = {"images": images, "image_sizes": image_sizes} 154 | else: 155 | images = None 156 | image_args = {} 157 | 158 | temperature = float(params.get("temperature", 1.0)) 159 | top_p = float(params.get("top_p", 1.0)) 160 | max_context_length = getattr(model.config, 'max_position_embeddings', 2048) 161 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) 162 | stop_str = params.get("stop", None) 163 | do_sample = True if temperature > 0.001 else False 164 | 165 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) 166 | keywords = [stop_str] 167 | # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 168 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) 169 | 170 | max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) 171 | 172 | if max_new_tokens < 1: 173 | yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" 174 | return 175 | 176 | thread = Thread(target=model.generate, kwargs=dict( 177 | inputs=input_ids, 178 | do_sample=do_sample, 179 | temperature=temperature, 180 | top_p=top_p, 181 | max_new_tokens=max_new_tokens, 182 | streamer=streamer, 183 | use_cache=True, 184 | **image_args 185 | )) 186 | thread.start() 187 | 188 | generated_text = ori_prompt 189 | for new_text in streamer: 190 | generated_text += new_text 191 | if generated_text.endswith(stop_str): 192 | generated_text = generated_text[:-len(stop_str)] 193 | yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" 194 | 195 | def generate_stream_gate(self, params): 196 | try: 197 | for x in self.generate_stream(params): 198 | yield x 199 | except ValueError as e: 200 | print("Caught ValueError:", e) 201 | ret = { 202 | "text": server_error_msg, 203 | "error_code": 1, 204 | } 205 | yield json.dumps(ret).encode() + b"\0" 206 | except torch.cuda.CudaError as e: 207 | print("Caught torch.cuda.CudaError:", e) 208 | ret = { 209 | "text": server_error_msg, 210 | "error_code": 1, 211 | } 212 | yield json.dumps(ret).encode() + b"\0" 213 | except Exception as e: 214 | print("Caught Unknown Error", e) 215 | ret = { 216 | "text": server_error_msg, 217 | "error_code": 1, 218 | } 219 | yield json.dumps(ret).encode() + b"\0" 220 | 221 | 222 | app = FastAPI() 223 | 224 | 225 | def release_model_semaphore(fn=None): 226 | model_semaphore.release() 227 | if fn is not None: 228 | fn() 229 | 230 | 231 | @app.post("/worker_generate_stream") 232 | async def generate_stream(request: Request): 233 | global model_semaphore, global_counter 234 | global_counter += 1 235 | params = await request.json() 236 | 237 | if model_semaphore is None: 238 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) 239 | await model_semaphore.acquire() 240 | worker.send_heart_beat() 241 | generator = worker.generate_stream_gate(params) 242 | background_tasks = BackgroundTasks() 243 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) 244 | return StreamingResponse(generator, background=background_tasks) 245 | 246 | 247 | @app.post("/worker_get_status") 248 | async def get_status(request: Request): 249 | return worker.get_status() 250 | 251 | 252 | if __name__ == "__main__": 253 | parser = argparse.ArgumentParser() 254 | parser.add_argument("--host", type=str, default="localhost") 255 | parser.add_argument("--port", type=int, default=21002) 256 | parser.add_argument("--worker-address", type=str, 257 | default="http://localhost:21002") 258 | parser.add_argument("--controller-address", type=str, 259 | default="http://localhost:21001") 260 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 261 | parser.add_argument("--model-base", type=str, default=None) 262 | parser.add_argument("--model-name", type=str) 263 | parser.add_argument("--device", type=str, default="cuda") 264 | parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") 265 | parser.add_argument("--limit-model-concurrency", type=int, default=5) 266 | parser.add_argument("--stream-interval", type=int, default=1) 267 | parser.add_argument("--no-register", action="store_true") 268 | parser.add_argument("--load-8bit", action="store_true") 269 | parser.add_argument("--load-4bit", action="store_true") 270 | parser.add_argument("--use-flash-attn", action="store_true") 271 | args = parser.parse_args() 272 | logger.info(f"args: {args}") 273 | 274 | if args.multi_modal: 275 | logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") 276 | 277 | worker = ModelWorker(args.controller_address, 278 | args.worker_address, 279 | worker_id, 280 | args.no_register, 281 | args.model_path, 282 | args.model_base, 283 | args.model_name, 284 | args.load_8bit, 285 | args.load_4bit, 286 | args.device, 287 | use_flash_attn=args.use_flash_attn) 288 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 289 | -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/sglang_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model worker executes the model. 3 | """ 4 | import argparse 5 | import asyncio 6 | from concurrent.futures import ThreadPoolExecutor 7 | import json 8 | import time 9 | import threading 10 | import uuid 11 | 12 | from fastapi import FastAPI, Request, BackgroundTasks 13 | from fastapi.responses import StreamingResponse 14 | import requests 15 | import re 16 | import uvicorn 17 | from functools import partial 18 | 19 | from llava.constants import WORKER_HEART_BEAT_INTERVAL 20 | from llava.utils import (build_logger, server_error_msg, 21 | pretty_print_semaphore) 22 | from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square 23 | from llava.constants import DEFAULT_IMAGE_TOKEN 24 | 25 | import sglang as sgl 26 | from sglang.backend.runtime_endpoint import RuntimeEndpoint 27 | 28 | 29 | GB = 1 << 30 30 | 31 | worker_id = str(uuid.uuid4())[:6] 32 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log") 33 | global_counter = 0 34 | 35 | model_semaphore = None 36 | 37 | 38 | def heart_beat_worker(controller): 39 | while True: 40 | time.sleep(WORKER_HEART_BEAT_INTERVAL) 41 | controller.send_heart_beat() 42 | 43 | 44 | @sgl.function 45 | def pipeline(s, prompt, max_tokens): 46 | for p in prompt: 47 | if type(p) is str: 48 | s += p 49 | else: 50 | s += sgl.image(p) 51 | s += sgl.gen("response", max_tokens=max_tokens) 52 | 53 | 54 | class ModelWorker: 55 | def __init__(self, controller_addr, worker_addr, sgl_endpoint, 56 | worker_id, no_register, model_name): 57 | self.controller_addr = controller_addr 58 | self.worker_addr = worker_addr 59 | self.worker_id = worker_id 60 | 61 | # Select backend 62 | backend = RuntimeEndpoint(sgl_endpoint) 63 | sgl.set_default_backend(backend) 64 | model_path = backend.model_info["model_path"] 65 | 66 | if model_path.endswith("/"): 67 | model_path = model_path[:-1] 68 | if model_name is None: 69 | model_paths = model_path.split("/") 70 | if model_paths[-1].startswith('checkpoint-'): 71 | self.model_name = model_paths[-2] + "_" + model_paths[-1] 72 | else: 73 | self.model_name = model_paths[-1] 74 | else: 75 | self.model_name = model_name 76 | 77 | logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...") 78 | 79 | if not no_register: 80 | self.register_to_controller() 81 | self.heart_beat_thread = threading.Thread( 82 | target=heart_beat_worker, args=(self,), daemon=True) 83 | self.heart_beat_thread.start() 84 | 85 | def register_to_controller(self): 86 | logger.info("Register to controller") 87 | 88 | url = self.controller_addr + "/register_worker" 89 | data = { 90 | "worker_name": self.worker_addr, 91 | "check_heart_beat": True, 92 | "worker_status": self.get_status() 93 | } 94 | r = requests.post(url, json=data) 95 | assert r.status_code == 200 96 | 97 | def send_heart_beat(self): 98 | logger.info(f"Send heart beat. Models: {[self.model_name]}. " 99 | f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " 100 | f"global_counter: {global_counter}") 101 | 102 | url = self.controller_addr + "/receive_heart_beat" 103 | 104 | while True: 105 | try: 106 | ret = requests.post(url, json={ 107 | "worker_name": self.worker_addr, 108 | "queue_length": self.get_queue_length()}, timeout=5) 109 | exist = ret.json()["exist"] 110 | break 111 | except requests.exceptions.RequestException as e: 112 | logger.error(f"heart beat error: {e}") 113 | time.sleep(5) 114 | 115 | if not exist: 116 | self.register_to_controller() 117 | 118 | def get_queue_length(self): 119 | if model_semaphore is None: 120 | return 0 121 | else: 122 | return args.limit_model_concurrency - model_semaphore._value + (len( 123 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0) 124 | 125 | def get_status(self): 126 | return { 127 | "model_names": [self.model_name], 128 | "speed": 1, 129 | "queue_length": self.get_queue_length(), 130 | } 131 | 132 | async def generate_stream(self, params): 133 | ori_prompt = prompt = params["prompt"] 134 | images = params.get("images", None) 135 | if images is not None and len(images) > 0: 136 | if len(images) > 0: 137 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): 138 | raise ValueError("Number of images does not match number of tokens in prompt") 139 | 140 | images = [load_image_from_base64(image) for image in images] 141 | 142 | # FIXME: for image-start/end token 143 | # replace_token = DEFAULT_IMAGE_TOKEN 144 | # if getattr(self.model.config, 'mm_use_im_start_end', False): 145 | # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 146 | # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) 147 | prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN) 148 | prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN) 149 | prompt = [] 150 | for i in range(len(prompt_split)): 151 | prompt.append(prompt_split[i]) 152 | if i < len(images): 153 | prompt.append(images[i]) 154 | else: 155 | prompt = [prompt] 156 | 157 | temperature = float(params.get("temperature", 1.0)) 158 | top_p = float(params.get("top_p", 1.0)) 159 | # max_context_length = getattr(model.config, 'max_position_embeddings', 2048) 160 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) 161 | stop_str = params.get("stop", None) 162 | stop_str = [stop_str] if stop_str is not None else None 163 | 164 | print({'prompt': prompt, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p}) 165 | state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True) 166 | 167 | generated_text = ori_prompt 168 | async for text_outputs in state.text_async_iter(var_name="response"): 169 | generated_text += text_outputs 170 | yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" 171 | 172 | async def generate_stream_gate(self, params): 173 | try: 174 | async for x in self.generate_stream(params): 175 | yield x 176 | except ValueError as e: 177 | print("Caught ValueError:", e) 178 | ret = { 179 | "text": server_error_msg, 180 | "error_code": 1, 181 | } 182 | yield json.dumps(ret).encode() + b"\0" 183 | except Exception as e: 184 | print("Caught Unknown Error", e) 185 | ret = { 186 | "text": server_error_msg, 187 | "error_code": 1, 188 | } 189 | yield json.dumps(ret).encode() + b"\0" 190 | 191 | 192 | app = FastAPI() 193 | 194 | 195 | def release_model_semaphore(fn=None): 196 | model_semaphore.release() 197 | if fn is not None: 198 | fn() 199 | 200 | 201 | @app.post("/worker_generate_stream") 202 | async def generate_stream(request: Request): 203 | global model_semaphore, global_counter 204 | global_counter += 1 205 | params = await request.json() 206 | 207 | if model_semaphore is None: 208 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) 209 | await model_semaphore.acquire() 210 | worker.send_heart_beat() 211 | generator = worker.generate_stream_gate(params) 212 | background_tasks = BackgroundTasks() 213 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) 214 | return StreamingResponse(generator, background=background_tasks) 215 | 216 | 217 | @app.post("/worker_get_status") 218 | async def get_status(request: Request): 219 | return worker.get_status() 220 | 221 | 222 | if __name__ == "__main__": 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument("--host", type=str, default="localhost") 225 | parser.add_argument("--port", type=int, default=21002) 226 | parser.add_argument("--worker-address", type=str, 227 | default="http://localhost:21002") 228 | parser.add_argument("--controller-address", type=str, 229 | default="http://localhost:21001") 230 | parser.add_argument("--model-name", type=str) 231 | parser.add_argument("--sgl-endpoint", type=str) 232 | parser.add_argument("--limit-model-concurrency", type=int, default=5) 233 | parser.add_argument("--stream-interval", type=int, default=1) 234 | parser.add_argument("--no-register", action="store_true") 235 | args = parser.parse_args() 236 | logger.info(f"args: {args}") 237 | 238 | worker = ModelWorker(args.controller_address, 239 | args.worker_address, 240 | args.sgl_endpoint, 241 | worker_id, 242 | args.no_register, 243 | args.model_name) 244 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 245 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.") 59 | args = parser.parse_args() 60 | 61 | main() 62 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /llava/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.utils.data import Sampler 6 | 7 | import transformers 8 | from transformers import Trainer 9 | from transformers.trainer import ( 10 | is_sagemaker_mp_enabled, 11 | get_parameter_names, 12 | has_length, 13 | # ALL_LAYERNORM_LAYERS, 14 | logger, 15 | ) 16 | from typing import List, Optional 17 | 18 | 19 | ALL_LAYERNORM_LAYERS = [nn.LayerNorm, nn.BatchNorm2d] 20 | 21 | 22 | def maybe_zero_3(param, ignore_status=False, name=None): 23 | from deepspeed import zero 24 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 25 | if hasattr(param, "ds_id"): 26 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 27 | if not ignore_status: 28 | print(name, 'no ignore status') 29 | with zero.GatheredParameters([param]): 30 | param = param.data.detach().cpu().clone() 31 | else: 32 | param = param.detach().cpu().clone() 33 | return param 34 | 35 | 36 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 37 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 38 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 39 | return to_return 40 | 41 | 42 | def split_to_even_chunks(indices, lengths, num_chunks): 43 | """ 44 | Split a list of indices into `chunks` chunks of roughly equal lengths. 45 | """ 46 | 47 | if len(indices) % num_chunks != 0: 48 | return [indices[i::num_chunks] for i in range(num_chunks)] 49 | 50 | num_indices_per_chunk = len(indices) // num_chunks 51 | 52 | chunks = [[] for _ in range(num_chunks)] 53 | chunks_lengths = [0 for _ in range(num_chunks)] 54 | for index in indices: 55 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 56 | chunks[shortest_chunk].append(index) 57 | chunks_lengths[shortest_chunk] += lengths[index] 58 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 59 | chunks_lengths[shortest_chunk] = float("inf") 60 | 61 | return chunks 62 | 63 | 64 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): 65 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 66 | assert all(l != 0 for l in lengths), "Should not have zero length." 67 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): 68 | # all samples are in the same modality 69 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) 70 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) 71 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) 72 | 73 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] 74 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] 75 | megabatch_size = world_size * batch_size 76 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] 77 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] 78 | 79 | last_mm = mm_megabatches[-1] 80 | last_lang = lang_megabatches[-1] 81 | additional_batch = last_mm + last_lang 82 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] 83 | megabatch_indices = torch.randperm(len(megabatches), generator=generator) 84 | megabatches = [megabatches[i] for i in megabatch_indices] 85 | 86 | if len(additional_batch) > 0: 87 | megabatches.append(sorted(additional_batch)) 88 | 89 | return [i for megabatch in megabatches for i in megabatch] 90 | 91 | 92 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 93 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 94 | indices = torch.randperm(len(lengths), generator=generator) 95 | megabatch_size = world_size * batch_size 96 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 97 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 98 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 99 | 100 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 101 | 102 | 103 | class LengthGroupedSampler(Sampler): 104 | r""" 105 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 106 | keeping a bit of randomness. 107 | """ 108 | 109 | def __init__( 110 | self, 111 | batch_size: int, 112 | world_size: int, 113 | lengths: Optional[List[int]] = None, 114 | generator=None, 115 | group_by_modality: bool = False, 116 | ): 117 | if lengths is None: 118 | raise ValueError("Lengths must be provided.") 119 | 120 | self.batch_size = batch_size 121 | self.world_size = world_size 122 | self.lengths = lengths 123 | self.generator = generator 124 | self.group_by_modality = group_by_modality 125 | 126 | def __len__(self): 127 | return len(self.lengths) 128 | 129 | def __iter__(self): 130 | if self.group_by_modality: 131 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 132 | else: 133 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 134 | return iter(indices) 135 | 136 | 137 | class LLaVATrainer(Trainer): 138 | 139 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 140 | if self.train_dataset is None or not has_length(self.train_dataset): 141 | return None 142 | 143 | if self.args.group_by_modality_length: 144 | lengths = self.train_dataset.modality_lengths 145 | return LengthGroupedSampler( 146 | self.args.train_batch_size, 147 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 148 | lengths=lengths, 149 | group_by_modality=True, 150 | ) 151 | else: 152 | return super()._get_train_sampler() 153 | 154 | def create_optimizer(self): 155 | """ 156 | Setup the optimizer. 157 | 158 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 159 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 160 | """ 161 | if is_sagemaker_mp_enabled(): 162 | return super().create_optimizer() 163 | 164 | opt_model = self.model 165 | 166 | if self.optimizer is None: 167 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 168 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 169 | 170 | lr_mapper = {} 171 | if self.args.mm_projector_lr is not None: 172 | lr_mapper["mm_projector"] = self.args.mm_projector_lr 173 | if self.args.mm_vision_tower_lr is not None: 174 | lr_mapper["vision_tower"] = self.args.mm_vision_tower_lr 175 | 176 | if len(lr_mapper) > 0: 177 | special_lr_parameters = [name for name, _ in opt_model.named_parameters() if 178 | any(module_keyword in name for module_keyword in lr_mapper)] 179 | optimizer_grouped_parameters = [ 180 | { 181 | "params": [p for n, p in opt_model.named_parameters() if 182 | (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)], 183 | "weight_decay": self.args.weight_decay, 184 | }, 185 | { 186 | "params": [p for n, p in opt_model.named_parameters() if 187 | (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)], 188 | "weight_decay": 0.0, 189 | }, 190 | ] 191 | for module_keyword, lr in lr_mapper.items(): 192 | module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name] 193 | optimizer_grouped_parameters.extend( 194 | [ 195 | { 196 | "params": [p for n, p in opt_model.named_parameters() if 197 | (n in decay_parameters and n in module_parameters and p.requires_grad)], 198 | "weight_decay": self.args.weight_decay, 199 | "lr": lr, 200 | }, 201 | { 202 | "params": [p for n, p in opt_model.named_parameters() if 203 | (n not in decay_parameters and n in module_parameters and p.requires_grad)], 204 | "weight_decay": 0.0, 205 | "lr": lr, 206 | }, 207 | ] 208 | ) 209 | else: 210 | optimizer_grouped_parameters = [ 211 | { 212 | "params": [ 213 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 214 | ], 215 | "weight_decay": self.args.weight_decay, 216 | }, 217 | { 218 | "params": [ 219 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) 220 | ], 221 | "weight_decay": 0.0, 222 | }, 223 | ] 224 | 225 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 226 | 227 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 228 | if optimizer_cls.__name__ == "Adam8bit": 229 | import bitsandbytes 230 | 231 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 232 | 233 | skipped = 0 234 | for module in opt_model.modules(): 235 | if isinstance(module, nn.Embedding): 236 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) 237 | logger.info(f"skipped {module}: {skipped/2**20}M params") 238 | manager.register_module_override(module, "weight", {"optim_bits": 32}) 239 | logger.debug(f"bitsandbytes: will optimize {module} in fp32") 240 | logger.info(f"skipped: {skipped/2**20}M params") 241 | 242 | return self.optimizer 243 | 244 | def _save_checkpoint(self, model, trial, metrics=None): 245 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 246 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 247 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 248 | 249 | run_dir = self._get_output_dir(trial=trial) 250 | output_dir = os.path.join(run_dir, checkpoint_folder) 251 | 252 | # Only save Adapter 253 | keys_to_match = ['mm_projector', 'vision_resampler'] 254 | if getattr(self.args, "use_im_start_end", False): 255 | keys_to_match.extend(['embed_tokens', 'embed_in']) 256 | 257 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 258 | 259 | if self.args.local_rank == 0 or self.args.local_rank == -1: 260 | self.model.config.save_pretrained(output_dir) 261 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 262 | else: 263 | # Workaround for the issue: https://github.com/haotian-liu/LLaVA/issues/1144 264 | model.generation_config = transformers.GenerationConfig(do_sample=True, temperature=None, top_p=None) 265 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) 266 | 267 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 268 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 269 | pass 270 | else: 271 | # Workaround for the issue: https://github.com/haotian-liu/LLaVA/issues/1144 272 | self.model.generation_config = transformers.GenerationConfig(do_sample=True, temperature=None, top_p=None) 273 | super(LLaVATrainer, self)._save(output_dir, state_dict) 274 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | from llava.train.train_qwen import train 2 | 3 | if __name__ == "__main__": 4 | train(attn_implementation="flash_attention_2") 5 | -------------------------------------------------------------------------------- /llava/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from llava.train.train import train 5 | from llava.train.llama_xformers_attn_monkey_patch import ( 6 | replace_llama_attn_with_xformers_attn, 7 | ) 8 | 9 | replace_llama_attn_with_xformers_attn() 10 | 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | 65 | def __init__(self, logger, log_level=logging.INFO): 66 | self.terminal = sys.stdout 67 | self.logger = logger 68 | self.log_level = log_level 69 | self.linebuf = '' 70 | 71 | def __getattr__(self, attr): 72 | return getattr(self.terminal, attr) 73 | 74 | def write(self, buf): 75 | temp_linebuf = self.linebuf + buf 76 | self.linebuf = '' 77 | for line in temp_linebuf.splitlines(True): 78 | # From the io.TextIOWrapper docs: 79 | # On output, if newline is None, any '\n' characters written 80 | # are translated to the system default line separator. 81 | # By default sys.stdout.write() expects '\n' newlines and then 82 | # translates them so this is still cross platform. 83 | if line[-1] == '\n': 84 | self.logger.log(self.log_level, line.rstrip()) 85 | else: 86 | self.linebuf += line 87 | 88 | def flush(self): 89 | if self.linebuf != '': 90 | self.logger.log(self.log_level, self.linebuf.rstrip()) 91 | self.linebuf = '' 92 | 93 | 94 | def disable_torch_init(): 95 | """ 96 | Disable the redundant torch default initialization to accelerate model creation. 97 | """ 98 | import torch 99 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 100 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 101 | 102 | 103 | def violates_moderation(text): 104 | """ 105 | Check whether the text violates OpenAI moderation API. 106 | """ 107 | url = "https://api.openai.com/v1/moderations" 108 | headers = {"Content-Type": "application/json", 109 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 110 | text = text.replace("\n", "") 111 | data = "{" + '"input": ' + f'"{text}"' + "}" 112 | data = data.encode("utf-8") 113 | try: 114 | ret = requests.post(url, headers=headers, data=data, timeout=5) 115 | flagged = ret.json()["results"][0]["flagged"] 116 | except requests.exceptions.RequestException as e: 117 | flagged = False 118 | except KeyError as e: 119 | flagged = False 120 | 121 | return flagged 122 | 123 | 124 | def pretty_print_semaphore(semaphore): 125 | if semaphore is None: 126 | return "None" 127 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 128 | -------------------------------------------------------------------------------- /model_export/README.md: -------------------------------------------------------------------------------- 1 | # Model Export for inference on Apple Silicon 2 | Disclaimer: this is not an official recommendation, just research and exploration. 3 | 4 | ## Export Vision Encoder 5 | We found that LLaVA trainer does not save all the states needed for auto inference, 6 | predominantly used in third party libraries like `mlx-vlm`. We save additional metadata 7 | to model checkpoint directory and export the vision model using coremltools. 8 | Export vision encoder and patch the checkpoint using the instruction below. 9 | ```bash 10 | python export_vision_encoder.py --model-path /path/to/fastvlm-checkpoint 11 | ``` 12 | 13 | ## Export VLM 14 | 15 | ### Install mlx-vlm 16 | We provide a patch to `mlx-vlm` to support inference of FastVLM. 17 | ```bash 18 | git clone https://github.com/Blaizzy/mlx-vlm.git 19 | cd mlx-vlm 20 | git checkout 1884b551bc741f26b2d54d68fa89d4e934b9a3de 21 | git apply ../fastvlm_mlx-vlm.patch 22 | pip install -e . 23 | ``` 24 | 25 | Export model using the following instruction. 26 | ```bash 27 | python -m mlx_vlm.convert --hf-path /path/to/fastvlm-checkpoint \ 28 | --mlx-path /path/to/exported-fastvlm \ 29 | --only-llm 30 | ``` 31 | To quantize the LLM, additional options can be provided as shown below. 32 | `--q-bits` specifies bits per weight, the command below exports the LLM with 8-bit quantization. 33 | ```bash 34 | python -m mlx_vlm.convert --hf-path /path/to/fastvlm-checkpoint \ 35 | --mlx-path /path/to/exported-fastvlm \ 36 | --only-llm \ 37 | -q \ 38 | --q-bits 8 # For 4-bit quantization, specify 4 39 | ``` 40 | 41 | ### Generate 42 | The exported model can be used for inference in a python environment following the instruction below. 43 | ```bash 44 | python -m mlx_vlm.generate --model /path/to/exported-fastvlm \ 45 | --image /path/to/image.png \ 46 | --prompt "Describe the image." \ 47 | --max-tokens 256 \ 48 | --temp 0.0 49 | ``` 50 | 51 | ## Troubleshooting 52 | We noticed that sometimes `config.json` for the LLaVA model incorrectly sets the value for `tie_word_embeddings`. 53 | This causes the following error during conversion, `ValueError: Received parameters not in model: language_model.lm_head.weight.` 54 | If you encounter this error, set the value of `tie_word_embeddings` accordingly. 55 | -------------------------------------------------------------------------------- /model_export/export_vision_encoder.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import json 7 | import copy 8 | import argparse 9 | 10 | import torch 11 | import numpy as np 12 | import coremltools 13 | 14 | from llava.model.builder import load_pretrained_model 15 | from llava.utils import disable_torch_init 16 | from llava.mm_utils import get_model_name_from_path 17 | 18 | 19 | def export(args): 20 | # Load model 21 | disable_torch_init() 22 | model_path = os.path.expanduser(args.model_path) 23 | model_name = get_model_name_from_path(model_path) 24 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, 25 | args.model_base, 26 | model_name, 27 | device="mps") 28 | 29 | # Save extra metadata that is not saved during LLaVA training 30 | # required by HF for auto-loading model and for mlx-vlm preprocessing 31 | 32 | # Save image processing config 33 | setattr(image_processor, "processor_class", "LlavaProcessor") 34 | output_path = os.path.join(model_path, "preprocessor_config.json") 35 | image_processor.to_json_file(output_path) 36 | 37 | # Create processor config 38 | processor_config = dict() 39 | processor_config["image_token"] = "" 40 | processor_config["num_additional_image_tokens"] = 0 41 | processor_config["processor_class"] = "LlavaProcessor" 42 | processor_config["patch_size"] = 64 43 | output_path = os.path.join(model_path, "processor_config.json") 44 | json.dump(processor_config, open(output_path, "w"), indent=2) 45 | 46 | # Modify tokenizer to include special token. 47 | tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json") 48 | tokenizer_config = json.load(open(tokenizer_config_path, 'r')) 49 | token_ids = list() 50 | image_token_is_present = False 51 | for k, v in tokenizer_config['added_tokens_decoder'].items(): 52 | token_ids.append(int(k)) 53 | if v["content"] == "": 54 | image_token_is_present = True 55 | token_ids.pop() 56 | 57 | # Append only if token is not present 58 | if not image_token_is_present: 59 | tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}'] = copy.deepcopy( 60 | tokenizer_config['added_tokens_decoder'][f'{token_ids[0]}']) 61 | tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}']["content"] = "" 62 | json.dump(tokenizer_config, open(tokenizer_config_path, 'w'), indent=2) 63 | 64 | # Modify config to contain token id for 65 | config_path = os.path.join(model_path, "config.json") 66 | model_config = json.load(open(config_path, 'r')) 67 | model_config["image_token_index"] = max(token_ids) + 1 68 | json.dump(model_config, open(config_path, 'w'), indent=2) 69 | 70 | # Export the vision encoder to CoreML 71 | image_res = image_processor.to_dict()['size']['shortest_edge'] 72 | inputs = torch.rand(1, 3, image_res, image_res) 73 | inputs_tensor = [ 74 | coremltools.TensorType( 75 | name="images", 76 | shape=inputs.shape, 77 | ) 78 | ] 79 | vision_model = model.get_vision_tower() 80 | vision_model = vision_model.float() 81 | traced_model = torch.jit.trace(vision_model, torch.Tensor(inputs)) 82 | pt_name = "fastvithd.pt" 83 | traced_model.save(pt_name) 84 | 85 | # Export 86 | ml_model = coremltools.convert( 87 | model=pt_name, 88 | outputs=[coremltools.TensorType(name="image_features", dtype=np.float32)], 89 | inputs=inputs_tensor, 90 | convert_to="mlprogram", 91 | debug=False, 92 | compute_units=coremltools.ComputeUnit.CPU_AND_GPU, 93 | minimum_deployment_target=coremltools.target.iOS16, 94 | compute_precision=coremltools.precision.FLOAT32 95 | ) 96 | ml_model_path = os.path.join(model_path, "fastvithd.mlpackage") 97 | ml_model.save(ml_model_path) 98 | 99 | # Remove traced model 100 | os.remove(pt_name) 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("--model-path", type=str, required=True) 106 | parser.add_argument("--model-base", type=str, default=None) 107 | parser.add_argument("--conv-mode", type=str, default="qwen_2") 108 | 109 | args = parser.parse_args() 110 | 111 | export(args) 112 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # 2 | # Modified from LLaVA/predict.py 3 | # Please see ACKNOWLEDGEMENTS for details about LICENSE 4 | # 5 | import os 6 | import argparse 7 | 8 | import torch 9 | from PIL import Image 10 | 11 | from llava.utils import disable_torch_init 12 | from llava.conversation import conv_templates 13 | from llava.model.builder import load_pretrained_model 14 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 15 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 16 | 17 | 18 | def predict(args): 19 | # Remove generation config from model folder 20 | # to read generation parameters from args 21 | model_path = os.path.expanduser(args.model_path) 22 | generation_config = None 23 | if os.path.exists(os.path.join(model_path, 'generation_config.json')): 24 | generation_config = os.path.join(model_path, '.generation_config.json') 25 | os.rename(os.path.join(model_path, 'generation_config.json'), 26 | generation_config) 27 | 28 | # Load model 29 | disable_torch_init() 30 | model_name = get_model_name_from_path(model_path) 31 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, device="mps") 32 | 33 | # Construct prompt 34 | qs = args.prompt 35 | if model.config.mm_use_im_start_end: 36 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 37 | else: 38 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 39 | conv = conv_templates[args.conv_mode].copy() 40 | conv.append_message(conv.roles[0], qs) 41 | conv.append_message(conv.roles[1], None) 42 | prompt = conv.get_prompt() 43 | 44 | # Set the pad token id for generation 45 | model.generation_config.pad_token_id = tokenizer.pad_token_id 46 | 47 | # Tokenize prompt 48 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(torch.device("mps")) 49 | 50 | # Load and preprocess image 51 | image = Image.open(args.image_file).convert('RGB') 52 | image_tensor = process_images([image], image_processor, model.config)[0] 53 | 54 | # Run inference 55 | with torch.inference_mode(): 56 | output_ids = model.generate( 57 | input_ids, 58 | images=image_tensor.unsqueeze(0).half(), 59 | image_sizes=[image.size], 60 | do_sample=True if args.temperature > 0 else False, 61 | temperature=args.temperature, 62 | top_p=args.top_p, 63 | num_beams=args.num_beams, 64 | max_new_tokens=256, 65 | use_cache=True) 66 | 67 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 68 | print(outputs) 69 | 70 | # Restore generation config 71 | if generation_config is not None: 72 | os.rename(generation_config, os.path.join(model_path, 'generation_config.json')) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument("--model-path", type=str, default="./llava-v1.5-13b") 78 | parser.add_argument("--model-base", type=str, default=None) 79 | parser.add_argument("--image-file", type=str, default=None, help="location of image file") 80 | parser.add_argument("--prompt", type=str, default="Describe the image.", help="Prompt for VLM.") 81 | parser.add_argument("--conv-mode", type=str, default="qwen_2") 82 | parser.add_argument("--temperature", type=float, default=0.2) 83 | parser.add_argument("--top_p", type=float, default=None) 84 | parser.add_argument("--num_beams", type=int, default=1) 85 | args = parser.parse_args() 86 | 87 | predict(args) 88 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llava" 7 | version = "1.2.2.post1" 8 | description = "Towards GPT-4 like large language and visual assistant." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.6.0", "torchvision==0.21.0", 17 | "transformers==4.48.3", "tokenizers==0.21.0", "sentencepiece==0.1.99", "shortuuid", 18 | "accelerate==1.6.0", "peft>=0.10.0,<0.14.0", "bitsandbytes", 19 | "pydantic", "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.2.2", 20 | "gradio==5.11.0", "requests", "uvicorn", "fastapi", 21 | "einops==0.6.1", "einops-exts==0.0.4", "timm==1.0.15", 22 | "coremltools==8.2" 23 | ] 24 | 25 | [project.optional-dependencies] 26 | train = ["deepspeed==0.13.1", "ninja", "wandb"] 27 | build = ["build", "twine"] 28 | 29 | [tool.setuptools.packages.find] 30 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 31 | 32 | [tool.wheel] 33 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 34 | --------------------------------------------------------------------------------