├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── assets ├── airplanes.png ├── company_name.png ├── crowd.png ├── crowd_output_1.png ├── crowd_output_2.png ├── dog.png ├── donuts.png ├── donuts_output.png ├── overview.png ├── pipeline.png ├── stand_higher.png └── stand_higher_output.png ├── evaluation ├── calculate_coco_ap.py ├── calculate_counting.py ├── calculate_iou.py ├── calculate_iou_with_bbox.py ├── coco_gt │ └── instances_val2017.json ├── eval_coco.sh ├── eval_count.sh ├── eval_segmentation.sh ├── evaluation_coco.py ├── evaluation_count.py └── evaluation_segmentation.py ├── requirements.txt ├── task_categorization.md └── vision_reasoner ├── __init__.py ├── inference.py ├── models ├── base_model.py ├── qwen_vl.py ├── task_router.py └── vision_reasoner_model.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | # outputs 174 | outputs/ 175 | checkpoints/ 176 | wandb/ 177 | workdir 178 | reasonseg_eval_results 179 | prepare_dataset 180 | result_visualization.png 181 | detection_eval_results/ 182 | pretrained_models/ 183 | test_scripts.sh 184 | test_yoloe.py 185 | mobileclip_blt.ts 186 | yolov8l-world.pt 187 | yolov8x-worldv2.pt 188 | result_visualization_fail.png 189 | crowd_people.png 190 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-ast 6 | - id: check-added-large-files 7 | args: ['--maxkb=25000'] 8 | - id: check-merge-conflict 9 | - id: check-yaml 10 | - id: debug-statements 11 | - id: end-of-file-fixer 12 | - id: requirements-txt-fixer 13 | - id: trailing-whitespace 14 | args: [--markdown-linebreak-ext=md] 15 | - id: no-commit-to-branch 16 | args: ['--branch', 'main'] 17 | 18 | - repo: https://github.com/asottile/pyupgrade 19 | rev: v3.17.0 20 | hooks: 21 | - id: pyupgrade 22 | args: [--py38-plus] 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VisionReasoner: Unified Visual Perception and Reasoning via Reinforcement Learning 2 | 3 | > Current VLMs are primarily used for visual captioning or visual QA tasks. In this project, we take a step further by demonstrating the potential of a single VLM to solve diverse vision tasks. We hope this work will advance the frontier of VLM research and expand the boundaries of what these models can achieve. 4 | 5 | Paper: [📖 VisionReasoner](https://arxiv.org/pdf/2505.12081) [📖 Seg-Zero](https://arxiv.org/pdf/2503.06520) 6 | HuggingFace Daily: [🤗 VisionReasoner](https://huggingface.co/papers/2505.12081) 7 | Model: [🤗 VisionReasoner-7B](https://huggingface.co/Ricky06662/VisionReasoner-7B) [🤗 TaskRouter-1.5B](https://huggingface.co/Ricky06662/TaskRouter-1.5B) 8 | Relative Link: [Seg-Zero![[code]](https://img.shields.io/github/stars/dvlab-research/Seg-Zero)](https://github.com/dvlab-research/Seg-Zero) 9 | 10 | Overview of VisionReasoner: 11 | 12 |
13 | 14 |
15 | 16 | VisionReasoner demonstrates following features: 17 | 1. **VisionReasoner** is a unified framework for visual perception tasks. Through carefully crafted rewards and training strategy, VisionReasoner has strong multi-task capability, addressing diverse visual perception tasks within a shared model. 18 | 2. We select several representative tasks to evaluate models unified visual ability, including detection tasks (e.g., [COCO](https://cocodataset.org/#home), [RefCOCOg](https://github.com/lichengunc/refer)), segmentation tasks (e.g., [ReasonSeg](https://github.com/dvlab-research/LISA)), counting tasks (e.g., [CountBench](https://teaching-clip-to-count.github.io/)) and VQA tasks (e.g. [DocVQA](https://www.docvqa.org/)). 19 | 3. Experimental results show that VisionReasoner achieves superior performance across ten diverse visual perception tasks within a single unified framework, outperforming baseline models by a significant margin. 20 | 4. We have reformulated dozens of visual task types categoried in [Papers With Code](https://paperswithcode.com/datasets?mod=images&page=1). Please refer to [task categorization](task_categorization.md) for details. These task types are categoried as four fundamental task types: detection, segmentation, counting and VQA. More supported task types and more fundamental task types can be added in this framework, such as 3D or medical image processing. 21 | 22 | 23 | ## News 24 | 25 | [May 17th, 2025] 🔥 [📖 Paper](https://arxiv.org/pdf/2505.12081) is coming! 26 | [May 17th, 2025] 🔥 VisionReasoner is coming! VisionReasoner is based on our previous [Seg-Zero](https://github.com/dvlab-research/Seg-Zero). 27 | 28 | 29 | ## Contents 30 | - [Model](#model) 31 | - [Installation](#installation) 32 | - [Inference](#inference) 33 | - [Hybrid Mode](#hybrid-mode) 34 | - [Image Generation](#image-generation) 35 | - [Evaluation](#evaluation) 36 | - [Training](#training) 37 | - [Citation](#citation) 38 | - [Acknowledgement](#acknowledgement) 39 | 40 | 41 | 42 | ## Model 43 |
44 | 45 |
46 | 47 | VisionReasoner model incorporates a reasoning module, which processing image and locates targeted objects, and a segmentation module that produces segmentation masks if needed. 48 | Besides, we also train a task router that convert diverse vision tasks into given four fundamental task types. 49 | 50 | 51 | 56 | 57 | 58 | ## Installation 59 | > [!NOTE] 60 | > If you train VisionReasoner using codes in [Seg-Zero](https://github.com/dvlab-research/Seg-Zero), you can directly use the environment of the training codes. 61 | 62 | ```bash 63 | git clone https://github.com/dvlab-research/VisionReasoner.git 64 | cd VisionReasoner 65 | conda create -n visionreasoner_test python=3.12 66 | conda activate visionreasoner_test 67 | pip3 install torch torchvision 68 | pip install -r requirements.txt 69 | ``` 70 | 71 | 72 | ## Inference 73 | Download model using the following scripts: 74 | ```bash 75 | mkdir pretrained_models 76 | cd pretrained_models 77 | git lfs install 78 | git clone https://huggingface.co/Ricky06662/VisionReasoner-7B 79 | git clone https://huggingface.co/Ricky06662/TaskRouter-1.5B 80 | ``` 81 | > [!TIP] 82 | > If you encounter issues with connecting to Hugging Face, consider using `export HF_ENDPOINT=https://hf-mirror.com`. 83 | 84 | 85 | Then run inference using: 86 | ```bash 87 | python vision_reasoner/inference.py 88 | ``` 89 | ### The default task is a counting task. 90 | > "How many airplanes are there in this image?" 91 | 92 |
93 | 94 |
95 | 96 | 97 | You will get the thinking process in command line, like: 98 | 99 | > "The image shows a formation of airplanes flying in the sky. Each airplane is distinct and can be counted individually. The planes are arranged in a specific pattern, and there are visible trails of smoke behind them, which is typical for airshows or demonstrations." 100 | 101 | And you will get the final answer in command line, like: 102 | 103 | > "Total number of interested objects is: 10" 104 | 105 | 106 | ### You can also try a detection / segmentation task by: 107 | ```bash 108 | python vision_reasoner/inference.py --image_path "assets/donuts.png" --query "please segment the donuts" 109 | ``` 110 | 111 | You will get the thinking process in command line, like: 112 | 113 | > "The task involves identifying and segmenting individual donuts in the image. Each donut is distinct in its color, glaze, and toppings, which helps in distinguishing them from one another. The goal is to identify each donut as a separate object and provide bounding boxes for them." 114 | 115 | And the result will be presented in result_visualization.png. 116 | 117 |
118 | 119 |
120 | 121 | ### Or some tasks that need reasoning: 122 | 123 | ```bash 124 | python vision_reasoner/inference.py --image_path "assets/stand_higher.png" --query "find what can make the woman stand higher?" 125 | ``` 126 | 127 | You will get the thinking process in command line, like: 128 | 129 | > "The question asks for objects that can make the woman stand higher. The woman is already standing on a ladder, which is the object that elevates her. The ladder is the most closely matched object to what can make her stand higher." 130 | 131 | And the result will be presented in result_visualization.png. 132 | 133 |
134 | 135 |
136 | 137 | 138 | ### We also support naive visual QA / captioning task: 139 | ```bash 140 | python vision_reasoner/inference.py --image_path "assets/company_name.png" --query "What is name of the company?" 141 | ``` 142 | 143 |
144 | 145 |
146 | 147 | In VQA, there are no reasoning, and you will get the final answer in command line, like: 148 | 149 | > "The answer is: The name of the company is ITC (Indian Tobacco Company Limited)." 150 | 151 | ### You can also provide your own image_path and text by: 152 | ```bash 153 | python vision_reasoner/inference.py --image_path "your_image_path" --query "your question text" 154 | ``` 155 | 156 | ## Hybrid Mode: 157 | When hybrid reasoning mode is enabled, VisionReasoner intelligently switches between direct detection (using YOLO-World) and reasoning-based approaches based on the complexity of the query. This allows for faster responses on simple queries while maintaining detailed reasoning for complex tasks. 158 | 159 | ### Simple Query Example: 160 | For straightforward queries that can be directly answered by object detection: 161 | 162 | ```bash 163 | python vision_reasoner/inference.py --image_path "assets/crowd.png" --query "person" --hybrid_mode 164 | ``` 165 | 166 | Output: 167 | 168 |
169 | 170 |
171 | 172 | In this case, the model directly uses YOLO-World for detection without going through the reasoning process, resulting in faster response times. 173 | 174 | ### Complex Query Example: 175 | For queries that require spatial reasoning or complex understanding: 176 | 177 | ```bash 178 | python vision_reasoner/inference.py --image_path "assets/crowd.png" --query "the person who is facing to the camera" --hybrid_mode 179 | ``` 180 | 181 | Output: 182 | > Thinking process: The task involves identifying the person who is facing the camera and then finding the most closely matched object. In the image, there is a person in the center wearing a white shirt and a black vest, who appears to be facing the camera directly. The other individuals are walking away from the camera, so they are not the target. The person in the white shirt and black vest is the closest match to the description of facing the camera. 183 | 184 | 185 |
186 | 187 |
188 | 189 | In this case, the model switches to the reasoning-based approach because the query requires understanding spatial relationships and visual attributes. 190 | 191 | 192 | ## Image Generation: 193 | Our framework can also incorporate generation tasks. We adopt [gpt-image-1](https://platform.openai.com/docs/guides/image-generation?image-generation-model=gpt-image-1) for generation in current version. 194 | 195 | > [!NOTE] 196 | > Bugs might arise from API version mismatches. Please debug and customize based on your API key and version. 197 | 198 | ### Text-to-image generation 199 | For text to image generation, you can only input a prompt 200 | ```bash 201 | python vision_reasoner/inference.py --image_prompt "Draw a image of a cute dog." --generation_model_name [your openAI api key] --generation_mode 202 | ``` 203 | 204 | ### Image reference generation 205 | For image reference generation, you should input a prompt and reference image 206 | ```bash 207 | python vision_reasoner/inference.py --refer_image_path "assets/dog.png" --image_prompt "Generate a cute dog in a forest" --generation_model_name [your openAI api key] --generation_mode 208 | ``` 209 | 210 | ## Evaluation 211 | 212 | The evaluation scripts allow you to test VisionReasoner on various datasets. We provide scripts for evaluating segmentation, detection, and counting tasks. 213 | 214 | ### Using the Evaluation Scripts 215 | 216 | Each evaluation script accepts either a HuggingFace dataset path or a local dataset path: 217 | 218 | ```bash 219 | # Using HuggingFace dataset paths (default in examples) 220 | bash evaluation/eval_segmentation.sh Ricky06662/refcoco_val 221 | 222 | # Using local dataset paths 223 | bash evaluation/eval_segmentation.sh /path/to/your/local/refcoco_val 224 | ``` 225 | 226 | Additionally, you can customize model paths with the following parameters: 227 | 228 | ```bash 229 | # Using local model paths (instead of downloading from HuggingFace) 230 | bash evaluation/eval_segmentation.sh [dataset_path] \ 231 | --model_path /path/to/local/VisionReasoner-7B \ 232 | --task_router_model_path /path/to/local/TaskRouter-1.5B 233 | ``` 234 | 235 | ### Available Evaluation Scripts 236 | 237 | - `eval_segmentation.sh`: Evaluates segmentation performance on RefCOCO, RefCOCO+, RefCOCOg, and ReasonSeg datasets. When the dataset contains bounding box ground truth annotations, it will also output detection metrics. 238 | - `eval_coco.sh`: Evaluates detection performance on COCO dataset 239 | - `eval_count.sh`: Evaluates counting performance on counting benchmarks 240 | 241 | ### Example Commands 242 | 243 | ```bash 244 | # Segmentation/Detection evaluation 245 | bash evaluation/eval_segmentation.sh Ricky06662/refcoco_val 246 | bash evaluation/eval_segmentation.sh Ricky06662/refcoco_testA 247 | bash evaluation/eval_segmentation.sh Ricky06662/refcocoplus_val 248 | bash evaluation/eval_segmentation.sh Ricky06662/refcocoplus_testA 249 | bash evaluation/eval_segmentation.sh Ricky06662/refcocog_val 250 | bash evaluation/eval_segmentation.sh Ricky06662/refcocog_test 251 | bash evaluation/eval_segmentation.sh Ricky06662/ReasonSeg_val 252 | bash evaluation/eval_segmentation.sh Ricky06662/ReasonSeg_test 253 | 254 | # COCO evaluation 255 | bash evaluation/eval_coco.sh Ricky06662/coco_val 256 | 257 | # Counting evaluation 258 | bash evaluation/eval_count.sh Ricky06662/counting_pixmo_validation 259 | bash evaluation/eval_count.sh Ricky06662/counting_pixmo_test 260 | bash evaluation/eval_count.sh Ricky06662/counting_countbench 261 | ``` 262 | 263 | ## Training 264 | 265 | We recommend you to [Seg-Zero](https://github.com/dvlab-research/Seg-Zero) for training the VisionReasoner. 266 | 267 | 268 | ## Citation 269 | 270 | ```bibtex 271 | @article{liu2025segzero, 272 | title = {Seg-Zero: Reasoning-Chain Guided Segmentation via Cognitive Reinforcement}, 273 | author = {Liu, Yuqi and Peng, Bohao and Zhong, Zhisheng and Yue, Zihao and Lu, Fanbin and Yu, Bei and Jia, Jiaya}, 274 | journal = {arXiv preprint arXiv:2503.06520}, 275 | year = {2025} 276 | } 277 | 278 | @article{liu2025visionreasoner, 279 | title = {VisionReasoner: Unified Visual Perception and Reasoning via Reinforcement Learning}, 280 | author = {Liu, Yuqi and Qu, Tianyuan and Zhong, Zhisheng and Peng, Bohao and Liu, Shu and Yu, Bei and Jia, Jiaya}, 281 | journal = {arXiv preprint arXiv:2505.12081}, 282 | year = {2025} 283 | } 284 | ``` 285 | 286 | ## Acknowledgement 287 | We would like to thank the following repos for their great work: 288 | 289 | - This work is built upon the [Seg-Zero](https://github.com/dvlab-research/Seg-Zero), [EasyR1](https://github.com/hiyouga/EasyR1) and [veRL](https://github.com/volcengine/verl). 290 | - This work utilizes models from [Qwen2-VL](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct), [Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct), [SAM2](https://huggingface.co/facebook/sam2-hiera-large) and [YOLO-World](https://github.com/AILab-CVC/YOLO-World). 291 | 292 | 293 | ## Star History 294 | 295 | [![Star History Chart](https://api.star-history.com/svg?repos=dvlab-research/VisionReasoner&type=Date)](https://star-history.com/#dvlab-research/VisionReasoner&Date) -------------------------------------------------------------------------------- /assets/airplanes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/airplanes.png -------------------------------------------------------------------------------- /assets/company_name.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/company_name.png -------------------------------------------------------------------------------- /assets/crowd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/crowd.png -------------------------------------------------------------------------------- /assets/crowd_output_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/crowd_output_1.png -------------------------------------------------------------------------------- /assets/crowd_output_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/crowd_output_2.png -------------------------------------------------------------------------------- /assets/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/dog.png -------------------------------------------------------------------------------- /assets/donuts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/donuts.png -------------------------------------------------------------------------------- /assets/donuts_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/donuts_output.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/overview.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/pipeline.png -------------------------------------------------------------------------------- /assets/stand_higher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/stand_higher.png -------------------------------------------------------------------------------- /assets/stand_higher_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/stand_higher_output.png -------------------------------------------------------------------------------- /evaluation/calculate_coco_ap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import glob 4 | import numpy as np 5 | from argparse import ArgumentParser 6 | from pycocotools.coco import COCO 7 | from pycocotools.cocoeval import COCOeval 8 | 9 | 10 | def parse_args(): 11 | parser = ArgumentParser() 12 | parser.add_argument("--output_dir", type=str, required=True, help="folder path of output files") 13 | parser.add_argument("--gt_json_path", type=str, required=True, help="path to COCO ground truth json file") 14 | return parser.parse_args() 15 | 16 | def calculate_metrics(output_dir, gt_json_path): 17 | # get all output files 18 | output_files = sorted(glob.glob(os.path.join(output_dir, "output_*.json"))) 19 | 20 | if not output_files: 21 | print(f"cannot find output files in {output_dir}") 22 | return 23 | 24 | # for accumulating all data 25 | pred_results = [] 26 | 27 | pred_results_constant_score = [] 28 | 29 | pred_results_constant_exist_score = [] 30 | 31 | # for calculating think text length 32 | think_text_lengths = [] 33 | 34 | # read and process all files 35 | for file_path in output_files: 36 | with open(file_path, 'r', encoding='utf-8') as f: 37 | results = json.load(f) 38 | 39 | # process all items in each file 40 | for item in results: 41 | # Calculate think text length if available 42 | if 'think' in item and item['think']: 43 | think_text_lengths.append(len(item['think'])) 44 | 45 | # Original bbox processing code 46 | bbox = item['bbox'] 47 | bbox = [bbox[0], bbox[1], bbox[2]-bbox[0], bbox[3]-bbox[1]] 48 | 49 | pred_results.append({ 50 | "image_id": item['image_id'], 51 | "category_id": item['category_id'], 52 | "bbox": bbox, 53 | "score": item['score'] 54 | }) 55 | 56 | pred_results_constant_score.append({ 57 | "image_id": item['image_id'], 58 | "category_id": item['category_id'], 59 | "bbox": bbox, 60 | "score": 1.0 61 | }) 62 | 63 | pred_results_constant_exist_score.append({ 64 | "image_id": item['image_id'], 65 | "category_id": item['category_id'], 66 | "bbox": bbox, 67 | "score": 1.0 if item['score'] <= 0.2 else 0.0 68 | }) 69 | 70 | # Calculate think text metrics 71 | if think_text_lengths: 72 | avg_think_length = sum(think_text_lengths) / len(think_text_lengths) 73 | min_think_length = min(think_text_lengths) 74 | max_think_length = max(think_text_lengths) 75 | print(f"\n-----------------Think Text Statistics----------------------------------") 76 | print(f"Number of think texts: {len(think_text_lengths)}") 77 | print(f"Average think text length: {avg_think_length:.2f} characters") 78 | print(f"Minimum think text length: {min_think_length} characters") 79 | print(f"Maximum think text length: {max_think_length} characters") 80 | print(f"------------------------------------------------------------------\n") 81 | 82 | coco_gt = COCO(gt_json_path) # load ground truth 83 | coco_dt = coco_gt.loadRes(pred_results) # load prediction results 84 | 85 | # initialize evaluation object (task type: bbox/keypoints/segmentation) 86 | coco_eval = COCOeval(coco_gt, coco_dt, 'bbox') # select task type 87 | 88 | # run evaluation 89 | coco_eval.evaluate() # calculate matches 90 | coco_eval.accumulate() # accumulate metrics 91 | coco_eval.summarize() # output results 92 | 93 | 94 | if __name__ == "__main__": 95 | args = parse_args() 96 | calculate_metrics(args.output_dir, args.gt_json_path) 97 | -------------------------------------------------------------------------------- /evaluation/calculate_counting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import glob 4 | import numpy as np 5 | from argparse import ArgumentParser 6 | 7 | def parse_args(): 8 | parser = ArgumentParser() 9 | parser.add_argument("--output_dir", type=str, required=True, help="folder path of output files") 10 | return parser.parse_args() 11 | 12 | def calculate_metrics(output_dir): 13 | # get all output files 14 | output_files = sorted(glob.glob(os.path.join(output_dir, "output_*.json"))) 15 | 16 | if not output_files: 17 | print(f"cannot find output files in {output_dir}") 18 | return 19 | 20 | # for accumulating all data 21 | all_counts = [] 22 | 23 | # read and process all files 24 | for file_path in output_files: 25 | with open(file_path, 'r', encoding='utf-8') as f: 26 | results = json.load(f) 27 | 28 | # process all items in each file 29 | for item in results: 30 | pred_count = item['pred_count'] 31 | gt_count = item['gt_count'] 32 | 33 | all_counts.append({ 34 | 'image_id': item['image_id'], 35 | 'mae': abs(pred_count - gt_count), 36 | 'rmse': (pred_count - gt_count) ** 2, 37 | 'correct_count': pred_count == gt_count 38 | }) 39 | 40 | # calculate mae and rmse 41 | mae = np.mean([item['mae'] for item in all_counts]) 42 | rmse = np.sqrt(np.mean([item['rmse'] for item in all_counts])) 43 | correct_count = np.mean([item['correct_count'] for item in all_counts]) 44 | # print the results 45 | print(f"test len: {len(all_counts)}") 46 | print(f"mae: {mae:.4f}") 47 | print(f"rmse: {rmse:.4f}") 48 | print(f"correct_count: {correct_count:.4f}") 49 | 50 | 51 | if __name__ == "__main__": 52 | args = parse_args() 53 | calculate_metrics(args.output_dir) 54 | -------------------------------------------------------------------------------- /evaluation/calculate_iou.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import glob 4 | import numpy as np 5 | from argparse import ArgumentParser 6 | 7 | def parse_args(): 8 | parser = ArgumentParser() 9 | parser.add_argument("--output_dir", type=str, required=True, help="folder path of output files") 10 | return parser.parse_args() 11 | 12 | def calculate_metrics(output_dir): 13 | # get all output files 14 | output_files = sorted(glob.glob(os.path.join(output_dir, "output_*.json"))) 15 | 16 | if not output_files: 17 | print(f"cannot find output files in {output_dir}") 18 | return 19 | 20 | # for accumulating all data 21 | total_intersection = 0 22 | total_union = 0 23 | all_ious = [] 24 | 25 | # read and process all files 26 | for file_path in output_files: 27 | with open(file_path, 'r', encoding='utf-8') as f: 28 | results = json.load(f) 29 | 30 | # process all items in each file 31 | for item in results: 32 | intersection = item['intersection'] 33 | union = item['union'] 34 | 35 | # calculate IoU of each item 36 | iou = intersection / union if union > 0 else 0 37 | all_ious.append({ 38 | 'image_id': item['image_id'], 39 | 'iou': iou 40 | }) 41 | 42 | # accumulate total intersection and union 43 | total_intersection += intersection 44 | total_union += union 45 | 46 | # calculate gIoU 47 | gIoU = np.mean([item['iou'] for item in all_ious]) 48 | # calculate cIoU 49 | cIoU = total_intersection / total_union if total_union > 0 else 0 50 | 51 | # print the results 52 | print(f"gIoU (average of per image IoU): {gIoU:.4f}") 53 | print(f"cIoU (total_intersection / total_union): {cIoU:.4f}") 54 | 55 | 56 | if __name__ == "__main__": 57 | args = parse_args() 58 | calculate_metrics(args.output_dir) 59 | -------------------------------------------------------------------------------- /evaluation/calculate_iou_with_bbox.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import glob 4 | import numpy as np 5 | from argparse import ArgumentParser 6 | 7 | def parse_args(): 8 | parser = ArgumentParser() 9 | parser.add_argument("--output_dir", type=str, required=True, help="folder path of output files") 10 | return parser.parse_args() 11 | 12 | def calculate_metrics(output_dir): 13 | # get all output files 14 | output_files = sorted(glob.glob(os.path.join(output_dir, "output_*.json"))) 15 | 16 | if not output_files: 17 | print(f"cannot find output files in {output_dir}") 18 | return 19 | 20 | # for accumulating all data 21 | total_intersection = 0 22 | total_union = 0 23 | total_bbox_iou = 0 24 | all_ious = [] 25 | cnt = 0 26 | 27 | # for calculating think text length 28 | think_text_lengths = [] 29 | 30 | # read and process all files 31 | for file_path in output_files: 32 | with open(file_path, 'r', encoding='utf-8') as f: 33 | results = json.load(f) 34 | 35 | # process all items in each file 36 | for item in results: 37 | # Calculate think text length if available 38 | if 'think' in item and item['think']: 39 | think_text_lengths.append(len(item['think'])) 40 | 41 | intersection = item['intersection'] 42 | union = item['union'] 43 | 44 | # calculate IoU of each item 45 | iou = intersection / union if union > 0 else 0 46 | all_ious.append({ 47 | 'image_id': item['image_id'], 48 | 'iou': iou 49 | }) 50 | 51 | # accumulate total intersection and union 52 | total_intersection += intersection 53 | total_union += union 54 | total_bbox_iou += item['bbox_iou'] 55 | cnt += 1 56 | 57 | # Calculate think text metrics 58 | if think_text_lengths: 59 | avg_think_length = sum(think_text_lengths) / len(think_text_lengths) 60 | min_think_length = min(think_text_lengths) 61 | max_think_length = max(think_text_lengths) 62 | print(f"\n-----------------Think Text Statistics----------------------------------") 63 | print(f"Number of think texts: {len(think_text_lengths)}") 64 | print(f"Average think text length: {avg_think_length:.2f} characters") 65 | print(f"Minimum think text length: {min_think_length} characters") 66 | print(f"Maximum think text length: {max_think_length} characters") 67 | print(f"------------------------------------------------------------------\n") 68 | 69 | # calculate gIoU 70 | gIoU = np.mean([item['iou'] for item in all_ious]) 71 | # calculate cIoU 72 | cIoU = total_intersection / total_union if total_union > 0 else 0 73 | # calculate bbox_iou 74 | bbox_iou = total_bbox_iou / cnt if cnt > 0 else 0 75 | 76 | # print the results 77 | print(f"gIoU (average of per image IoU): {gIoU:.4f}") 78 | print(f"cIoU (total_intersection / total_union): {cIoU:.4f}") 79 | print(f"bbox_AP (average of per image bbox_AP): {bbox_iou:.4f}") 80 | 81 | if __name__ == "__main__": 82 | args = parse_args() 83 | calculate_metrics(args.output_dir) 84 | -------------------------------------------------------------------------------- /evaluation/eval_coco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | MODEL_TYPE="vision_reasoner" # Model type: qwen or vision_reasoner 5 | TEST_DATA_PATH=${1:-"Ricky06662/coco_val"} 6 | 7 | # Extract model name and test dataset name for output directory 8 | TEST_NAME=$(echo $TEST_DATA_PATH | sed -E 's/.*\/([^\/]+)$/\1/') 9 | OUTPUT_PATH="./detection_eval_results/${MODEL_TYPE}/${TEST_NAME}" 10 | 11 | # Customize GPU array here - specify which GPUs to use 12 | GPU_ARRAY=(0 1 2 3 4 5 6 7) # Example: using GPUs 0, 1, 2, 3 13 | NUM_PARTS=${#GPU_ARRAY[@]} 14 | 15 | # Create output directory 16 | mkdir -p $OUTPUT_PATH 17 | 18 | # Run processes in parallel 19 | for i in $(seq 0 $((NUM_PARTS-1))); do 20 | gpu_id=${GPU_ARRAY[$i]} 21 | process_idx=$i # 0-based indexing for process 22 | 23 | export CUDA_VISIBLE_DEVICES=$gpu_id 24 | ( 25 | python evaluation/evaluation_coco.py \ 26 | --model $MODEL_TYPE \ 27 | --output_path $OUTPUT_PATH \ 28 | --test_data_path $TEST_DATA_PATH \ 29 | --idx $process_idx \ 30 | --num_parts $NUM_PARTS \ 31 | --batch_size 16 || { echo "1" > /tmp/process_status.$$; kill -TERM -$$; } 32 | ) & 33 | done 34 | 35 | # Wait for all processes to complete 36 | wait 37 | 38 | COCO_GT_JSON_PATH="evaluation/coco_gt/instances_val2017.json" 39 | 40 | # Calculate COCO AP metrics 41 | python evaluation/calculate_coco_ap.py --output_dir $OUTPUT_PATH --gt_json_path $COCO_GT_JSON_PATH 42 | -------------------------------------------------------------------------------- /evaluation/eval_count.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | MODEL_TYPE="vision_reasoner" # Model type: qwen or vision_reasoner or qwen2 5 | TEST_DATA_PATH=${1:-"Ricky06662/counting_pixmo_test"} 6 | 7 | # Extract model name and test dataset name for output directory 8 | TEST_NAME=$(echo $TEST_DATA_PATH | sed -E 's/.*\/([^\/]+)$/\1/') 9 | OUTPUT_PATH="./detection_eval_results/${MODEL_TYPE}/${TEST_NAME}" 10 | 11 | # Customize GPU array here - specify which GPUs to use 12 | GPU_ARRAY=(0 1 2 3 4 5 6 7) # Example: using GPUs 0, 1, 2, 3 13 | NUM_PARTS=${#GPU_ARRAY[@]} 14 | 15 | # Create output directory 16 | mkdir -p $OUTPUT_PATH 17 | 18 | # Run processes in parallel 19 | for i in $(seq 0 $((NUM_PARTS-1))); do 20 | gpu_id=${GPU_ARRAY[$i]} 21 | process_idx=$i # 0-based indexing for process 22 | 23 | export CUDA_VISIBLE_DEVICES=$gpu_id 24 | ( 25 | python evaluation/evaluation_count.py \ 26 | --model $MODEL_TYPE \ 27 | --output_path $OUTPUT_PATH \ 28 | --test_data_path $TEST_DATA_PATH \ 29 | --idx $process_idx \ 30 | --num_parts $NUM_PARTS \ 31 | --batch_size 16 || { echo "1" > /tmp/process_status.$$; kill -TERM -$$; } 32 | ) & 33 | done 34 | 35 | # Wait for all processes to complete 36 | wait 37 | 38 | python evaluation/calculate_counting.py --output_dir $OUTPUT_PATH -------------------------------------------------------------------------------- /evaluation/eval_segmentation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | MODEL_TYPE="vision_reasoner" # Model type: qwen or vision_reasoner or qwen2 5 | TEST_DATA_PATH=${1:-"Ricky06662/refcocog_test"} 6 | 7 | # Extract model name and test dataset name for output directory 8 | TEST_NAME=$(echo $TEST_DATA_PATH | sed -E 's/.*\/([^\/]+)$/\1/') 9 | OUTPUT_PATH="./detection_eval_results/${MODEL_TYPE}/${TEST_NAME}" 10 | 11 | # Customize GPU array here - specify which GPUs to use 12 | GPU_ARRAY=(0 1 2 3 4 5 6 7) # Example: using GPUs 0, 1, 2, 3 13 | NUM_PARTS=${#GPU_ARRAY[@]} 14 | 15 | # Create output directory 16 | mkdir -p $OUTPUT_PATH 17 | 18 | # Run processes in parallel 19 | for i in $(seq 0 $((NUM_PARTS-1))); do 20 | gpu_id=${GPU_ARRAY[$i]} 21 | process_idx=$i # 0-based indexing for process 22 | 23 | export CUDA_VISIBLE_DEVICES=$gpu_id 24 | ( 25 | python evaluation/evaluation_segmentation.py \ 26 | --model $MODEL_TYPE \ 27 | --output_path $OUTPUT_PATH \ 28 | --test_data_path $TEST_DATA_PATH \ 29 | --idx $process_idx \ 30 | --num_parts $NUM_PARTS \ 31 | --batch_size 16 || { echo "1" > /tmp/process_status.$$; kill -TERM -$$; } 32 | ) & 33 | done 34 | 35 | # Wait for all processes to complete 36 | wait 37 | 38 | python evaluation/calculate_iou_with_bbox.py --output_dir $OUTPUT_PATH -------------------------------------------------------------------------------- /evaluation/evaluation_coco.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import numpy as np 5 | import os 6 | from datasets import load_from_disk, load_dataset 7 | from PIL import Image as PILImage 8 | from tqdm import tqdm 9 | import sys 10 | from scipy.optimize import linear_sum_assignment 11 | 12 | # Add the parent directory to the Python path to import model module 13 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 14 | from vision_reasoner.models.vision_reasoner_model import VisionReasonerModel 15 | from vision_reasoner.models.qwen_vl import QwenVLModel 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model", type=str, default="vision_reasoner") 19 | parser.add_argument("--model_path", type=str, default="pretrained_models/VisionReasoner-7B", choices=["Ricky06662/VisionReasoner-7B", "Qwen/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen2-VL-7B-Instruct"]) 20 | parser.add_argument("--task_router_model_path", type=str, default="pretrained_models/TaskRouter-1.5B") 21 | parser.add_argument("--segmentation_model_path", type=str, default="facebook/sam2-hiera-large") 22 | parser.add_argument("--output_path", type=str, required=True) 23 | parser.add_argument("--test_data_path", type=str, required=True) 24 | parser.add_argument("--batch_size", type=int, default=1) 25 | 26 | # for parallel evaluation 27 | parser.add_argument("--idx", type=int, required=True) 28 | parser.add_argument("--num_parts", type=int, required=True) 29 | return parser.parse_args() 30 | 31 | def compute_bbox_iou(bboxes1, bboxes2): 32 | """ 33 | Calculate IOU matrix between two sets of bounding boxes 34 | bboxes1: shape (N, 4) prediction boxes 35 | bboxes2: shape (M, 4) ground truth boxes 36 | Returns: shape (N, M) IOU matrix 37 | """ 38 | # Expand dimensions to support broadcasting 39 | bboxes1 = np.array(bboxes1)[:, None, :] # (N, 1, 4) 40 | bboxes2 = np.array(bboxes2)[None, :, :] # (1, M, 4) 41 | 42 | # Calculate intersection area 43 | x1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) 44 | y1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) 45 | x2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) 46 | y2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) 47 | 48 | # Calculate intersection area 49 | intersection = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1) 50 | 51 | # Calculate the areas of the two sets of bboxes 52 | area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) 53 | area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) 54 | 55 | # Calculate union area 56 | union = area1 + area2 - intersection 57 | 58 | # Avoid division by zero 59 | iou = np.where(union > 0, intersection / union, 0) 60 | 61 | return iou 62 | 63 | def main(): 64 | args = parse_args() 65 | 66 | # Initialize model 67 | if args.model == "qwen": 68 | model = QwenVLModel(model_path=args.model_path) 69 | elif args.model == "qwen2": 70 | model = QwenVLModel(model_path=args.model_path) 71 | elif args.model == "vision_reasoner": 72 | model = VisionReasonerModel(reasoning_model_path=args.model_path, 73 | task_router_model_path=args.task_router_model_path, 74 | segmentation_model_path=args.segmentation_model_path) 75 | 76 | # Load dataset 77 | dataset = load_dataset(args.test_data_path, split="test") 78 | total_len = len(dataset) 79 | part_size = total_len // args.num_parts 80 | start_idx = args.idx * part_size 81 | end_idx = min(start_idx + part_size if args.idx < args.num_parts - 1 else total_len, total_len) 82 | range_list = range(start_idx, end_idx) 83 | 84 | dataset = dataset.select(range_list) 85 | 86 | all_outputs = [] 87 | 88 | # Prepare batches 89 | for i in tqdm(range(0, len(dataset), args.batch_size), desc="Processing batches"): 90 | batch_data = [dataset[j] for j in range(i, min(i + args.batch_size, len(dataset)))] 91 | 92 | batch_images = [item["image"].convert("RGB") for item in batch_data] 93 | batch_questions = [item["text"].lower().strip(".\"?!") for item in batch_data] 94 | id_list = [{ 95 | "image_id": item["image_id"], 96 | "ann_id": item["ann_id"], 97 | "img_height": item["img_height"], 98 | "img_width": item["img_width"], 99 | "bbox": item["bbox"], 100 | "cat_id": item["cat_id"] 101 | } for item in batch_data] 102 | 103 | process_batch(model, batch_images, batch_questions, id_list, all_outputs) 104 | 105 | # Save results 106 | output_file = os.path.join(args.output_path, f"output_{args.idx}.json") 107 | with open(output_file, "w") as f: 108 | json.dump(all_outputs, f, indent=2, ensure_ascii=False) 109 | 110 | def process_batch(model, batch_images, batch_questions, id_list, all_outputs): 111 | """Process a batch of images and questions""" 112 | batch_results = model.detect_objects_batch(batch_images, batch_questions) 113 | for i, result in enumerate(batch_results): 114 | try: 115 | thinking = result["thinking"] 116 | bboxes = result["bboxes"] 117 | 118 | gt_bboxes = id_list[i]["bbox"] 119 | 120 | if gt_bboxes and len(bboxes) > 0: 121 | # # Use vectorized calculation of IOU matrix 122 | # cost_matrix = -compute_bbox_iou(bboxes, gt_bboxes) # Use negative IOU as cost 123 | 124 | # # Use Hungarian algorithm for matching 125 | # pred_indices, gt_indices = linear_sum_assignment(cost_matrix) 126 | 127 | # # Assign scores to each predicted box 128 | # scores = np.zeros(len(bboxes)) 129 | # for pred_idx, gt_idx in zip(pred_indices, gt_indices): 130 | # scores[pred_idx] = -cost_matrix[pred_idx, gt_idx] # Convert back to positive IOU value 131 | 132 | # Add results 133 | for pred_idx, pred_bbox in enumerate(bboxes): 134 | all_outputs.append({ 135 | "image_id": int(id_list[i]["image_id"]), 136 | "ann_id": int(id_list[i]["ann_id"]), 137 | "think": thinking, 138 | "category_id": int(id_list[i]["cat_id"]), 139 | "bbox": pred_bbox, 140 | #"score": float(max(scores[pred_idx],0.0)) # Use the match score 141 | "score": float((pred_bbox[2]-pred_bbox[0])*(pred_bbox[3]-pred_bbox[1])/(id_list[i]["img_width"]*id_list[i]["img_height"])) 142 | }) 143 | else: 144 | # If there are no ground truth boxes or predicted boxes, score is 0 145 | for pred_bbox in bboxes: 146 | all_outputs.append({ 147 | "image_id": int(id_list[i]["image_id"]), 148 | "ann_id": int(id_list[i]["ann_id"]), 149 | "think": thinking, 150 | "category_id": int(id_list[i]["cat_id"]), 151 | "bbox": pred_bbox, 152 | "score": 0.0 153 | }) 154 | 155 | except Exception as e: 156 | # raise 157 | print(f"Error processing result: {e}") 158 | # Skip this because the implementation is different from the original 159 | continue 160 | 161 | if __name__ == "__main__": 162 | main() -------------------------------------------------------------------------------- /evaluation/evaluation_count.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import numpy as np 5 | import os 6 | from datasets import load_from_disk, load_dataset 7 | from tqdm import tqdm 8 | import sys 9 | 10 | # Add the parent directory to the Python path to import model module 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | from vision_reasoner.models.vision_reasoner_model import VisionReasonerModel 13 | from vision_reasoner.models.qwen_vl import QwenVLModel 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model", type=str, default="vision_reasoner") 17 | parser.add_argument("--model_path", type=str, default="pretrained_models/VisionReasoner-7B", choices=["Ricky06662/VisionReasoner-7B", "Qwen/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen2-VL-7B-Instruct"]) 18 | parser.add_argument("--task_router_model_path", type=str, default="pretrained_models/TaskRouter-1.5B") 19 | parser.add_argument("--segmentation_model_path", type=str, default="facebook/sam2-hiera-large") 20 | parser.add_argument("--output_path", type=str, required=True) 21 | parser.add_argument("--test_data_path", type=str, required=True) 22 | parser.add_argument("--batch_size", type=int, default=1) 23 | 24 | # for parallel evaluation 25 | parser.add_argument("--idx", type=int, required=True) 26 | parser.add_argument("--num_parts", type=int, required=True) 27 | return parser.parse_args() 28 | 29 | def main(): 30 | args = parse_args() 31 | 32 | # Initialize model 33 | if args.model == "qwen": 34 | model = QwenVLModel(model_path=args.model_path) 35 | elif args.model == "qwen2": 36 | model = QwenVLModel(model_path=args.model_path) 37 | elif args.model == "vision_reasoner": 38 | model = VisionReasonerModel(reasoning_model_path=args.model_path, 39 | task_router_model_path=args.task_router_model_path, 40 | segmentation_model_path=args.segmentation_model_path) 41 | 42 | # Load dataset 43 | dataset = load_dataset(args.test_data_path, split="test") 44 | total_len = len(dataset) 45 | part_size = total_len // args.num_parts 46 | start_idx = args.idx * part_size 47 | end_idx = start_idx + part_size if args.idx < args.num_parts - 1 else total_len 48 | 49 | dataset = dataset.select(range(start_idx, end_idx)) 50 | all_outputs = [] 51 | 52 | # Process in batches 53 | for i in tqdm(range(0, len(dataset), args.batch_size), desc="Processing batches"): 54 | batch_data = [dataset[j] for j in range(i, min(i + args.batch_size, len(dataset)))] 55 | 56 | batch_images = [item["image"].convert("RGB") for item in batch_data] 57 | batch_questions = [item["text"].lower().strip(".\"?!") for item in batch_data] 58 | id_list = [{ 59 | "image_id": item["image_id"], 60 | "ann_id": item["ann_id"], 61 | "img_height": item["img_height"], 62 | "img_width": item["img_width"], 63 | "gt_count": item["count"] 64 | } for item in batch_data] 65 | process_batch(model, batch_images, batch_questions, id_list, all_outputs) 66 | 67 | # Save results 68 | output_file = os.path.join(args.output_path, f"output_{args.idx}.json") 69 | with open(output_file, "w") as f: 70 | json.dump(all_outputs, f, indent=2, ensure_ascii=False) 71 | 72 | def process_batch(model, batch_images, batch_questions, id_list, all_outputs): 73 | """Process a batch of images and questions""" 74 | batch_results = model.count_objects_batch(batch_images, batch_questions) 75 | 76 | for i, result in enumerate(batch_results): 77 | try: 78 | thinking = result["thinking"] 79 | bboxes = result["bboxes"] 80 | pred_count = result["count"] 81 | 82 | all_outputs.append({ 83 | "image_id": id_list[i]["image_id"], 84 | "ann_id": id_list[i]["ann_id"], 85 | "think": thinking, 86 | "pred_count": pred_count, 87 | "gt_count": id_list[i]["gt_count"] 88 | }) 89 | 90 | except Exception as e: 91 | # raise 92 | print(f"Error processing result: {e}") 93 | # Add penalty in this situation 94 | all_outputs.append({ 95 | "image_id": id_list[i]["image_id"], 96 | "ann_id": id_list[i]["ann_id"], 97 | "think": "", 98 | "pred_count": 1, 99 | "gt_count": id_list[i]["gt_count"] 100 | }) 101 | 102 | print(f"Processed batch of {len(batch_images)} images") 103 | 104 | 105 | if __name__ == "__main__": 106 | main() -------------------------------------------------------------------------------- /evaluation/evaluation_segmentation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import numpy as np 5 | import os 6 | from datasets import load_from_disk, load_dataset 7 | from PIL import Image as PILImage 8 | from tqdm import tqdm 9 | import sys 10 | 11 | # Add the parent directory to the Python path to import model module 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | from vision_reasoner.models.vision_reasoner_model import VisionReasonerModel 14 | from vision_reasoner.models.qwen_vl import QwenVLModel 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model", type=str, default="vision_reasoner") 19 | parser.add_argument("--model_path", type=str, default="pretrained_models/VisionReasoner-7B", choices=["Ricky06662/VisionReasoner-7B", "Qwen/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen2-VL-7B-Instruct"]) 20 | parser.add_argument("--task_router_model_path", type=str, default="pretrained_models/TaskRouter-1.5B") 21 | parser.add_argument("--segmentation_model_path", type=str, default="facebook/sam2-hiera-large") 22 | parser.add_argument("--output_path", type=str, required=True) 23 | parser.add_argument("--test_data_path", type=str, required=True) 24 | parser.add_argument("--batch_size", type=int, default=1) 25 | 26 | # for parallel evaluation 27 | parser.add_argument("--idx", type=int, required=True) 28 | parser.add_argument("--num_parts", type=int, required=True) 29 | return parser.parse_args() 30 | 31 | def compute_iou(mask1, mask2): 32 | intersection = np.logical_and(mask1, mask2).sum() 33 | union = np.logical_or(mask1, mask2).sum() 34 | if union == 0: 35 | return 0, 0 36 | return intersection, union 37 | 38 | def compute_bbox_iou(bbox1, bbox2): 39 | # Calculate the intersection area of two bboxes 40 | x1 = max(bbox1[0], bbox2[0]) 41 | y1 = max(bbox1[1], bbox2[1]) 42 | x2 = min(bbox1[2], bbox2[2]) 43 | y2 = min(bbox1[3], bbox2[3]) 44 | 45 | # Calculate intersection area 46 | intersection = max(0, x2 - x1) * max(0, y2 - y1) 47 | 48 | # Calculate areas of the two bboxes 49 | area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) 50 | area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) 51 | 52 | # Calculate union area 53 | union = area1 + area2 - intersection 54 | 55 | # Avoid division by zero 56 | if union == 0: 57 | return 0 58 | 59 | return intersection / union 60 | 61 | def main(): 62 | args = parse_args() 63 | 64 | # Initialize model 65 | if args.model == "qwen": 66 | model = QwenVLModel(model_path=args.model_path) 67 | elif args.model == "qwen2": 68 | model = QwenVLModel(model_path=args.model_path) 69 | elif args.model == "vision_reasoner": 70 | model = VisionReasonerModel(reasoning_model_path=args.model_path, 71 | task_router_model_path=args.task_router_model_path, 72 | segmentation_model_path=args.segmentation_model_path) 73 | 74 | # Load dataset 75 | dataset = load_dataset(args.test_data_path, split="test") 76 | total_len = len(dataset) 77 | part_size = total_len // args.num_parts 78 | start_idx = args.idx * part_size 79 | end_idx = start_idx + part_size if args.idx < args.num_parts - 1 else total_len 80 | 81 | dataset = dataset.select(range(start_idx, end_idx)) 82 | 83 | # Check if dataset has bbox information 84 | has_bbox = 'bbox' in dataset[0] 85 | 86 | all_outputs = [] 87 | 88 | # Prepare batches 89 | for i in tqdm(range(0, len(dataset), args.batch_size), desc="Processing batches"): 90 | batch_data = [dataset[j] for j in range(i, min(i + args.batch_size, len(dataset)))] 91 | 92 | batch_images = [item["image"].convert("RGB") for item in batch_data] 93 | batch_questions = [item["text"].lower().strip(".\"?!") for item in batch_data] 94 | id_list = [{ 95 | "image_id": item["image_id"], 96 | "ann_id": item["ann_id"], 97 | "mask": item["mask"], 98 | "img_height": item["img_height"], 99 | "img_width": item["img_width"], 100 | "bbox": item["bbox"] if has_bbox else None 101 | } for item in batch_data] 102 | 103 | process_batch(model, batch_images, batch_questions, id_list, all_outputs, has_bbox) 104 | 105 | # Save results 106 | output_file = os.path.join(args.output_path, f"output_{args.idx}.json") 107 | with open(output_file, "w") as f: 108 | json.dump(all_outputs, f, indent=2, ensure_ascii=False) 109 | 110 | def process_batch(model, batch_images, batch_questions, id_list, all_outputs, has_bbox): 111 | """Process a batch of images and questions""" 112 | batch_results = model.segment_objects_batch(batch_images, batch_questions) 113 | 114 | for i, result in enumerate(batch_results): 115 | try: 116 | thinking = result["thinking"] 117 | bboxes = result["bboxes"] 118 | mask_all = result["masks"] 119 | gt_mask = np.array(id_list[i]["mask"]) 120 | 121 | intersection, union = compute_iou(mask_all, gt_mask) 122 | 123 | bbox_iou = 0.0 124 | if has_bbox: 125 | try: 126 | gt_bbox = id_list[i]["bbox"] 127 | for pred_bbox in bboxes: 128 | if compute_bbox_iou(pred_bbox, gt_bbox) > 0.5: 129 | bbox_iou = 1.0 130 | break 131 | except Exception as e: 132 | print(f"Bbox error: {e}, Image ID: {id_list[i]['image_id']}, Ann ID: {id_list[i]['ann_id']}") 133 | bbox_iou = 0.0 134 | 135 | all_outputs.append({ 136 | "image_id": id_list[i]["image_id"], 137 | "ann_id": id_list[i]["ann_id"], 138 | "think": thinking, 139 | "intersection": int(intersection), 140 | "union": int(union), 141 | "bbox_iou": bbox_iou 142 | }) 143 | 144 | except Exception as e: 145 | print(f"Error processing result: {e}") 146 | # Add penalty in this situation 147 | all_outputs.append({ 148 | "image_id": id_list[i]["image_id"], 149 | "ann_id": id_list[i]["ann_id"], 150 | "think": "", 151 | "intersection": 0, 152 | "union": np.array(id_list[i]["mask"]).sum(), 153 | "bbox_iou": 0.0 154 | }) 155 | 156 | 157 | if __name__ == "__main__": 158 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | codetiming 3 | datasets 4 | flash-attn 5 | liger-kernel 6 | mathruler 7 | numpy 8 | omegaconf 9 | pandas 10 | peft 11 | pillow 12 | pyarrow>=15.0.0 13 | pylatexenc 14 | qwen-vl-utils 15 | scipy 16 | matplotlib 17 | ipykernel 18 | pycocotools 19 | scikit-image 20 | ray 21 | tensordict 22 | transformers 23 | vllm 24 | wandb 25 | sam2 26 | ultralytics 27 | openai -------------------------------------------------------------------------------- /task_categorization.md: -------------------------------------------------------------------------------- 1 | # Current task types 2 | > This list shows how we reformulating existing visual perception tasks. We are working on it to support more task types. 3 | 4 | ## Detection 5 | 6 | - [Visual Grounding](https://paperswithcode.com/datasets?mod=images&task=visual-grounding) 7 | - [Object Detection](https://paperswithcode.com/task/2d-object-detection/latest) 8 | - [2D Object Detection](https://paperswithcode.com/task/2d-object-detection/latest) 9 | - [Small Object Detection](https://paperswithcode.com/task/object-detection) 10 | - [Defect Detection](https://paperswithcode.com/task/defect-detection) 11 | - [Face Detection](https://astro.paperswithcode.com/task/occluded-face-detection) 12 | - [License Plate Detection](https://paperswithcode.com/task/license-plate-detection) 13 | - [Anomaly Detection](https://paperswithcode.com/task/anomaly-detection) 14 | - [Human Detection](https://paperswithcode.com/task/human-detection) 15 | - [Surgical Tool Detection](https://paperswithcode.com/task/surgical-tool-detection) 16 | - [Dense Object Detection](https://paperswithcode.com/task/dense-object-detection) 17 | - [Open World Object Detection](https://paperswithcode.com/task/open-world-object-detection) 18 | - [Zero-Shot Object Detection](https://paperswithcode.com/task/zero-shot-object-detection) 19 | - [Animal Action Recognition](https://paperswithcode.com/task/animal-action-recognition) 20 | - [Robotic Grasping](https://paperswithcode.com/task/robotic-grasping) 21 | - [Object Localization](https://paperswithcode.com/task/object-localization) 22 | - [Hand Detection](https://paperswithcode.com/task/hand-detection) 23 | - [Visual Relationship Detection](https://paperswithcode.com/task/visual-relationship-detection) 24 | - [Open Vocabulary Object Detection](https://paperswithcode.com/task/open-vocabulary-object-detection) 25 | - [Oriented Object Detection](https://paperswithcode.com/task/oriented-object-detection) 26 | - [Object Detection in Indoor Scenes](https://paperswithcode.com/task/object-detection-in-indoor-scenes) 27 | - [Object Detection in Aerial Images](https://paperswithcode.com/task/object-detection-in-aerial-images) 28 | - [Person Search](https://paperswithcode.com/task/person-search) 29 | - [Object Recognition](https://paperswithcode.com/datasets?mod=images&page=1) 30 | 31 | 32 | ## Segmentation 33 | 34 | 35 | - [Semantic Segmentation](https://paperswithcode.com/task/semantic-segmentation) 36 | - [Instance Segmentation](https://paperswithcode.com/task/3d-instance-segmentation-1) 37 | - [Lane Detection](https://paperswithcode.com/task/lane-detection/social) 38 | - [2D Semantic Segmentation](https://paperswithcode.com/task/2d-semantic-segmentation) 39 | - [Medical Image Segmentation](https://paperswithcode.com/task/medical-image-segmentation/latest) 40 | - [Human Part Segmentation](https://paperswithcode.com/task/human-part-segmentation) 41 | - [Action Segmentation](https://paperswithcode.com/task/action-segmentation) 42 | - [Video Object Segmentation](https://paperswithcode.com/task/interactive-video-object-segmentation) 43 | - [Referring Expression Segmentation](https://paperswithcode.com/task/referring-expression-segmentation) 44 | - [Saliency Detection](https://paperswithcode.com/task/saliency-detection/latest) 45 | - [Salient Object Detection](https://paperswithcode.com/task/salient-object-detection) 46 | - [The Semantic Segmentation of Remote Sensing Imagery](https://paperswithcode.com/task/the-semantic-segmentation-of-remote-sensing) 47 | - [Crack Segmentation](https://paperswithcode.com/task/crack-segmentation) 48 | - [Action Unit Detection](https://paperswithcode.com/task/action-unit-detection) 49 | - [RGB Salient Object Detection](https://paperswithcode.com/task/salient-object-detection) 50 | - [Boundary Detection](https://paperswithcode.com/task/boundary-detection) 51 | - [Crack Segmentation for Infrastructure](https://paperswithcode.com/task/crack-segmentation) 52 | - [Surgical Tool Segmentation](https://paperswithcode.com/task/surgical-tool-detection) 53 | 54 | ## Counting 55 | 56 | - [Object Counting](https://paperswithcode.com/task/object-counting) 57 | - [Crowd Counting](https://paperswithcode.com/task/crowd-counting) 58 | - [Density Estimation](https://paperswithcode.com/task/density-estimation) 59 | - [Pedestrian Detection](https://paperswithcode.com/task/pedestrian-detection) 60 | - [Crowd Estimation in Dense Scenes](https://paperswithcode.com/task/crowd-counting/codeless) 61 | - [Traffic Counting in Surveillance](https://paperswithcode.com/task/crowd-counting) 62 | 63 | ## VQA 64 | 65 | - [Visual Question Answering (VQA)](https://paperswithcode.com/datasets?mod=images&task=visual-question-answering) 66 | - [Classification](https://paperswithcode.com/datasets?mod=images&task=classification-1) 67 | - [Image Captioning](https://paperswithcode.com/datasets?mod=images&task=image-captioning) 68 | - [Question Answering](https://paperswithcode.com/datasets?mod=images&task=question-answering) 69 | - [Visual Reasoning](https://paperswithcode.com/datasets?mod=images&task=visual-reasoning) 70 | - [Visual Question Answering](https://paperswithcode.com/datasets?mod=images&task=visual-question-answering-1) 71 | - [Relational Reasoning](https://paperswithcode.com/datasets?mod=images&task=relational-reasoning) 72 | 73 | -------------------------------------------------------------------------------- /vision_reasoner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/vision_reasoner/__init__.py -------------------------------------------------------------------------------- /vision_reasoner/inference.py: -------------------------------------------------------------------------------- 1 | # test_qwen_vl_model.py 2 | import argparse 3 | from PIL import Image 4 | import os 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 6 | from models.vision_reasoner_model import VisionReasonerModel 7 | from utils import visualize_results_enhanced 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description="Test unified vision model on a single image") 11 | parser.add_argument("--model_path", type=str, default='pretrained_models/VisionReasoner-7B', help="Path to the model") 12 | parser.add_argument("--task_router_model_path", type=str, default="pretrained_models/TaskRouter-1.5B") 13 | parser.add_argument("--segmentation_model_path", type=str, default="facebook/sam2-hiera-large") 14 | parser.add_argument("--image_path", type=str, default="assets/airplanes.png", help="Path to the input image") 15 | parser.add_argument("--query", type=str, default="How many airplanes are there in this image?", help="Query/instruction for the model") 16 | parser.add_argument("--task", type=str, choices=["auto", "detection", "segmentation", "counting", "vqa", "generation"], 17 | default="auto", help="Task type (default: auto)") 18 | parser.add_argument("--output_path", type=str, default="result_visualization.png", help="Path to save the output visualization") 19 | parser.add_argument("--hybrid_mode", action="store_true", help="Whether to use YOLO for object detection") 20 | parser.add_argument("--yolo_model_path", type=str, default="yolov8x-worldv2.pt", help="Path to the YOLO model") 21 | parser.add_argument("--refer_image_path", type=str, default="", help="Path to the reference image") 22 | parser.add_argument("--image_prompt", type=str, default="", help="Prompt for image generation") 23 | parser.add_argument("--generation_mode", action="store_true", help="Whether to use generation model") 24 | parser.add_argument("--generation_model_name", type=str, default="gpt-image-1", help="Name of the generation model") 25 | args = parser.parse_args() 26 | 27 | # Determine task type 28 | if args.image_prompt != "": 29 | assert args.generation_mode, "Please set --generation_mode to True when using image prompt" 30 | task_type = "generation" 31 | else: 32 | task_type = args.task 33 | 34 | # Load model 35 | print(f"Loading model from {args.model_path}...") 36 | if args.generation_mode: 37 | model = VisionReasonerModel(reasoning_model_path=args.model_path, 38 | task_router_model_path=args.task_router_model_path, 39 | segmentation_model_path=args.segmentation_model_path, 40 | generation_model_path=args.generation_model_name) 41 | elif args.hybrid_mode: 42 | model = VisionReasonerModel(reasoning_model_path=args.model_path, 43 | task_router_model_path=args.task_router_model_path, 44 | segmentation_model_path=args.segmentation_model_path, 45 | yolo_model_path=args.yolo_model_path) 46 | else: 47 | model = VisionReasonerModel(reasoning_model_path=args.model_path, 48 | task_router_model_path=args.task_router_model_path, 49 | segmentation_model_path=args.segmentation_model_path) 50 | 51 | 52 | if task_type != "generation": 53 | # Load image 54 | print(f"Loading image from {args.image_path}...") 55 | image = Image.open(args.image_path).convert("RGB") 56 | 57 | if task_type == "auto": 58 | result, task_type = model.process_single_image(image, args.query, return_task_type=True) 59 | elif task_type == "detection": 60 | result = model.detect_objects(image, args.query) 61 | elif task_type == "segmentation": 62 | result = model.segment_objects(image, args.query) 63 | elif task_type == "counting": 64 | result = model.count_objects(image, args.query) 65 | elif task_type == "generation": 66 | result = model.generate_image(args.refer_image_path, args.image_prompt) 67 | else: # VQA 68 | result = model.answer_question(image, args.query) 69 | 70 | # Print results 71 | print("\n===== Results =====") 72 | print("Task type: ", task_type) 73 | if args.image_prompt != "": 74 | print("User prompt: ", args.image_prompt) 75 | else: 76 | print("User question: ", args.query) 77 | if 'thinking' in result and result['thinking'].strip() != "": 78 | print("Thinking process: ", result['thinking']) 79 | 80 | # print("Response: ", result) 81 | 82 | if task_type == "detection": 83 | print(f"Total number of detected objects: {len(result['bboxes'])}") 84 | elif task_type == "segmentation": 85 | print(f"Total number of segmented objects: {len(result['bboxes'])}") 86 | elif task_type == "counting": 87 | print(f"Total number of interested objects is: {result['count']}") 88 | elif task_type == "generation": 89 | result.save(args.output_path, format="PNG") 90 | print(f"The generated image is saved as '{args.output_path}'") 91 | else: # QA 92 | print(f"The answer is: {result['answer']}") 93 | 94 | if task_type != "generation" and task_type != "vqa" and task_type != "counting": 95 | # Visualize results with the new three-panel layout 96 | visualize_results_enhanced(image, result, task_type, args.output_path) 97 | print(f"\nResult visualization saved as '{args.output_path}'") 98 | 99 | 100 | 101 | if __name__ == "__main__": 102 | main() -------------------------------------------------------------------------------- /vision_reasoner/models/base_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class BaseVisionModel(ABC): 4 | """Abstract base class for vision models that process images and instructions""" 5 | 6 | @abstractmethod 7 | def process_single_image(self, image, instruction): 8 | """ 9 | Process a single image and instruction 10 | 11 | Args: 12 | image: Input image 13 | instruction: Text instruction/query 14 | 15 | Returns: 16 | dict: Results dictionary 17 | """ 18 | pass 19 | 20 | @abstractmethod 21 | def process_batch(self, batch_images, batch_instructions): 22 | """ 23 | Process a batch of images and instructions 24 | 25 | Args: 26 | batch_images: List of input images 27 | batch_instructions: List of text instructions/queries 28 | 29 | Returns: 30 | list: List of result dictionaries 31 | """ 32 | pass 33 | 34 | 35 | class DetectionModel(ABC): 36 | """Interface for object detection tasks""" 37 | 38 | @abstractmethod 39 | def detect_objects(self, image, query): 40 | """ 41 | Detect objects in an image based on a query 42 | 43 | Args: 44 | image: Input image 45 | query: Text query describing what to detect 46 | 47 | Returns: 48 | dict: Results containing at least: 49 | - bboxes: List of bounding boxes [x1, y1, x2, y2] 50 | - scores: List of confidence scores 51 | - thinking: Reasoning process (if available) 52 | """ 53 | pass 54 | 55 | @abstractmethod 56 | def detect_objects_batch(self, images, queries): 57 | """ 58 | Detect objects in a batch of images 59 | 60 | Args: 61 | images: List of input images 62 | queries: List of text queries 63 | 64 | Returns: 65 | list: List of result dictionaries 66 | """ 67 | pass 68 | 69 | 70 | class SegmentationModel(ABC): 71 | """Interface for segmentation tasks""" 72 | 73 | @abstractmethod 74 | def segment_objects(self, image, query): 75 | """ 76 | Segment objects in an image based on a query 77 | 78 | Args: 79 | image: Input image 80 | query: Text query describing what to segment 81 | 82 | Returns: 83 | dict: Results containing at least: 84 | - masks: Segmentation masks 85 | - bboxes: List of bounding boxes 86 | - thinking: Reasoning process (if available) 87 | """ 88 | pass 89 | 90 | @abstractmethod 91 | def segment_objects_batch(self, images, queries): 92 | """ 93 | Segment objects in a batch of images 94 | 95 | Args: 96 | images: List of input images 97 | queries: List of text queries 98 | 99 | Returns: 100 | list: List of result dictionaries 101 | """ 102 | pass 103 | 104 | 105 | class CountingModel(ABC): 106 | """Interface for counting tasks""" 107 | 108 | @abstractmethod 109 | def count_objects(self, image, query): 110 | """ 111 | Count objects in an image based on a query 112 | 113 | Args: 114 | image: Input image 115 | query: Text query describing what to count 116 | 117 | Returns: 118 | dict: Results containing at least: 119 | - count: Number of objects 120 | - bboxes: List of bounding boxes (optional) 121 | - thinking: Reasoning process (if available) 122 | """ 123 | pass 124 | 125 | @abstractmethod 126 | def count_objects_batch(self, images, queries): 127 | """ 128 | Count objects in a batch of images 129 | 130 | Args: 131 | images: List of input images 132 | queries: List of text queries 133 | 134 | Returns: 135 | list: List of result dictionaries 136 | """ 137 | pass 138 | 139 | 140 | class QAModel(ABC): 141 | """Interface for visual question answering tasks""" 142 | 143 | @abstractmethod 144 | def answer_question(self, image, question): 145 | """ 146 | Answer a question about an image 147 | 148 | Args: 149 | image: Input image 150 | question: Text question 151 | 152 | Returns: 153 | dict: Results containing at least: 154 | - answer: Text answer 155 | - thinking: Reasoning process (if available) 156 | """ 157 | pass 158 | 159 | @abstractmethod 160 | def answer_questions_batch(self, images, questions): 161 | """ 162 | Answer questions about a batch of images 163 | 164 | Args: 165 | images: List of input images 166 | questions: List of text questions 167 | 168 | Returns: 169 | list: List of result dictionaries 170 | """ 171 | pass -------------------------------------------------------------------------------- /vision_reasoner/models/qwen_vl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import re 4 | import json 5 | from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration 6 | from PIL import Image as PILImage 7 | from sam2.sam2_image_predictor import SAM2ImagePredictor 8 | from qwen_vl_utils import process_vision_info 9 | from .base_model import ( 10 | BaseVisionModel, 11 | DetectionModel, 12 | SegmentationModel, 13 | CountingModel, 14 | QAModel 15 | ) 16 | 17 | class QwenVLModel(BaseVisionModel, DetectionModel, SegmentationModel, CountingModel, QAModel): 18 | """ 19 | QwenVL model implementing all task interfaces with custom prompts 20 | """ 21 | 22 | def __init__(self, model_path='Qwen/Qwen2.5-VL-7B-Instruct', 23 | segmentation_model_path="facebook/sam2-hiera-large"): 24 | """ 25 | Initialize the QwenVL model 26 | 27 | Args: 28 | model_path (str): Path to the model 29 | """ 30 | self.model_path = model_path 31 | 32 | # Initialize model 33 | if 'Qwen2.5' in model_path: 34 | self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 35 | model_path, 36 | torch_dtype=torch.float16, 37 | device_map="auto", 38 | trust_remote_code=True 39 | ) 40 | else: 41 | self.model = Qwen2VLForConditionalGeneration.from_pretrained( 42 | model_path, 43 | torch_dtype=torch.float16, 44 | device_map="auto", 45 | trust_remote_code=True 46 | ) 47 | self.model.eval() 48 | 49 | # Initialize processor 50 | self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) 51 | 52 | # Initialize segmentation model 53 | self.segmentation_model = SAM2ImagePredictor.from_pretrained(segmentation_model_path) 54 | 55 | # Task-specific prompts 56 | if 'Qwen2.5' in model_path: 57 | self.DETECTION_PROMPT = """ 58 | Locate "{query}", report the bboxes coordinates in JSON format. 59 | """ 60 | else: 61 | self.DETECTION_PROMPT = """ 62 | Please identify "{query}" in the image. 63 | Return a JSON array where each object is represented as: 64 | {{"bbox": [x1, y1, x2, y2], "label": label}} 65 | 66 | Your response should be formatted as: 67 | ```json 68 | [ 69 | {{"bbox": [x1, y1, x2, y2], "label": label}}, 70 | ... 71 | ] 72 | ``` 73 | """ 74 | self.COUNTING_PROMPT = """ 75 | Locate "{query}", report the bboxes coordinates in JSON format. 76 | """ 77 | 78 | self.QA_PROMPT = """{query}""" 79 | 80 | def extract_json_from_response(self, response_text): 81 | """ 82 | Extract JSON content from model response using triple quotes 83 | 84 | Args: 85 | response_text (str): Model response text 86 | 87 | Returns: 88 | dict or list: Parsed JSON content 89 | """ 90 | json_pattern = r"```json\s*(.*?)\s*```" 91 | match = re.search(json_pattern, response_text, re.DOTALL) 92 | 93 | if match: 94 | json_str = match.group(1).strip() 95 | try: 96 | return json.loads(json_str) 97 | except json.JSONDecodeError: 98 | print(f"Error parsing JSON: {json_str}") 99 | return None 100 | 101 | # Fallback: try to find any JSON-like structure 102 | try: 103 | # Look for arrays 104 | array_pattern = r'\[(.*?)\]' 105 | array_match = re.search(array_pattern, response_text, re.DOTALL) 106 | if array_match: 107 | return json.loads(f"[{array_match.group(1)}]") 108 | 109 | # Look for objects 110 | object_pattern = r'\{(.*?)\}' 111 | object_match = re.search(object_pattern, response_text, re.DOTALL) 112 | if object_match: 113 | return json.loads(f"{{{object_match.group(1)}}}") 114 | except: 115 | pass 116 | 117 | return None 118 | 119 | def generate_response(self, image, prompt): 120 | """ 121 | Generate response from the model 122 | 123 | Args: 124 | image (PIL.Image): Input image 125 | prompt (str): Text prompt 126 | 127 | Returns: 128 | str: Model response 129 | """ 130 | # Resize image while maintaining aspect ratio 131 | messages = [ 132 | { 133 | "role": "system", 134 | "content": "You are a helpful assistant" 135 | }, 136 | { 137 | "role": "user", 138 | "content": [ 139 | { 140 | "type": "image", 141 | "image": image 142 | }, 143 | { 144 | "type": "text", 145 | "text": prompt 146 | } 147 | ] 148 | } 149 | ] 150 | text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 151 | 152 | image_inputs, video_inputs = process_vision_info(messages) 153 | inputs = self.processor( 154 | text=[text], 155 | images=image_inputs, 156 | videos=video_inputs, 157 | padding=True, 158 | return_tensors="pt", 159 | ).to('cuda') 160 | 161 | 162 | # Generate response 163 | with torch.inference_mode(): 164 | output_ids = self.model.generate( 165 | **inputs, 166 | max_new_tokens=1024, 167 | do_sample=False 168 | ) 169 | 170 | generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] 171 | output_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) 172 | 173 | input_height = inputs['image_grid_thw'][0][1]*14 174 | input_width = inputs['image_grid_thw'][0][2]*14 175 | 176 | return output_text[0], (input_height, input_width) 177 | 178 | def _determine_task_type(self, instruction): 179 | """ 180 | Determine the task type based on the instruction 181 | 182 | Args: 183 | instruction (str): Text instruction or query 184 | 185 | Returns: 186 | str: Task type ("detection", "segmentation", "counting", or "qa") 187 | """ 188 | instruction_lower = instruction.lower() 189 | 190 | if "how many" in instruction_lower: 191 | return "counting" 192 | elif any(word in instruction_lower for word in ["segment", "mask", "outline"]): 193 | return "segmentation" 194 | elif any(word in instruction_lower for word in ["find", "locate", "detect", "identify"]): 195 | return "detection" 196 | else: 197 | return "vqa" 198 | 199 | def process_single_image(self, image, instruction, return_task_type=False): 200 | """ 201 | Process a single image with given instruction 202 | 203 | Args: 204 | image (PIL.Image): Input image 205 | instruction (str): Text instruction or query 206 | 207 | Returns: 208 | dict: Results dictionary 209 | """ 210 | task_type = self._determine_task_type(instruction) 211 | 212 | 213 | 214 | if task_type == "detection": 215 | result = self.detect_objects(image, instruction) 216 | elif task_type == "segmentation": 217 | result = self.segment_objects(image, instruction) 218 | elif task_type == "counting": 219 | result = self.count_objects(image, instruction) 220 | else: # Default to QA 221 | result = self.answer_question(image, instruction) 222 | 223 | if return_task_type: 224 | return result, task_type 225 | else: 226 | return result 227 | 228 | def process_batch(self, batch_images, batch_instructions): 229 | """ 230 | Process a batch of images with given instructions 231 | 232 | Args: 233 | batch_images (list): List of PIL Images 234 | batch_instructions (list): List of text instructions or queries 235 | 236 | Returns: 237 | list: List of result dictionaries 238 | """ 239 | results = [] 240 | for image, instruction in zip(batch_images, batch_instructions): 241 | result = self.process_single_image(image, instruction) 242 | results.append(result) 243 | return results 244 | 245 | def detect_objects(self, image, query): 246 | """ 247 | Detect objects in an image based on a query 248 | 249 | Args: 250 | image: Input image 251 | query: Text query describing what to detect 252 | 253 | Returns: 254 | dict: Results with bounding boxes and scores 255 | """ 256 | prompt = self.DETECTION_PROMPT.format(query=query) 257 | response, scale_factors = self.generate_response(image, prompt) 258 | 259 | # Get original image dimensions 260 | img_height, img_width = image.size[1], image.size[0] 261 | 262 | json_data = self.extract_json_from_response(response) 263 | 264 | bboxes = [] 265 | points = [] 266 | scores = [] 267 | 268 | if json_data and isinstance(json_data, list): 269 | for item in json_data: 270 | for key in ['bbox', 'bbox_2d']: 271 | if key in item: 272 | # For Qwen2-VL, convert from normalized coordinates (0-1000) to actual image coordinates 273 | if 'Qwen2.5' not in self.model_path: 274 | bbox = [ 275 | int(item[key][0] * img_width / 1000), 276 | int(item[key][1] * img_height / 1000), 277 | int(item[key][2] * img_width / 1000), 278 | int(item[key][3] * img_height / 1000) 279 | ] 280 | else: 281 | # Original scaling for Qwen2.5-VL 282 | bbox = [ 283 | int(item[key][0]), 284 | int(item[key][1]), 285 | int(item[key][2]), 286 | int(item[key][3]) 287 | ] 288 | bboxes.append(bbox) 289 | 290 | for key in ['point', 'point_2d']: 291 | if key in item: 292 | # Similarly handle points 293 | if 'Qwen2.5' not in self.model_path: 294 | point = [ 295 | int(item[key][0] * img_width / 1000), 296 | int(item[key][1] * img_height / 1000) 297 | ] 298 | else: 299 | point = [ 300 | int(item[key][0]), 301 | int(item[key][1]) 302 | ] 303 | points.append(point) 304 | 305 | for key in ['score', 'score_2d']: 306 | if key in item: 307 | scores.append(item[key]) 308 | 309 | if len(scores) == 0: 310 | scores.append(0.0) 311 | 312 | return { 313 | "bboxes": bboxes, 314 | "points": points, 315 | "scores": scores, 316 | "thinking": "", 317 | "full_response": response, 318 | "json_data": json_data 319 | } 320 | 321 | def detect_objects_batch(self, images, queries): 322 | """ 323 | Detect objects in a batch of images 324 | 325 | Args: 326 | images: List of input images 327 | queries: List of text queries 328 | 329 | Returns: 330 | list: List of detection results 331 | """ 332 | results = [] 333 | for image, query in zip(images, queries): 334 | result = self.detect_objects(image, query) 335 | results.append(result) 336 | return results 337 | 338 | 339 | def generate_masks(self, image, bboxes, points): 340 | """ 341 | Generate segmentation masks for given image, bounding boxes and points 342 | 343 | Args: 344 | image (PIL.Image): Input image 345 | bboxes (list): List of bounding boxes 346 | points (list): List of points 347 | 348 | Returns: 349 | numpy.ndarray: Combined segmentation mask 350 | """ 351 | img_height, img_width = image.height, image.width 352 | mask_all = np.zeros((img_height, img_width), dtype=bool) 353 | 354 | if not bboxes: 355 | return mask_all 356 | 357 | try: 358 | self.segmentation_model.set_image(image) 359 | if not points: 360 | points = [] 361 | if len(points) != len(bboxes): 362 | points.extend([None] * (len(bboxes) - len(points))) 363 | 364 | for bbox, point in zip(bboxes, points): 365 | masks, scores, _ = self.segmentation_model.predict( 366 | box=bbox 367 | ) 368 | sorted_ind = np.argsort(scores)[::-1] 369 | masks = masks[sorted_ind] 370 | mask = masks[0].astype(bool) 371 | mask_all = np.logical_or(mask_all, mask) 372 | 373 | return mask_all 374 | except Exception as e: 375 | print(f"Error generating masks: {e}") 376 | return mask_all 377 | 378 | def segment_objects(self, image, query): 379 | """ 380 | Segment objects in an image based on a query 381 | 382 | Args: 383 | image: Input image 384 | query: Text query describing what to segment 385 | 386 | Returns: 387 | dict: Results with masks and bounding boxes 388 | """ 389 | try: 390 | prompt = self.DETECTION_PROMPT.format(query=query) 391 | response, scale_factors = self.generate_response(image, prompt) 392 | img_height, img_width = image.size[1], image.size[0] 393 | x_factor, y_factor = (1, 1) if 'Qwen2.5' in self.model_path \ 394 | else (img_width / 1000, img_height / 1000) 395 | 396 | json_data = self.extract_json_from_response(response) 397 | 398 | bboxes = [] 399 | points = [] 400 | 401 | if json_data and isinstance(json_data, list): 402 | for item in json_data: 403 | for key in ['bbox', 'bbox_2d']: 404 | if key in item and len(item[key]) == 4: 405 | bbox = [ 406 | int(item[key][0] * x_factor), 407 | int(item[key][1] * y_factor), 408 | int(item[key][2] * x_factor), 409 | int(item[key][3] * y_factor) 410 | ] 411 | bboxes.append(bbox) 412 | 413 | for key in ['point', 'point_2d']: 414 | if key in item and len(item[key]) == 2: 415 | point = [ 416 | int(item[key][0] * x_factor), 417 | int(item[key][1] * y_factor) 418 | ] 419 | points.append(point) 420 | 421 | masks = self.generate_masks(image, bboxes, points) 422 | 423 | return { 424 | "masks": masks, 425 | "bboxes": bboxes, 426 | "points": points, 427 | "thinking": "", 428 | "full_response": response, 429 | "json_data": json_data 430 | } 431 | except Exception as e: 432 | raise 433 | print(f"Error in segmentation: {e}") 434 | return { 435 | "masks": np.zeros((image.height, image.width), dtype=bool), 436 | "bboxes": [], 437 | "points": [], 438 | "thinking": "", 439 | "full_response": "", 440 | "json_data": None 441 | } 442 | 443 | def segment_objects_batch(self, images, queries): 444 | """ 445 | Segment objects in a batch of images 446 | 447 | Args: 448 | images: List of input images 449 | queries: List of text queries 450 | 451 | Returns: 452 | list: List of segmentation results 453 | """ 454 | results = [] 455 | for image, query in zip(images, queries): 456 | result = self.segment_objects(image, query) 457 | results.append(result) 458 | return results 459 | 460 | def count_objects(self, image, query): 461 | """ 462 | Count objects in an image based on a query 463 | 464 | Args: 465 | image: Input image 466 | query: Text query describing what to count 467 | 468 | Returns: 469 | dict: Results with count and bounding boxes 470 | """ 471 | try: 472 | prompt = self.COUNTING_PROMPT.format(query=query) 473 | response, scale_factors = self.generate_response(image, prompt) 474 | 475 | # Get original image dimensions 476 | img_height, img_width = image.size[1], image.size[0] 477 | x_factor, y_factor = (1, 1) if 'Qwen2.5' in self.model_path \ 478 | else (img_width / 1000, img_height / 1000) 479 | 480 | json_data = self.extract_json_from_response(response) 481 | 482 | bboxes = [] 483 | points = [] 484 | 485 | if json_data and isinstance(json_data, list): 486 | for item in json_data: 487 | for key in ['bbox', 'bbox_2d']: 488 | if key in item: 489 | bbox = [ 490 | int(item[key][0] * x_factor), 491 | int(item[key][1] * y_factor), 492 | int(item[key][2] * x_factor), 493 | int(item[key][3] * y_factor) 494 | ] 495 | bboxes.append(bbox) 496 | 497 | for key in ['point', 'point_2d']: 498 | if key in item: 499 | point = [ 500 | int(item[key][0] * x_factor), 501 | int(item[key][1] * y_factor) 502 | ] 503 | points.append(point) 504 | 505 | return { 506 | "count": len(bboxes), 507 | "bboxes": bboxes, 508 | "thinking": "", 509 | "points": points, 510 | "full_response": response, 511 | "json_data": json_data 512 | } 513 | 514 | # If JSON extraction fails, extract count from text using regex 515 | # Match digits, number words and Roman numerals 516 | count = 0 517 | 518 | # Match cardinal numbers (e.g. "5", "five", "10") 519 | number_words = { 520 | 'zero': 0, 'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5, 521 | 'six': 6, 'seven': 7, 'eight': 8, 'nine': 9, 'ten': 10, 'eleven': 11, 522 | 'twelve': 12, 'thirteen': 13, 'fourteen': 14, 'fifteen': 15, 523 | 'sixteen': 16, 'seventeen': 17, 'eighteen': 18, 'nineteen': 19, 'twenty': 20 524 | } 525 | 526 | # Look for phrases like "there are X" or "I count X" or "X objects" 527 | count_patterns = [ 528 | r'there (?:are|is) (\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty)', 529 | r'i (?:can |)(?:see|count|find) (\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty)', 530 | r'(\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty) (?:objects|items)', 531 | r'count(?:ing)? (?:is|of) (\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty)', 532 | r'total (?:of|is) (\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty)', 533 | r'(\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty) in total', 534 | r'found (\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty)', 535 | # Roman numerals 536 | r'there (?:are|is) (I|II|III|IV|V|VI|VII|VIII|IX|X|XI|XII|XIII|XIV|XV|XVI|XVII|XVIII|XIX|XX)', 537 | r'count(?:ing)? (?:is|of) (I|II|III|IV|V|VI|VII|VIII|IX|X|XI|XII|XIII|XIV|XV|XVI|XVII|XVIII|XIX|XX)', 538 | r'total (?:of|is) (I|II|III|IV|V|VI|VII|VIII|IX|X|XI|XII|XIII|XIV|XV|XVI|XVII|XVIII|XIX|XX)', 539 | ] 540 | 541 | roman_to_int = { 542 | 'I': 1, 'II': 2, 'III': 3, 'IV': 4, 'V': 5, 543 | 'VI': 6, 'VII': 7, 'VIII': 8, 'IX': 9, 'X': 10, 544 | 'XI': 11, 'XII': 12, 'XIII': 13, 'XIV': 14, 'XV': 15, 545 | 'XVI': 16, 'XVII': 17, 'XVIII': 18, 'XIX': 19, 'XX': 20 546 | } 547 | 548 | import re 549 | response_lower = response.lower() 550 | 551 | for pattern in count_patterns: 552 | match = re.search(pattern, response_lower) 553 | if match: 554 | match_text = match.group(1) 555 | if match_text.isdigit(): 556 | count = int(match_text) 557 | break 558 | elif match_text in number_words: 559 | count = number_words[match_text] 560 | break 561 | elif match_text.upper() in roman_to_int: 562 | count = roman_to_int[match_text.upper()] 563 | break 564 | 565 | return { 566 | "count": count, 567 | "bboxes": bboxes, 568 | "thinking": "Extracted count from text response", 569 | "points": points, 570 | "full_response": response, 571 | "json_data": json_data 572 | } 573 | 574 | except Exception as e: 575 | print(f"Error in counting: {e}") 576 | return { 577 | "count": 0, 578 | "bboxes": [], 579 | "thinking": "", 580 | "points": [], 581 | "full_response": "", 582 | "json_data": None 583 | } 584 | 585 | 586 | def count_objects_batch(self, images, queries): 587 | """ 588 | Count objects in a batch of images 589 | 590 | Args: 591 | images: List of input images 592 | queries: List of text queries 593 | 594 | Returns: 595 | list: List of counting results 596 | """ 597 | results = [] 598 | for image, query in zip(images, queries): 599 | result = self.count_objects(image, query) 600 | results.append(result) 601 | return results 602 | 603 | def answer_question(self, image, question): 604 | """ 605 | Answer a question about an image 606 | 607 | Args: 608 | image: Input image 609 | question: Text question 610 | 611 | Returns: 612 | dict: Results with answer 613 | """ 614 | try: 615 | prompt = self.QA_PROMPT.format(query=question) 616 | response, _ = self.generate_response(image, prompt) 617 | 618 | answer = response 619 | 620 | return { 621 | "answer": answer, 622 | "thinking": "", 623 | "full_response": response, 624 | "json_data": None 625 | } 626 | except Exception as e: 627 | print(f"Error in QA: {e}") 628 | return { 629 | "answer": "", 630 | "thinking": "", 631 | "full_response": "", 632 | "json_data": None 633 | } 634 | 635 | def answer_questions_batch(self, images, questions): 636 | """ 637 | Answer questions about a batch of images 638 | 639 | Args: 640 | images: List of input images 641 | questions: List of text questions 642 | 643 | Returns: 644 | list: List of QA results 645 | """ 646 | results = [] 647 | for image, question in zip(images, questions): 648 | result = self.answer_question(image, question) 649 | results.append(result) 650 | return results -------------------------------------------------------------------------------- /vision_reasoner/models/task_router.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | import re 4 | 5 | class TaskRouter: 6 | def __init__(self, model_name="Qwen/Qwen2.5-7B-Instruct"): 7 | """Initialize task router""" 8 | self.model = AutoModelForCausalLM.from_pretrained( 9 | model_name, 10 | torch_dtype=torch.bfloat16, 11 | attn_implementation="flash_attention_2", 12 | device_map="auto" 13 | ) 14 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") 15 | self.prompt_template = "Given a user instruction, please classify which type of task it is and output the final answer after \"####\".. The types are: 1) Segmentation/detection, 2) Counting, 3) Editing, 4) Caption/QA. The user instruction is: " 16 | 17 | def route_task(self, instruction): 18 | """Route input instruction to corresponding task category 19 | 20 | Args: 21 | instruction: User input instruction 22 | 23 | Returns: 24 | dict: Dictionary containing predicted category and confidence 25 | """ 26 | # Get model response 27 | response = self._get_model_response(instruction) 28 | 29 | # Extract category 30 | predicted_category = self._extract_category(response) 31 | 32 | return predicted_category 33 | 34 | def route_batch(self, instructions): 35 | """Batch route tasks 36 | 37 | Args: 38 | instructions: List of instructions 39 | 40 | Returns: 41 | list: List of result dictionaries 42 | """ 43 | # Get batch responses 44 | responses = self._get_model_responses(instructions) 45 | 46 | results = [] 47 | for instruction, response in zip(instructions, responses): 48 | category = self._extract_category(response) 49 | results.append(category) 50 | 51 | return results 52 | 53 | def _get_model_response(self, instruction): 54 | """Get model response for a single instruction""" 55 | return self._get_model_responses([instruction])[0] 56 | 57 | def _get_model_responses(self, instructions): 58 | """Get batch model responses 59 | 60 | Args: 61 | instructions: List of instructions 62 | 63 | Returns: 64 | list: List of responses 65 | """ 66 | # Build batch messages 67 | message_batch = [ 68 | [ 69 | {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, 70 | {"role": "user", "content": self.prompt_template + instruction} 71 | ] 72 | for instruction in instructions 73 | ] 74 | 75 | # Process batch template 76 | text_batch = self.tokenizer.apply_chat_template( 77 | message_batch, 78 | tokenize=False, 79 | add_generation_prompt=True 80 | ) 81 | 82 | # Batch tokenize 83 | model_inputs = self.tokenizer( 84 | text_batch, 85 | return_tensors="pt", 86 | padding=True 87 | ).to(self.model.device) 88 | 89 | with torch.no_grad(): 90 | generated_ids = self.model.generate( 91 | **model_inputs, 92 | max_new_tokens=512 93 | ) 94 | # Only keep newly generated tokens 95 | generated_ids = generated_ids[:, model_inputs.input_ids.shape[1]:] 96 | responses = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) 97 | 98 | return responses 99 | 100 | def _extract_category(self, response): 101 | """Extract predicted category from model response""" 102 | response = response.lower() 103 | 104 | # 1. Find all number matches 105 | number_pattern = r'(?:type|category|task)?\s*(?:is|:)?\s*([1-4])' 106 | number_matches = re.findall(number_pattern, response) 107 | if number_matches: 108 | number = int(number_matches[-1]) # Take the last matched number 109 | if number == 1: 110 | return 'segmentation' 111 | elif number == 2: 112 | return 'counting' 113 | elif number == 3: 114 | return 'editing' 115 | elif number == 4: 116 | return 'vqa' 117 | 118 | # 2. Keyword matching - find all matching keywords 119 | matches = [] 120 | if any(word in response for word in ['segmentation', 'detection', 'segment', 'detect', 'grounding']): 121 | matches.append('segmentation') 122 | if any(word in response for word in ['counting', 'count', 'number']): 123 | matches.append('counting') 124 | if any(word in response for word in ['editing', 'edit', 'modify']): 125 | matches.append('editing') 126 | if any(word in response for word in ['caption', 'qa', 'question', 'answer', 'describe']): 127 | matches.append('vqa') 128 | 129 | # Return the last matched category 130 | return matches[-1] if matches else "vqa" -------------------------------------------------------------------------------- /vision_reasoner/models/vision_reasoner_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import re 4 | import json 5 | import os 6 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor 7 | from sam2.sam2_image_predictor import SAM2ImagePredictor 8 | from PIL import Image as PILImage 9 | from ultralytics import YOLOWorld 10 | from openai import OpenAI 11 | from io import BytesIO 12 | import base64 13 | 14 | from .base_model import ( 15 | BaseVisionModel, 16 | DetectionModel, 17 | SegmentationModel, 18 | CountingModel, 19 | QAModel 20 | ) 21 | from qwen_vl_utils import process_vision_info 22 | from .task_router import TaskRouter 23 | 24 | STOP_WORDS = {"is", "are", "find", "the", "segment", "all", "in", "image", 25 | "how", "many", "there", "locate", "please"} 26 | MAX_QUERY_WORDS = 2 27 | 28 | 29 | class VisionReasonerModel(BaseVisionModel, DetectionModel, SegmentationModel, CountingModel, QAModel): 30 | """ 31 | VisionReasoner model implementing all task interfaces 32 | """ 33 | def __init__(self, 34 | reasoning_model_path="Ricky06662/VisionReasoner-7B", 35 | segmentation_model_path="facebook/sam2-hiera-large", 36 | task_router_model_path="Ricky06662/TaskRouter-1.5B", 37 | yolo_model_path=None, 38 | generation_model_path=None): 39 | """ 40 | Initialize the VisionReasoner model with reasoning and segmentation components 41 | 42 | Args: 43 | reasoning_model_path (str): Path to the reasoning model 44 | segmentation_model_path (str): Path to the segmentation model 45 | """ 46 | self.resize_size = 840 47 | 48 | # Initialize reasoning model 49 | self.reasoning_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 50 | reasoning_model_path, 51 | torch_dtype=torch.bfloat16, 52 | attn_implementation="flash_attention_2", 53 | device_map="auto", 54 | ) 55 | self.reasoning_model.eval() 56 | 57 | # Initialize processor 58 | self.processor = AutoProcessor.from_pretrained(reasoning_model_path, padding_side="left") 59 | 60 | # Initialize segmentation model 61 | self.segmentation_model = SAM2ImagePredictor.from_pretrained(segmentation_model_path) 62 | 63 | self.task_router = TaskRouter(task_router_model_path) 64 | 65 | # Template for detection/segmentation tasks 66 | self.DETECTION_TEMPLATE = \ 67 | "Please find \"{Question}\" with bboxs and points." \ 68 | "Compare the difference between object(s) and find the most closely matched object(s)." \ 69 | "Output the thinking process in and final answer in tags." \ 70 | "Output the bbox(es) and point(s) inside the interested object(s) in JSON format." \ 71 | "i.e., thinking process here " \ 72 | "{Answer}" 73 | 74 | # Template for QA tasks 75 | self.QA_TEMPLATE = "{Question}" 76 | 77 | # Initialize YOLO model 78 | self.use_hybrid_mode = False 79 | if yolo_model_path: 80 | self.use_hybrid_mode = True 81 | self.yolo_model = YOLOWorld(yolo_model_path) 82 | 83 | # Initialize generation model 84 | if generation_model_path: 85 | self.generation_model = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", generation_model_path)) 86 | 87 | 88 | def extract_bbox_points_think(self, output_text, x_factor, y_factor): 89 | """ 90 | Extract bounding boxes, points, and thinking process from model output 91 | 92 | Args: 93 | output_text (str): Raw output text from the model 94 | x_factor (float): Scaling factor for x coordinates 95 | y_factor (float): Scaling factor for y coordinates 96 | 97 | Returns: 98 | tuple: (pred_bboxes, pred_points, think_text, pred_answer) 99 | """ 100 | json_match = re.search(r'\s*(.*?)\s*', output_text, re.DOTALL) 101 | pred_bboxes = [] 102 | pred_points = [] 103 | pred_answer = None 104 | think_text = "" 105 | 106 | if json_match: 107 | try: 108 | data = json.loads(json_match.group(1)) 109 | pred_answer = data 110 | pred_bboxes = [[ 111 | int(item['bbox_2d'][0] * x_factor + 0.5), 112 | int(item['bbox_2d'][1] * y_factor + 0.5), 113 | int(item['bbox_2d'][2] * x_factor + 0.5), 114 | int(item['bbox_2d'][3] * y_factor + 0.5) 115 | ] for item in data] 116 | pred_points = [[ 117 | int(item['point_2d'][0] * x_factor + 0.5), 118 | int(item['point_2d'][1] * y_factor + 0.5) 119 | ] for item in data] 120 | except Exception as e: 121 | print(f"Error parsing JSON: {e}") 122 | 123 | think_pattern = r'([^<]+)' 124 | think_match = re.search(think_pattern, output_text) 125 | if think_match: 126 | think_text = think_match.group(1) 127 | 128 | return pred_bboxes, pred_points, think_text, pred_answer 129 | 130 | def extract_qa_answer(self, output_text): 131 | """ 132 | Extract answer for QA tasks 133 | 134 | Args: 135 | output_text (str): Raw output text from the model 136 | 137 | Returns: 138 | dict: Result dictionary with answer and thinking (if available) 139 | """ 140 | think_pattern = r'([^<]+)' 141 | think_match = re.search(think_pattern, output_text) 142 | thinking = think_match.group(1) if think_match else "" 143 | 144 | # Remove thinking tags from output to get cleaner answer 145 | clean_answer = re.sub(r'.*?', '', output_text, flags=re.DOTALL).strip() 146 | 147 | return { 148 | "answer": clean_answer, 149 | "thinking": thinking, 150 | "full_response": output_text 151 | } 152 | 153 | def generate_masks(self, image, bboxes, points=None): 154 | """ 155 | Generate segmentation masks for given image, bounding boxes and points 156 | 157 | Args: 158 | image (PIL.Image): Input image 159 | bboxes (list): List of bounding boxes 160 | points (list): List of points 161 | 162 | Returns: 163 | numpy.ndarray: Combined segmentation mask 164 | """ 165 | img_height, img_width = image.height, image.width 166 | mask_all = np.zeros((img_height, img_width), dtype=bool) 167 | 168 | if not bboxes: 169 | return mask_all 170 | 171 | if points and len(points) != len(bboxes): 172 | return mask_all 173 | 174 | try: 175 | self.segmentation_model.set_image(image) 176 | if points: 177 | for bbox, point in zip(bboxes, points): 178 | masks, scores, _ = self.segmentation_model.predict( 179 | point_coords=[point], 180 | point_labels=[1], 181 | box=bbox 182 | ) 183 | sorted_ind = np.argsort(scores)[::-1] 184 | masks = masks[sorted_ind] 185 | mask = masks[0].astype(bool) 186 | mask_all = np.logical_or(mask_all, mask) 187 | else: 188 | for bbox in bboxes: 189 | masks, scores, _ = self.segmentation_model.predict( 190 | box=bbox 191 | ) 192 | sorted_ind = np.argsort(scores)[::-1] 193 | masks = masks[sorted_ind] 194 | 195 | return mask_all 196 | except Exception as e: 197 | print(f"Error generating masks: {e}") 198 | return mask_all 199 | 200 | def _generate_model_output(self, images, instructions, template, batch_mode=False): 201 | """ 202 | Generate raw model output for images and instructions 203 | 204 | Args: 205 | images (PIL.Image or List[PIL.Image]): Input image(s) 206 | instructions (str or List[str]): Text instruction(s)/query(ies) 207 | template (str): Template to use for the prompt 208 | batch_mode (bool): Whether to process in batch mode 209 | 210 | Returns: 211 | tuple: (output_texts, scale_factors) 212 | """ 213 | if not batch_mode: 214 | images = [images] 215 | instructions = [instructions] 216 | 217 | batch_messages = [] 218 | scale_factors = [] 219 | 220 | for image, instruction in zip(images, instructions): 221 | # Prepare image 222 | original_width, original_height = image.size 223 | x_factor, y_factor = original_width/self.resize_size, original_height/self.resize_size 224 | scale_factors.append((x_factor, y_factor)) 225 | resized_image = image.resize((self.resize_size, self.resize_size), PILImage.BILINEAR) 226 | 227 | # Format text based on template 228 | if "{Question}" in template: 229 | formatted_text = template.format( 230 | Question=instruction.lower().strip(".\"?!"), 231 | Answer="[{\"bbox_2d\": [10,100,200,210], \"point_2d\": [30,110]}, {\"bbox_2d\": [225,296,706,786], \"point_2d\": [302,410]}]" 232 | ) 233 | else: 234 | formatted_text = template 235 | 236 | # Create message 237 | message = [{ 238 | "role": "user", 239 | "content": [ 240 | { 241 | "type": "image", 242 | "image": resized_image 243 | }, 244 | { 245 | "type": "text", 246 | "text": formatted_text 247 | } 248 | ] 249 | }] 250 | batch_messages.append(message) 251 | 252 | # Prepare for batch inference 253 | texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages] 254 | 255 | image_inputs, video_inputs = process_vision_info(batch_messages) 256 | 257 | inputs = self.processor( 258 | text=texts, 259 | images=image_inputs, 260 | padding=True, 261 | return_tensors="pt", 262 | ) 263 | inputs = inputs.to("cuda") 264 | 265 | # Generate output 266 | generated_ids = self.reasoning_model.generate(**inputs, use_cache=True, max_new_tokens=2048, do_sample=False) 267 | 268 | generated_ids_trimmed = [ 269 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 270 | ] 271 | 272 | output_texts = self.processor.batch_decode( 273 | generated_ids_trimmed, 274 | skip_special_tokens=True, 275 | clean_up_tokenization_spaces=False 276 | ) 277 | if not batch_mode: 278 | return output_texts[0], scale_factors[0] 279 | return output_texts, scale_factors 280 | 281 | # BaseVisionModel implementation 282 | def process_single_image(self, image, instruction, return_task_type=False): 283 | """ 284 | Process a single image with given instruction 285 | 286 | Args: 287 | image (PIL.Image): Input image 288 | instruction (str): Text instruction or query 289 | 290 | Returns: 291 | dict: Results dictionary 292 | """ 293 | # Determine task type based on instruction 294 | task_type = self.task_router.route_task(instruction) 295 | 296 | if task_type == "segmentation": 297 | result = self.segment_objects(image, instruction) 298 | elif task_type == "detection": 299 | result = self.detect_objects(image, instruction) 300 | elif task_type == "counting": 301 | result = self.count_objects(image, instruction) 302 | else: # Default to VQA 303 | result = self.answer_question(image, instruction) 304 | 305 | if return_task_type: 306 | return result, task_type 307 | else: 308 | return result 309 | 310 | def process_batch(self, batch_images, batch_instructions): 311 | """ 312 | Process a batch of images with given instructions 313 | 314 | Args: 315 | batch_images (list): List of PIL Images 316 | batch_instructions (list): List of text instructions or queries 317 | 318 | Returns: 319 | list: List of result dictionaries 320 | """ 321 | results = [] 322 | for image, instruction in zip(batch_images, batch_instructions): 323 | result = self.process_single_image(image, instruction) 324 | results.append(result) 325 | return results 326 | 327 | def detect_objects_yolo(self, image, query): 328 | """ 329 | Detect objects in an image based on a query using YOLO model 330 | 331 | Args: 332 | image: Input image 333 | query: Text query describing what to detect 334 | 335 | Returns: 336 | dict: Results with bounding boxes and scores 337 | """ 338 | # Initialize a YOLO model 339 | query = " ".join([word.strip(".\"?!'") for word in query.lower().strip(".\"?!").split() if word not in STOP_WORDS]) 340 | names = [query] 341 | self.yolo_model.set_classes(names) 342 | 343 | # Run detection on the given image 344 | results = self.yolo_model.predict(image) 345 | 346 | # Get original image dimensions 347 | img_height, img_width = image.height, image.width 348 | 349 | # Get YOLO's input size 350 | yolo_input_size = results[0].orig_shape 351 | 352 | # Calculate scaling factors 353 | x_scale = img_width / yolo_input_size[1] 354 | y_scale = img_height / yolo_input_size[0] 355 | 356 | # Scale the bounding boxes back to original image size 357 | bboxes = results[0].boxes.xyxy 358 | scaled_bboxes = [] 359 | for bbox in bboxes: 360 | scaled_bbox = [ 361 | int(bbox[0] * x_scale), 362 | int(bbox[1] * y_scale), 363 | int(bbox[2] * x_scale), 364 | int(bbox[3] * y_scale) 365 | ] 366 | scaled_bboxes.append(scaled_bbox) 367 | 368 | return scaled_bboxes 369 | 370 | def if_yolo_condition(self, query): 371 | """ 372 | Check if YOLO should be used for the given query 373 | 374 | Args: 375 | query (str): Text query describing what to detect 376 | 377 | Returns: 378 | bool: True if YOLO should be used, False otherwise 379 | """ 380 | 381 | # trivial condition 382 | query_words = [word for word in query.lower().strip(".\"?!").split() if word not in STOP_WORDS] 383 | return len(query_words) <= MAX_QUERY_WORDS 384 | 385 | # DetectionModel implementation 386 | def detect_objects(self, image, query): 387 | """ 388 | Detect objects in an image based on a query 389 | 390 | Args: 391 | image: Input image 392 | query: Text query describing what to detect 393 | 394 | Returns: 395 | dict: Results with bounding boxes and scores 396 | """ 397 | try: 398 | if self.use_hybrid_mode and self.if_yolo_condition(query): 399 | bboxes = self.detect_objects_yolo(image, query) 400 | scores = [1.0] * len(bboxes) 401 | # use middle point of bbox as point 402 | points = [[int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2)] for bbox in bboxes] 403 | output_text, thinking, pred_answer = "", "", str(bboxes) 404 | else: 405 | output_text, (x_factor, y_factor) = self._generate_model_output( 406 | image, 407 | query, 408 | self.DETECTION_TEMPLATE 409 | ) 410 | 411 | bboxes, points, thinking, pred_answer = self.extract_bbox_points_think( 412 | output_text, 413 | x_factor, 414 | y_factor 415 | ) 416 | 417 | # Assign confidence scores (all 1.0 as the model doesn't provide them) 418 | scores = [1.0] * len(bboxes) 419 | 420 | return { 421 | "bboxes": bboxes, 422 | "points": points, 423 | "scores": scores, 424 | "thinking": thinking, 425 | "full_response": output_text, 426 | "pred_answer": pred_answer 427 | } 428 | except Exception as e: 429 | print(f"Error in detection: {e}") 430 | return { 431 | "bboxes": [], 432 | "points": [], 433 | "scores": [], 434 | "thinking": "", 435 | "full_response": "", 436 | "pred_answer": None 437 | } 438 | 439 | def detect_objects_batch(self, images, queries): 440 | """ 441 | Detect objects in a batch of images 442 | 443 | Args: 444 | images: List of input images 445 | queries: List of text queries 446 | 447 | Returns: 448 | list: List of detection results 449 | """ 450 | try: 451 | # TODO: support yolo for batch 452 | 453 | output_texts, scale_factors = self._generate_model_output( 454 | images, 455 | queries, 456 | self.DETECTION_TEMPLATE, 457 | batch_mode=True 458 | ) 459 | 460 | results = [] 461 | for output_text, (x_factor, y_factor) in zip(output_texts, scale_factors): 462 | bboxes, points, thinking, pred_answer = self.extract_bbox_points_think( 463 | output_text, 464 | x_factor, 465 | y_factor 466 | ) 467 | 468 | scores = [1.0] * len(bboxes) 469 | results.append({ 470 | "bboxes": bboxes, 471 | "points": points, 472 | "scores": scores, 473 | "thinking": thinking, 474 | "full_response": output_text, 475 | "pred_answer": pred_answer 476 | }) 477 | return results 478 | except Exception as e: 479 | print(f"Error in batch detection: {e}") 480 | return [{ 481 | "bboxes": [], 482 | "points": [], 483 | "scores": [], 484 | "thinking": "", 485 | "full_response": "", 486 | "pred_answer": None 487 | } for _ in range(len(images))] 488 | 489 | # SegmentationModel implementation 490 | def segment_objects(self, image, query): 491 | """ 492 | Segment objects in an image based on a query 493 | 494 | Args: 495 | image: Input image 496 | query: Text query describing what to segment 497 | 498 | Returns: 499 | dict: Results with masks and bounding boxes 500 | """ 501 | try: 502 | if self.use_hybrid_mode and self.if_yolo_condition(query): 503 | #bboxes, masks = self.segment_objects_yolo(image, query) 504 | bboxes = self.detect_objects_yolo(image, query) 505 | # use middle point of bbox as point 506 | points = [[int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2)] for bbox in bboxes] 507 | output_text, thinking, pred_answer = "", "", str(bboxes) 508 | else: 509 | output_text, (x_factor, y_factor) = self._generate_model_output( 510 | image, 511 | query, 512 | self.DETECTION_TEMPLATE 513 | ) 514 | bboxes, points, thinking, pred_answer = self.extract_bbox_points_think( 515 | output_text, 516 | x_factor, 517 | y_factor 518 | ) 519 | masks = self.generate_masks(image, bboxes, points) 520 | 521 | return { 522 | "masks": masks, 523 | "bboxes": bboxes, 524 | "points": points, 525 | "thinking": thinking, 526 | "full_response": output_text, 527 | "pred_answer": pred_answer 528 | } 529 | except Exception as e: 530 | raise 531 | print(f"Error in segmentation: {e}") 532 | img_height, img_width = image.height, image.width 533 | return { 534 | "masks": np.zeros((img_height, img_width), dtype=bool), 535 | "bboxes": [], 536 | "points": [], 537 | "thinking": "", 538 | "full_response": "", 539 | "pred_answer": None 540 | } 541 | 542 | def segment_objects_batch(self, images, queries): 543 | """ 544 | Segment objects in a batch of images 545 | 546 | Args: 547 | images: List of input images 548 | queries: List of text queries 549 | 550 | Returns: 551 | list: List of segmentation results 552 | """ 553 | try: 554 | # TODO: support yolo for batch 555 | output_texts, scale_factors = self._generate_model_output( 556 | images, 557 | queries, 558 | self.DETECTION_TEMPLATE, 559 | batch_mode=True 560 | ) 561 | 562 | results = [] 563 | for image, output_text, (x_factor, y_factor) in zip(images, output_texts, scale_factors): 564 | bboxes, points, thinking, pred_answer = self.extract_bbox_points_think( 565 | output_text, 566 | x_factor, 567 | y_factor 568 | ) 569 | 570 | masks = self.generate_masks(image, bboxes, points) 571 | results.append({ 572 | "masks": masks, 573 | "bboxes": bboxes, 574 | "points": points, 575 | "thinking": thinking, 576 | "full_response": output_text, 577 | "pred_answer": pred_answer 578 | }) 579 | return results 580 | except Exception as e: 581 | print(f"Error in batch segmentation: {e}") 582 | return [{ 583 | "masks": np.zeros((img.height, img.width), dtype=bool), 584 | "bboxes": [], 585 | "points": [], 586 | "thinking": "", 587 | "full_response": "", 588 | "pred_answer": None 589 | } for img in images] 590 | 591 | # CountingModel implementation 592 | def count_objects(self, image, query): 593 | """ 594 | Count objects in an image based on a query 595 | 596 | Args: 597 | image: Input image 598 | query: Text query describing what to count 599 | 600 | Returns: 601 | dict: Results with count and bounding boxes 602 | """ 603 | try: 604 | if self.use_hybrid_mode and self.if_yolo_condition(query): 605 | bboxes = self.detect_objects_yolo(image, query) 606 | # use middle point of bbox as point 607 | points = [[int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2)] for bbox in bboxes] 608 | output_text, thinking, pred_answer = "", "", str(bboxes) 609 | else: 610 | output_text, (x_factor, y_factor) = self._generate_model_output( 611 | image, 612 | query, 613 | self.DETECTION_TEMPLATE 614 | ) 615 | 616 | bboxes, points, thinking, pred_answer = self.extract_bbox_points_think( 617 | output_text, 618 | x_factor, 619 | y_factor 620 | ) 621 | 622 | count = len(bboxes) 623 | 624 | return { 625 | "count": count, 626 | "bboxes": bboxes, 627 | "points": points, 628 | "thinking": thinking, 629 | "full_response": output_text, 630 | "pred_answer": pred_answer 631 | } 632 | except Exception as e: 633 | print(f"Error in counting: {e}") 634 | return { 635 | "count": 0, 636 | "bboxes": [], 637 | "points": [], 638 | "thinking": "", 639 | "full_response": "", 640 | "pred_answer": None 641 | } 642 | 643 | def count_objects_batch(self, images, queries): 644 | """ 645 | Count objects in a batch of images 646 | 647 | Args: 648 | images: List of input images 649 | queries: List of text queries 650 | 651 | Returns: 652 | list: List of counting results 653 | """ 654 | try: 655 | # TODO: support yolo for batch 656 | output_texts, scale_factors = self._generate_model_output( 657 | images, 658 | queries, 659 | self.DETECTION_TEMPLATE, 660 | batch_mode=True 661 | ) 662 | 663 | results = [] 664 | for output_text, (x_factor, y_factor) in zip(output_texts, scale_factors): 665 | bboxes, points, thinking, pred_answer = self.extract_bbox_points_think( 666 | output_text, 667 | x_factor, 668 | y_factor 669 | ) 670 | 671 | count = len(bboxes) 672 | results.append({ 673 | "count": count, 674 | "bboxes": bboxes, 675 | "points": points, 676 | "thinking": thinking, 677 | "full_response": output_text, 678 | "pred_answer": pred_answer 679 | }) 680 | return results 681 | except Exception as e: 682 | print(f"Error in batch counting: {e}") 683 | return [{ 684 | "count": 0, 685 | "bboxes": [], 686 | "points": [], 687 | "thinking": "", 688 | "full_response": "", 689 | "pred_answer": None 690 | } for _ in range(len(images))] 691 | 692 | # QAModel implementation 693 | def answer_question(self, image, question): 694 | """ 695 | Answer a question about an image 696 | 697 | Args: 698 | image: Input image 699 | question: Text question 700 | 701 | Returns: 702 | dict: Results with answer and thinking (if available) 703 | """ 704 | try: 705 | output_text, _ = self._generate_model_output( 706 | image, 707 | question, 708 | self.QA_TEMPLATE 709 | ) 710 | 711 | result = self.extract_qa_answer(output_text) 712 | return result 713 | except Exception as e: 714 | print(f"Error in QA: {e}") 715 | return { 716 | "answer": "", 717 | "thinking": "", 718 | "full_response": "" 719 | } 720 | 721 | def answer_questions_batch(self, images, questions): 722 | """ 723 | Answer questions about a batch of images 724 | 725 | Args: 726 | images: List of input images 727 | questions: List of text questions 728 | 729 | Returns: 730 | list: List of QA results 731 | """ 732 | try: 733 | output_texts, _ = self._generate_model_output( 734 | images, 735 | questions, 736 | self.QA_TEMPLATE, 737 | batch_mode=True 738 | ) 739 | 740 | results = [] 741 | for output_text in output_texts: 742 | result = self.extract_qa_answer(output_text) 743 | results.append(result) 744 | return results 745 | except Exception as e: 746 | print(f"Error in batch QA: {e}") 747 | return [{ 748 | "answer": "", 749 | "thinking": "", 750 | "full_response": "" 751 | } for _ in range(len(images))] 752 | 753 | def generate_image(self, refer_image_path, image_prompt): 754 | """ 755 | Generate an image based on a query 756 | 757 | Args: 758 | refer_image_path: Path to the reference image 759 | image_prompt: Text prompt describing what to generate 760 | 761 | Returns: 762 | dict: Results with generated image and thinking (if available) 763 | """ 764 | if self.generation_model is None or image_prompt is None: 765 | raise ValueError("Do not have generation model or query") 766 | 767 | try: 768 | if refer_image_path == "": 769 | # Generate the image 770 | output = self.generation_model.images.generate( 771 | model="gpt-image-1", 772 | prompt=image_prompt, 773 | ) 774 | image_base64 = output.data[0].b64_json 775 | 776 | else: 777 | output = self.generation_model.images.edit( 778 | model="gpt-image-1", 779 | image=[open(refer_image_path, "rb")], 780 | prompt=image_prompt, 781 | ) 782 | image_base64 = output.data[0].b64_json 783 | 784 | image = PILImage.open(BytesIO(base64.b64decode(image_base64))) 785 | return image 786 | except Exception as e: 787 | print(f"Error in image generation: {e}") 788 | return None -------------------------------------------------------------------------------- /vision_reasoner/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | def visualize_results_enhanced(image, result, task_type, output_path): 5 | """ 6 | Enhanced visualization with three-panel layout 7 | """ 8 | # Create a figure with 3 subplots 9 | plt.figure(figsize=(15, 5)) 10 | 11 | # First panel: Original image 12 | plt.subplot(1, 3, 1) 13 | plt.imshow(image) 14 | plt.title('Original Image') 15 | plt.axis('off') 16 | 17 | # Second panel: Image with bounding boxes 18 | plt.subplot(1, 3, 2) 19 | plt.imshow(image) 20 | 21 | if 'bboxes' in result and result['bboxes']: 22 | for bbox in result['bboxes']: 23 | x1, y1, x2, y2 = bbox 24 | rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 25 | fill=False, edgecolor='red', linewidth=2) 26 | plt.gca().add_patch(rect) 27 | 28 | if 'points' in result and result['points']: 29 | for point in result['points']: 30 | plt.plot(point[0], point[1], 'go', markersize=8) # Green point 31 | 32 | plt.title('Image with Bounding Boxes') 33 | plt.axis('off') 34 | 35 | # Third panel: Mask overlay (for segmentation tasks) 36 | plt.subplot(1, 3, 3) 37 | plt.imshow(image, alpha=0.6) 38 | 39 | if task_type == 'segmentation' and 'masks' in result and result['masks'] is not None: 40 | mask = result['masks'] 41 | if np.any(mask): 42 | mask_overlay = np.zeros_like(np.array(image)) 43 | mask_overlay[mask] = [255, 0, 0] # Red color for mask 44 | plt.imshow(mask_overlay, alpha=0.4) 45 | 46 | if task_type == 'detection' or task_type == 'counting': 47 | # For non-segmentation tasks, just show bounding boxes again 48 | if 'bboxes' in result and result['bboxes']: 49 | for bbox in result['bboxes']: 50 | x1, y1, x2, y2 = bbox 51 | rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 52 | fill=True, edgecolor='red', facecolor='red', alpha=0.3) 53 | plt.gca().add_patch(rect) 54 | 55 | task_title = { 56 | 'detection': 'Detection Overlay', 57 | 'segmentation': 'Segmentation Mask', 58 | 'counting': 'Counting Results', 59 | 'qa': 'Visual QA' 60 | } 61 | 62 | plt.title(task_title.get(task_type, 'Results Overlay')) 63 | plt.axis('off') 64 | 65 | plt.tight_layout() 66 | plt.savefig(output_path) 67 | plt.close() --------------------------------------------------------------------------------