├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── assets
├── airplanes.png
├── company_name.png
├── crowd.png
├── crowd_output_1.png
├── crowd_output_2.png
├── dog.png
├── donuts.png
├── donuts_output.png
├── overview.png
├── pipeline.png
├── stand_higher.png
└── stand_higher_output.png
├── evaluation
├── calculate_coco_ap.py
├── calculate_counting.py
├── calculate_iou.py
├── calculate_iou_with_bbox.py
├── coco_gt
│ └── instances_val2017.json
├── eval_coco.sh
├── eval_count.sh
├── eval_segmentation.sh
├── evaluation_coco.py
├── evaluation_count.py
└── evaluation_segmentation.py
├── requirements.txt
├── task_categorization.md
└── vision_reasoner
├── __init__.py
├── inference.py
├── models
├── base_model.py
├── qwen_vl.py
├── task_router.py
└── vision_reasoner_model.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
173 | # outputs
174 | outputs/
175 | checkpoints/
176 | wandb/
177 | workdir
178 | reasonseg_eval_results
179 | prepare_dataset
180 | result_visualization.png
181 | detection_eval_results/
182 | pretrained_models/
183 | test_scripts.sh
184 | test_yoloe.py
185 | mobileclip_blt.ts
186 | yolov8l-world.pt
187 | yolov8x-worldv2.pt
188 | result_visualization_fail.png
189 | crowd_people.png
190 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: check-ast
6 | - id: check-added-large-files
7 | args: ['--maxkb=25000']
8 | - id: check-merge-conflict
9 | - id: check-yaml
10 | - id: debug-statements
11 | - id: end-of-file-fixer
12 | - id: requirements-txt-fixer
13 | - id: trailing-whitespace
14 | args: [--markdown-linebreak-ext=md]
15 | - id: no-commit-to-branch
16 | args: ['--branch', 'main']
17 |
18 | - repo: https://github.com/asottile/pyupgrade
19 | rev: v3.17.0
20 | hooks:
21 | - id: pyupgrade
22 | args: [--py38-plus]
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VisionReasoner: Unified Visual Perception and Reasoning via Reinforcement Learning
2 |
3 | > Current VLMs are primarily used for visual captioning or visual QA tasks. In this project, we take a step further by demonstrating the potential of a single VLM to solve diverse vision tasks. We hope this work will advance the frontier of VLM research and expand the boundaries of what these models can achieve.
4 |
5 | Paper: [📖 VisionReasoner](https://arxiv.org/pdf/2505.12081) [📖 Seg-Zero](https://arxiv.org/pdf/2503.06520)
6 | HuggingFace Daily: [🤗 VisionReasoner](https://huggingface.co/papers/2505.12081)
7 | Model: [🤗 VisionReasoner-7B](https://huggingface.co/Ricky06662/VisionReasoner-7B) [🤗 TaskRouter-1.5B](https://huggingface.co/Ricky06662/TaskRouter-1.5B)
8 | Relative Link: [Seg-Zero![[code]](https://img.shields.io/github/stars/dvlab-research/Seg-Zero)](https://github.com/dvlab-research/Seg-Zero)
9 |
10 | Overview of VisionReasoner:
11 |
12 |
13 |

14 |
15 |
16 | VisionReasoner demonstrates following features:
17 | 1. **VisionReasoner** is a unified framework for visual perception tasks. Through carefully crafted rewards and training strategy, VisionReasoner has strong multi-task capability, addressing diverse visual perception tasks within a shared model.
18 | 2. We select several representative tasks to evaluate models unified visual ability, including detection tasks (e.g., [COCO](https://cocodataset.org/#home), [RefCOCOg](https://github.com/lichengunc/refer)), segmentation tasks (e.g., [ReasonSeg](https://github.com/dvlab-research/LISA)), counting tasks (e.g., [CountBench](https://teaching-clip-to-count.github.io/)) and VQA tasks (e.g. [DocVQA](https://www.docvqa.org/)).
19 | 3. Experimental results show that VisionReasoner achieves superior performance across ten diverse visual perception tasks within a single unified framework, outperforming baseline models by a significant margin.
20 | 4. We have reformulated dozens of visual task types categoried in [Papers With Code](https://paperswithcode.com/datasets?mod=images&page=1). Please refer to [task categorization](task_categorization.md) for details. These task types are categoried as four fundamental task types: detection, segmentation, counting and VQA. More supported task types and more fundamental task types can be added in this framework, such as 3D or medical image processing.
21 |
22 |
23 | ## News
24 |
25 | [May 17th, 2025] 🔥 [📖 Paper](https://arxiv.org/pdf/2505.12081) is coming!
26 | [May 17th, 2025] 🔥 VisionReasoner is coming! VisionReasoner is based on our previous [Seg-Zero](https://github.com/dvlab-research/Seg-Zero).
27 |
28 |
29 | ## Contents
30 | - [Model](#model)
31 | - [Installation](#installation)
32 | - [Inference](#inference)
33 | - [Hybrid Mode](#hybrid-mode)
34 | - [Image Generation](#image-generation)
35 | - [Evaluation](#evaluation)
36 | - [Training](#training)
37 | - [Citation](#citation)
38 | - [Acknowledgement](#acknowledgement)
39 |
40 |
41 |
42 | ## Model
43 |
44 |

45 |
46 |
47 | VisionReasoner model incorporates a reasoning module, which processing image and locates targeted objects, and a segmentation module that produces segmentation masks if needed.
48 | Besides, we also train a task router that convert diverse vision tasks into given four fundamental task types.
49 |
50 |
51 |
56 |
57 |
58 | ## Installation
59 | > [!NOTE]
60 | > If you train VisionReasoner using codes in [Seg-Zero](https://github.com/dvlab-research/Seg-Zero), you can directly use the environment of the training codes.
61 |
62 | ```bash
63 | git clone https://github.com/dvlab-research/VisionReasoner.git
64 | cd VisionReasoner
65 | conda create -n visionreasoner_test python=3.12
66 | conda activate visionreasoner_test
67 | pip3 install torch torchvision
68 | pip install -r requirements.txt
69 | ```
70 |
71 |
72 | ## Inference
73 | Download model using the following scripts:
74 | ```bash
75 | mkdir pretrained_models
76 | cd pretrained_models
77 | git lfs install
78 | git clone https://huggingface.co/Ricky06662/VisionReasoner-7B
79 | git clone https://huggingface.co/Ricky06662/TaskRouter-1.5B
80 | ```
81 | > [!TIP]
82 | > If you encounter issues with connecting to Hugging Face, consider using `export HF_ENDPOINT=https://hf-mirror.com`.
83 |
84 |
85 | Then run inference using:
86 | ```bash
87 | python vision_reasoner/inference.py
88 | ```
89 | ### The default task is a counting task.
90 | > "How many airplanes are there in this image?"
91 |
92 |
93 |

94 |
95 |
96 |
97 | You will get the thinking process in command line, like:
98 |
99 | > "The image shows a formation of airplanes flying in the sky. Each airplane is distinct and can be counted individually. The planes are arranged in a specific pattern, and there are visible trails of smoke behind them, which is typical for airshows or demonstrations."
100 |
101 | And you will get the final answer in command line, like:
102 |
103 | > "Total number of interested objects is: 10"
104 |
105 |
106 | ### You can also try a detection / segmentation task by:
107 | ```bash
108 | python vision_reasoner/inference.py --image_path "assets/donuts.png" --query "please segment the donuts"
109 | ```
110 |
111 | You will get the thinking process in command line, like:
112 |
113 | > "The task involves identifying and segmenting individual donuts in the image. Each donut is distinct in its color, glaze, and toppings, which helps in distinguishing them from one another. The goal is to identify each donut as a separate object and provide bounding boxes for them."
114 |
115 | And the result will be presented in result_visualization.png.
116 |
117 |
118 |

119 |
120 |
121 | ### Or some tasks that need reasoning:
122 |
123 | ```bash
124 | python vision_reasoner/inference.py --image_path "assets/stand_higher.png" --query "find what can make the woman stand higher?"
125 | ```
126 |
127 | You will get the thinking process in command line, like:
128 |
129 | > "The question asks for objects that can make the woman stand higher. The woman is already standing on a ladder, which is the object that elevates her. The ladder is the most closely matched object to what can make her stand higher."
130 |
131 | And the result will be presented in result_visualization.png.
132 |
133 |
134 |

135 |
136 |
137 |
138 | ### We also support naive visual QA / captioning task:
139 | ```bash
140 | python vision_reasoner/inference.py --image_path "assets/company_name.png" --query "What is name of the company?"
141 | ```
142 |
143 |
144 |

145 |
146 |
147 | In VQA, there are no reasoning, and you will get the final answer in command line, like:
148 |
149 | > "The answer is: The name of the company is ITC (Indian Tobacco Company Limited)."
150 |
151 | ### You can also provide your own image_path and text by:
152 | ```bash
153 | python vision_reasoner/inference.py --image_path "your_image_path" --query "your question text"
154 | ```
155 |
156 | ## Hybrid Mode:
157 | When hybrid reasoning mode is enabled, VisionReasoner intelligently switches between direct detection (using YOLO-World) and reasoning-based approaches based on the complexity of the query. This allows for faster responses on simple queries while maintaining detailed reasoning for complex tasks.
158 |
159 | ### Simple Query Example:
160 | For straightforward queries that can be directly answered by object detection:
161 |
162 | ```bash
163 | python vision_reasoner/inference.py --image_path "assets/crowd.png" --query "person" --hybrid_mode
164 | ```
165 |
166 | Output:
167 |
168 |
169 |

170 |
171 |
172 | In this case, the model directly uses YOLO-World for detection without going through the reasoning process, resulting in faster response times.
173 |
174 | ### Complex Query Example:
175 | For queries that require spatial reasoning or complex understanding:
176 |
177 | ```bash
178 | python vision_reasoner/inference.py --image_path "assets/crowd.png" --query "the person who is facing to the camera" --hybrid_mode
179 | ```
180 |
181 | Output:
182 | > Thinking process: The task involves identifying the person who is facing the camera and then finding the most closely matched object. In the image, there is a person in the center wearing a white shirt and a black vest, who appears to be facing the camera directly. The other individuals are walking away from the camera, so they are not the target. The person in the white shirt and black vest is the closest match to the description of facing the camera.
183 |
184 |
185 |
186 |

187 |
188 |
189 | In this case, the model switches to the reasoning-based approach because the query requires understanding spatial relationships and visual attributes.
190 |
191 |
192 | ## Image Generation:
193 | Our framework can also incorporate generation tasks. We adopt [gpt-image-1](https://platform.openai.com/docs/guides/image-generation?image-generation-model=gpt-image-1) for generation in current version.
194 |
195 | > [!NOTE]
196 | > Bugs might arise from API version mismatches. Please debug and customize based on your API key and version.
197 |
198 | ### Text-to-image generation
199 | For text to image generation, you can only input a prompt
200 | ```bash
201 | python vision_reasoner/inference.py --image_prompt "Draw a image of a cute dog." --generation_model_name [your openAI api key] --generation_mode
202 | ```
203 |
204 | ### Image reference generation
205 | For image reference generation, you should input a prompt and reference image
206 | ```bash
207 | python vision_reasoner/inference.py --refer_image_path "assets/dog.png" --image_prompt "Generate a cute dog in a forest" --generation_model_name [your openAI api key] --generation_mode
208 | ```
209 |
210 | ## Evaluation
211 |
212 | The evaluation scripts allow you to test VisionReasoner on various datasets. We provide scripts for evaluating segmentation, detection, and counting tasks.
213 |
214 | ### Using the Evaluation Scripts
215 |
216 | Each evaluation script accepts either a HuggingFace dataset path or a local dataset path:
217 |
218 | ```bash
219 | # Using HuggingFace dataset paths (default in examples)
220 | bash evaluation/eval_segmentation.sh Ricky06662/refcoco_val
221 |
222 | # Using local dataset paths
223 | bash evaluation/eval_segmentation.sh /path/to/your/local/refcoco_val
224 | ```
225 |
226 | Additionally, you can customize model paths with the following parameters:
227 |
228 | ```bash
229 | # Using local model paths (instead of downloading from HuggingFace)
230 | bash evaluation/eval_segmentation.sh [dataset_path] \
231 | --model_path /path/to/local/VisionReasoner-7B \
232 | --task_router_model_path /path/to/local/TaskRouter-1.5B
233 | ```
234 |
235 | ### Available Evaluation Scripts
236 |
237 | - `eval_segmentation.sh`: Evaluates segmentation performance on RefCOCO, RefCOCO+, RefCOCOg, and ReasonSeg datasets. When the dataset contains bounding box ground truth annotations, it will also output detection metrics.
238 | - `eval_coco.sh`: Evaluates detection performance on COCO dataset
239 | - `eval_count.sh`: Evaluates counting performance on counting benchmarks
240 |
241 | ### Example Commands
242 |
243 | ```bash
244 | # Segmentation/Detection evaluation
245 | bash evaluation/eval_segmentation.sh Ricky06662/refcoco_val
246 | bash evaluation/eval_segmentation.sh Ricky06662/refcoco_testA
247 | bash evaluation/eval_segmentation.sh Ricky06662/refcocoplus_val
248 | bash evaluation/eval_segmentation.sh Ricky06662/refcocoplus_testA
249 | bash evaluation/eval_segmentation.sh Ricky06662/refcocog_val
250 | bash evaluation/eval_segmentation.sh Ricky06662/refcocog_test
251 | bash evaluation/eval_segmentation.sh Ricky06662/ReasonSeg_val
252 | bash evaluation/eval_segmentation.sh Ricky06662/ReasonSeg_test
253 |
254 | # COCO evaluation
255 | bash evaluation/eval_coco.sh Ricky06662/coco_val
256 |
257 | # Counting evaluation
258 | bash evaluation/eval_count.sh Ricky06662/counting_pixmo_validation
259 | bash evaluation/eval_count.sh Ricky06662/counting_pixmo_test
260 | bash evaluation/eval_count.sh Ricky06662/counting_countbench
261 | ```
262 |
263 | ## Training
264 |
265 | We recommend you to [Seg-Zero](https://github.com/dvlab-research/Seg-Zero) for training the VisionReasoner.
266 |
267 |
268 | ## Citation
269 |
270 | ```bibtex
271 | @article{liu2025segzero,
272 | title = {Seg-Zero: Reasoning-Chain Guided Segmentation via Cognitive Reinforcement},
273 | author = {Liu, Yuqi and Peng, Bohao and Zhong, Zhisheng and Yue, Zihao and Lu, Fanbin and Yu, Bei and Jia, Jiaya},
274 | journal = {arXiv preprint arXiv:2503.06520},
275 | year = {2025}
276 | }
277 |
278 | @article{liu2025visionreasoner,
279 | title = {VisionReasoner: Unified Visual Perception and Reasoning via Reinforcement Learning},
280 | author = {Liu, Yuqi and Qu, Tianyuan and Zhong, Zhisheng and Peng, Bohao and Liu, Shu and Yu, Bei and Jia, Jiaya},
281 | journal = {arXiv preprint arXiv:2505.12081},
282 | year = {2025}
283 | }
284 | ```
285 |
286 | ## Acknowledgement
287 | We would like to thank the following repos for their great work:
288 |
289 | - This work is built upon the [Seg-Zero](https://github.com/dvlab-research/Seg-Zero), [EasyR1](https://github.com/hiyouga/EasyR1) and [veRL](https://github.com/volcengine/verl).
290 | - This work utilizes models from [Qwen2-VL](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct), [Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct), [SAM2](https://huggingface.co/facebook/sam2-hiera-large) and [YOLO-World](https://github.com/AILab-CVC/YOLO-World).
291 |
292 |
293 | ## Star History
294 |
295 | [](https://star-history.com/#dvlab-research/VisionReasoner&Date)
--------------------------------------------------------------------------------
/assets/airplanes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/airplanes.png
--------------------------------------------------------------------------------
/assets/company_name.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/company_name.png
--------------------------------------------------------------------------------
/assets/crowd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/crowd.png
--------------------------------------------------------------------------------
/assets/crowd_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/crowd_output_1.png
--------------------------------------------------------------------------------
/assets/crowd_output_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/crowd_output_2.png
--------------------------------------------------------------------------------
/assets/dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/dog.png
--------------------------------------------------------------------------------
/assets/donuts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/donuts.png
--------------------------------------------------------------------------------
/assets/donuts_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/donuts_output.png
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/overview.png
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/pipeline.png
--------------------------------------------------------------------------------
/assets/stand_higher.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/stand_higher.png
--------------------------------------------------------------------------------
/assets/stand_higher_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/assets/stand_higher_output.png
--------------------------------------------------------------------------------
/evaluation/calculate_coco_ap.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import glob
4 | import numpy as np
5 | from argparse import ArgumentParser
6 | from pycocotools.coco import COCO
7 | from pycocotools.cocoeval import COCOeval
8 |
9 |
10 | def parse_args():
11 | parser = ArgumentParser()
12 | parser.add_argument("--output_dir", type=str, required=True, help="folder path of output files")
13 | parser.add_argument("--gt_json_path", type=str, required=True, help="path to COCO ground truth json file")
14 | return parser.parse_args()
15 |
16 | def calculate_metrics(output_dir, gt_json_path):
17 | # get all output files
18 | output_files = sorted(glob.glob(os.path.join(output_dir, "output_*.json")))
19 |
20 | if not output_files:
21 | print(f"cannot find output files in {output_dir}")
22 | return
23 |
24 | # for accumulating all data
25 | pred_results = []
26 |
27 | pred_results_constant_score = []
28 |
29 | pred_results_constant_exist_score = []
30 |
31 | # for calculating think text length
32 | think_text_lengths = []
33 |
34 | # read and process all files
35 | for file_path in output_files:
36 | with open(file_path, 'r', encoding='utf-8') as f:
37 | results = json.load(f)
38 |
39 | # process all items in each file
40 | for item in results:
41 | # Calculate think text length if available
42 | if 'think' in item and item['think']:
43 | think_text_lengths.append(len(item['think']))
44 |
45 | # Original bbox processing code
46 | bbox = item['bbox']
47 | bbox = [bbox[0], bbox[1], bbox[2]-bbox[0], bbox[3]-bbox[1]]
48 |
49 | pred_results.append({
50 | "image_id": item['image_id'],
51 | "category_id": item['category_id'],
52 | "bbox": bbox,
53 | "score": item['score']
54 | })
55 |
56 | pred_results_constant_score.append({
57 | "image_id": item['image_id'],
58 | "category_id": item['category_id'],
59 | "bbox": bbox,
60 | "score": 1.0
61 | })
62 |
63 | pred_results_constant_exist_score.append({
64 | "image_id": item['image_id'],
65 | "category_id": item['category_id'],
66 | "bbox": bbox,
67 | "score": 1.0 if item['score'] <= 0.2 else 0.0
68 | })
69 |
70 | # Calculate think text metrics
71 | if think_text_lengths:
72 | avg_think_length = sum(think_text_lengths) / len(think_text_lengths)
73 | min_think_length = min(think_text_lengths)
74 | max_think_length = max(think_text_lengths)
75 | print(f"\n-----------------Think Text Statistics----------------------------------")
76 | print(f"Number of think texts: {len(think_text_lengths)}")
77 | print(f"Average think text length: {avg_think_length:.2f} characters")
78 | print(f"Minimum think text length: {min_think_length} characters")
79 | print(f"Maximum think text length: {max_think_length} characters")
80 | print(f"------------------------------------------------------------------\n")
81 |
82 | coco_gt = COCO(gt_json_path) # load ground truth
83 | coco_dt = coco_gt.loadRes(pred_results) # load prediction results
84 |
85 | # initialize evaluation object (task type: bbox/keypoints/segmentation)
86 | coco_eval = COCOeval(coco_gt, coco_dt, 'bbox') # select task type
87 |
88 | # run evaluation
89 | coco_eval.evaluate() # calculate matches
90 | coco_eval.accumulate() # accumulate metrics
91 | coco_eval.summarize() # output results
92 |
93 |
94 | if __name__ == "__main__":
95 | args = parse_args()
96 | calculate_metrics(args.output_dir, args.gt_json_path)
97 |
--------------------------------------------------------------------------------
/evaluation/calculate_counting.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import glob
4 | import numpy as np
5 | from argparse import ArgumentParser
6 |
7 | def parse_args():
8 | parser = ArgumentParser()
9 | parser.add_argument("--output_dir", type=str, required=True, help="folder path of output files")
10 | return parser.parse_args()
11 |
12 | def calculate_metrics(output_dir):
13 | # get all output files
14 | output_files = sorted(glob.glob(os.path.join(output_dir, "output_*.json")))
15 |
16 | if not output_files:
17 | print(f"cannot find output files in {output_dir}")
18 | return
19 |
20 | # for accumulating all data
21 | all_counts = []
22 |
23 | # read and process all files
24 | for file_path in output_files:
25 | with open(file_path, 'r', encoding='utf-8') as f:
26 | results = json.load(f)
27 |
28 | # process all items in each file
29 | for item in results:
30 | pred_count = item['pred_count']
31 | gt_count = item['gt_count']
32 |
33 | all_counts.append({
34 | 'image_id': item['image_id'],
35 | 'mae': abs(pred_count - gt_count),
36 | 'rmse': (pred_count - gt_count) ** 2,
37 | 'correct_count': pred_count == gt_count
38 | })
39 |
40 | # calculate mae and rmse
41 | mae = np.mean([item['mae'] for item in all_counts])
42 | rmse = np.sqrt(np.mean([item['rmse'] for item in all_counts]))
43 | correct_count = np.mean([item['correct_count'] for item in all_counts])
44 | # print the results
45 | print(f"test len: {len(all_counts)}")
46 | print(f"mae: {mae:.4f}")
47 | print(f"rmse: {rmse:.4f}")
48 | print(f"correct_count: {correct_count:.4f}")
49 |
50 |
51 | if __name__ == "__main__":
52 | args = parse_args()
53 | calculate_metrics(args.output_dir)
54 |
--------------------------------------------------------------------------------
/evaluation/calculate_iou.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import glob
4 | import numpy as np
5 | from argparse import ArgumentParser
6 |
7 | def parse_args():
8 | parser = ArgumentParser()
9 | parser.add_argument("--output_dir", type=str, required=True, help="folder path of output files")
10 | return parser.parse_args()
11 |
12 | def calculate_metrics(output_dir):
13 | # get all output files
14 | output_files = sorted(glob.glob(os.path.join(output_dir, "output_*.json")))
15 |
16 | if not output_files:
17 | print(f"cannot find output files in {output_dir}")
18 | return
19 |
20 | # for accumulating all data
21 | total_intersection = 0
22 | total_union = 0
23 | all_ious = []
24 |
25 | # read and process all files
26 | for file_path in output_files:
27 | with open(file_path, 'r', encoding='utf-8') as f:
28 | results = json.load(f)
29 |
30 | # process all items in each file
31 | for item in results:
32 | intersection = item['intersection']
33 | union = item['union']
34 |
35 | # calculate IoU of each item
36 | iou = intersection / union if union > 0 else 0
37 | all_ious.append({
38 | 'image_id': item['image_id'],
39 | 'iou': iou
40 | })
41 |
42 | # accumulate total intersection and union
43 | total_intersection += intersection
44 | total_union += union
45 |
46 | # calculate gIoU
47 | gIoU = np.mean([item['iou'] for item in all_ious])
48 | # calculate cIoU
49 | cIoU = total_intersection / total_union if total_union > 0 else 0
50 |
51 | # print the results
52 | print(f"gIoU (average of per image IoU): {gIoU:.4f}")
53 | print(f"cIoU (total_intersection / total_union): {cIoU:.4f}")
54 |
55 |
56 | if __name__ == "__main__":
57 | args = parse_args()
58 | calculate_metrics(args.output_dir)
59 |
--------------------------------------------------------------------------------
/evaluation/calculate_iou_with_bbox.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import glob
4 | import numpy as np
5 | from argparse import ArgumentParser
6 |
7 | def parse_args():
8 | parser = ArgumentParser()
9 | parser.add_argument("--output_dir", type=str, required=True, help="folder path of output files")
10 | return parser.parse_args()
11 |
12 | def calculate_metrics(output_dir):
13 | # get all output files
14 | output_files = sorted(glob.glob(os.path.join(output_dir, "output_*.json")))
15 |
16 | if not output_files:
17 | print(f"cannot find output files in {output_dir}")
18 | return
19 |
20 | # for accumulating all data
21 | total_intersection = 0
22 | total_union = 0
23 | total_bbox_iou = 0
24 | all_ious = []
25 | cnt = 0
26 |
27 | # for calculating think text length
28 | think_text_lengths = []
29 |
30 | # read and process all files
31 | for file_path in output_files:
32 | with open(file_path, 'r', encoding='utf-8') as f:
33 | results = json.load(f)
34 |
35 | # process all items in each file
36 | for item in results:
37 | # Calculate think text length if available
38 | if 'think' in item and item['think']:
39 | think_text_lengths.append(len(item['think']))
40 |
41 | intersection = item['intersection']
42 | union = item['union']
43 |
44 | # calculate IoU of each item
45 | iou = intersection / union if union > 0 else 0
46 | all_ious.append({
47 | 'image_id': item['image_id'],
48 | 'iou': iou
49 | })
50 |
51 | # accumulate total intersection and union
52 | total_intersection += intersection
53 | total_union += union
54 | total_bbox_iou += item['bbox_iou']
55 | cnt += 1
56 |
57 | # Calculate think text metrics
58 | if think_text_lengths:
59 | avg_think_length = sum(think_text_lengths) / len(think_text_lengths)
60 | min_think_length = min(think_text_lengths)
61 | max_think_length = max(think_text_lengths)
62 | print(f"\n-----------------Think Text Statistics----------------------------------")
63 | print(f"Number of think texts: {len(think_text_lengths)}")
64 | print(f"Average think text length: {avg_think_length:.2f} characters")
65 | print(f"Minimum think text length: {min_think_length} characters")
66 | print(f"Maximum think text length: {max_think_length} characters")
67 | print(f"------------------------------------------------------------------\n")
68 |
69 | # calculate gIoU
70 | gIoU = np.mean([item['iou'] for item in all_ious])
71 | # calculate cIoU
72 | cIoU = total_intersection / total_union if total_union > 0 else 0
73 | # calculate bbox_iou
74 | bbox_iou = total_bbox_iou / cnt if cnt > 0 else 0
75 |
76 | # print the results
77 | print(f"gIoU (average of per image IoU): {gIoU:.4f}")
78 | print(f"cIoU (total_intersection / total_union): {cIoU:.4f}")
79 | print(f"bbox_AP (average of per image bbox_AP): {bbox_iou:.4f}")
80 |
81 | if __name__ == "__main__":
82 | args = parse_args()
83 | calculate_metrics(args.output_dir)
84 |
--------------------------------------------------------------------------------
/evaluation/eval_coco.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | MODEL_TYPE="vision_reasoner" # Model type: qwen or vision_reasoner
5 | TEST_DATA_PATH=${1:-"Ricky06662/coco_val"}
6 |
7 | # Extract model name and test dataset name for output directory
8 | TEST_NAME=$(echo $TEST_DATA_PATH | sed -E 's/.*\/([^\/]+)$/\1/')
9 | OUTPUT_PATH="./detection_eval_results/${MODEL_TYPE}/${TEST_NAME}"
10 |
11 | # Customize GPU array here - specify which GPUs to use
12 | GPU_ARRAY=(0 1 2 3 4 5 6 7) # Example: using GPUs 0, 1, 2, 3
13 | NUM_PARTS=${#GPU_ARRAY[@]}
14 |
15 | # Create output directory
16 | mkdir -p $OUTPUT_PATH
17 |
18 | # Run processes in parallel
19 | for i in $(seq 0 $((NUM_PARTS-1))); do
20 | gpu_id=${GPU_ARRAY[$i]}
21 | process_idx=$i # 0-based indexing for process
22 |
23 | export CUDA_VISIBLE_DEVICES=$gpu_id
24 | (
25 | python evaluation/evaluation_coco.py \
26 | --model $MODEL_TYPE \
27 | --output_path $OUTPUT_PATH \
28 | --test_data_path $TEST_DATA_PATH \
29 | --idx $process_idx \
30 | --num_parts $NUM_PARTS \
31 | --batch_size 16 || { echo "1" > /tmp/process_status.$$; kill -TERM -$$; }
32 | ) &
33 | done
34 |
35 | # Wait for all processes to complete
36 | wait
37 |
38 | COCO_GT_JSON_PATH="evaluation/coco_gt/instances_val2017.json"
39 |
40 | # Calculate COCO AP metrics
41 | python evaluation/calculate_coco_ap.py --output_dir $OUTPUT_PATH --gt_json_path $COCO_GT_JSON_PATH
42 |
--------------------------------------------------------------------------------
/evaluation/eval_count.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | MODEL_TYPE="vision_reasoner" # Model type: qwen or vision_reasoner or qwen2
5 | TEST_DATA_PATH=${1:-"Ricky06662/counting_pixmo_test"}
6 |
7 | # Extract model name and test dataset name for output directory
8 | TEST_NAME=$(echo $TEST_DATA_PATH | sed -E 's/.*\/([^\/]+)$/\1/')
9 | OUTPUT_PATH="./detection_eval_results/${MODEL_TYPE}/${TEST_NAME}"
10 |
11 | # Customize GPU array here - specify which GPUs to use
12 | GPU_ARRAY=(0 1 2 3 4 5 6 7) # Example: using GPUs 0, 1, 2, 3
13 | NUM_PARTS=${#GPU_ARRAY[@]}
14 |
15 | # Create output directory
16 | mkdir -p $OUTPUT_PATH
17 |
18 | # Run processes in parallel
19 | for i in $(seq 0 $((NUM_PARTS-1))); do
20 | gpu_id=${GPU_ARRAY[$i]}
21 | process_idx=$i # 0-based indexing for process
22 |
23 | export CUDA_VISIBLE_DEVICES=$gpu_id
24 | (
25 | python evaluation/evaluation_count.py \
26 | --model $MODEL_TYPE \
27 | --output_path $OUTPUT_PATH \
28 | --test_data_path $TEST_DATA_PATH \
29 | --idx $process_idx \
30 | --num_parts $NUM_PARTS \
31 | --batch_size 16 || { echo "1" > /tmp/process_status.$$; kill -TERM -$$; }
32 | ) &
33 | done
34 |
35 | # Wait for all processes to complete
36 | wait
37 |
38 | python evaluation/calculate_counting.py --output_dir $OUTPUT_PATH
--------------------------------------------------------------------------------
/evaluation/eval_segmentation.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | MODEL_TYPE="vision_reasoner" # Model type: qwen or vision_reasoner or qwen2
5 | TEST_DATA_PATH=${1:-"Ricky06662/refcocog_test"}
6 |
7 | # Extract model name and test dataset name for output directory
8 | TEST_NAME=$(echo $TEST_DATA_PATH | sed -E 's/.*\/([^\/]+)$/\1/')
9 | OUTPUT_PATH="./detection_eval_results/${MODEL_TYPE}/${TEST_NAME}"
10 |
11 | # Customize GPU array here - specify which GPUs to use
12 | GPU_ARRAY=(0 1 2 3 4 5 6 7) # Example: using GPUs 0, 1, 2, 3
13 | NUM_PARTS=${#GPU_ARRAY[@]}
14 |
15 | # Create output directory
16 | mkdir -p $OUTPUT_PATH
17 |
18 | # Run processes in parallel
19 | for i in $(seq 0 $((NUM_PARTS-1))); do
20 | gpu_id=${GPU_ARRAY[$i]}
21 | process_idx=$i # 0-based indexing for process
22 |
23 | export CUDA_VISIBLE_DEVICES=$gpu_id
24 | (
25 | python evaluation/evaluation_segmentation.py \
26 | --model $MODEL_TYPE \
27 | --output_path $OUTPUT_PATH \
28 | --test_data_path $TEST_DATA_PATH \
29 | --idx $process_idx \
30 | --num_parts $NUM_PARTS \
31 | --batch_size 16 || { echo "1" > /tmp/process_status.$$; kill -TERM -$$; }
32 | ) &
33 | done
34 |
35 | # Wait for all processes to complete
36 | wait
37 |
38 | python evaluation/calculate_iou_with_bbox.py --output_dir $OUTPUT_PATH
--------------------------------------------------------------------------------
/evaluation/evaluation_coco.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import json
4 | import numpy as np
5 | import os
6 | from datasets import load_from_disk, load_dataset
7 | from PIL import Image as PILImage
8 | from tqdm import tqdm
9 | import sys
10 | from scipy.optimize import linear_sum_assignment
11 |
12 | # Add the parent directory to the Python path to import model module
13 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14 | from vision_reasoner.models.vision_reasoner_model import VisionReasonerModel
15 | from vision_reasoner.models.qwen_vl import QwenVLModel
16 | def parse_args():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--model", type=str, default="vision_reasoner")
19 | parser.add_argument("--model_path", type=str, default="pretrained_models/VisionReasoner-7B", choices=["Ricky06662/VisionReasoner-7B", "Qwen/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen2-VL-7B-Instruct"])
20 | parser.add_argument("--task_router_model_path", type=str, default="pretrained_models/TaskRouter-1.5B")
21 | parser.add_argument("--segmentation_model_path", type=str, default="facebook/sam2-hiera-large")
22 | parser.add_argument("--output_path", type=str, required=True)
23 | parser.add_argument("--test_data_path", type=str, required=True)
24 | parser.add_argument("--batch_size", type=int, default=1)
25 |
26 | # for parallel evaluation
27 | parser.add_argument("--idx", type=int, required=True)
28 | parser.add_argument("--num_parts", type=int, required=True)
29 | return parser.parse_args()
30 |
31 | def compute_bbox_iou(bboxes1, bboxes2):
32 | """
33 | Calculate IOU matrix between two sets of bounding boxes
34 | bboxes1: shape (N, 4) prediction boxes
35 | bboxes2: shape (M, 4) ground truth boxes
36 | Returns: shape (N, M) IOU matrix
37 | """
38 | # Expand dimensions to support broadcasting
39 | bboxes1 = np.array(bboxes1)[:, None, :] # (N, 1, 4)
40 | bboxes2 = np.array(bboxes2)[None, :, :] # (1, M, 4)
41 |
42 | # Calculate intersection area
43 | x1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
44 | y1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
45 | x2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
46 | y2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
47 |
48 | # Calculate intersection area
49 | intersection = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1)
50 |
51 | # Calculate the areas of the two sets of bboxes
52 | area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
53 | area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
54 |
55 | # Calculate union area
56 | union = area1 + area2 - intersection
57 |
58 | # Avoid division by zero
59 | iou = np.where(union > 0, intersection / union, 0)
60 |
61 | return iou
62 |
63 | def main():
64 | args = parse_args()
65 |
66 | # Initialize model
67 | if args.model == "qwen":
68 | model = QwenVLModel(model_path=args.model_path)
69 | elif args.model == "qwen2":
70 | model = QwenVLModel(model_path=args.model_path)
71 | elif args.model == "vision_reasoner":
72 | model = VisionReasonerModel(reasoning_model_path=args.model_path,
73 | task_router_model_path=args.task_router_model_path,
74 | segmentation_model_path=args.segmentation_model_path)
75 |
76 | # Load dataset
77 | dataset = load_dataset(args.test_data_path, split="test")
78 | total_len = len(dataset)
79 | part_size = total_len // args.num_parts
80 | start_idx = args.idx * part_size
81 | end_idx = min(start_idx + part_size if args.idx < args.num_parts - 1 else total_len, total_len)
82 | range_list = range(start_idx, end_idx)
83 |
84 | dataset = dataset.select(range_list)
85 |
86 | all_outputs = []
87 |
88 | # Prepare batches
89 | for i in tqdm(range(0, len(dataset), args.batch_size), desc="Processing batches"):
90 | batch_data = [dataset[j] for j in range(i, min(i + args.batch_size, len(dataset)))]
91 |
92 | batch_images = [item["image"].convert("RGB") for item in batch_data]
93 | batch_questions = [item["text"].lower().strip(".\"?!") for item in batch_data]
94 | id_list = [{
95 | "image_id": item["image_id"],
96 | "ann_id": item["ann_id"],
97 | "img_height": item["img_height"],
98 | "img_width": item["img_width"],
99 | "bbox": item["bbox"],
100 | "cat_id": item["cat_id"]
101 | } for item in batch_data]
102 |
103 | process_batch(model, batch_images, batch_questions, id_list, all_outputs)
104 |
105 | # Save results
106 | output_file = os.path.join(args.output_path, f"output_{args.idx}.json")
107 | with open(output_file, "w") as f:
108 | json.dump(all_outputs, f, indent=2, ensure_ascii=False)
109 |
110 | def process_batch(model, batch_images, batch_questions, id_list, all_outputs):
111 | """Process a batch of images and questions"""
112 | batch_results = model.detect_objects_batch(batch_images, batch_questions)
113 | for i, result in enumerate(batch_results):
114 | try:
115 | thinking = result["thinking"]
116 | bboxes = result["bboxes"]
117 |
118 | gt_bboxes = id_list[i]["bbox"]
119 |
120 | if gt_bboxes and len(bboxes) > 0:
121 | # # Use vectorized calculation of IOU matrix
122 | # cost_matrix = -compute_bbox_iou(bboxes, gt_bboxes) # Use negative IOU as cost
123 |
124 | # # Use Hungarian algorithm for matching
125 | # pred_indices, gt_indices = linear_sum_assignment(cost_matrix)
126 |
127 | # # Assign scores to each predicted box
128 | # scores = np.zeros(len(bboxes))
129 | # for pred_idx, gt_idx in zip(pred_indices, gt_indices):
130 | # scores[pred_idx] = -cost_matrix[pred_idx, gt_idx] # Convert back to positive IOU value
131 |
132 | # Add results
133 | for pred_idx, pred_bbox in enumerate(bboxes):
134 | all_outputs.append({
135 | "image_id": int(id_list[i]["image_id"]),
136 | "ann_id": int(id_list[i]["ann_id"]),
137 | "think": thinking,
138 | "category_id": int(id_list[i]["cat_id"]),
139 | "bbox": pred_bbox,
140 | #"score": float(max(scores[pred_idx],0.0)) # Use the match score
141 | "score": float((pred_bbox[2]-pred_bbox[0])*(pred_bbox[3]-pred_bbox[1])/(id_list[i]["img_width"]*id_list[i]["img_height"]))
142 | })
143 | else:
144 | # If there are no ground truth boxes or predicted boxes, score is 0
145 | for pred_bbox in bboxes:
146 | all_outputs.append({
147 | "image_id": int(id_list[i]["image_id"]),
148 | "ann_id": int(id_list[i]["ann_id"]),
149 | "think": thinking,
150 | "category_id": int(id_list[i]["cat_id"]),
151 | "bbox": pred_bbox,
152 | "score": 0.0
153 | })
154 |
155 | except Exception as e:
156 | # raise
157 | print(f"Error processing result: {e}")
158 | # Skip this because the implementation is different from the original
159 | continue
160 |
161 | if __name__ == "__main__":
162 | main()
--------------------------------------------------------------------------------
/evaluation/evaluation_count.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import json
4 | import numpy as np
5 | import os
6 | from datasets import load_from_disk, load_dataset
7 | from tqdm import tqdm
8 | import sys
9 |
10 | # Add the parent directory to the Python path to import model module
11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12 | from vision_reasoner.models.vision_reasoner_model import VisionReasonerModel
13 | from vision_reasoner.models.qwen_vl import QwenVLModel
14 | def parse_args():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--model", type=str, default="vision_reasoner")
17 | parser.add_argument("--model_path", type=str, default="pretrained_models/VisionReasoner-7B", choices=["Ricky06662/VisionReasoner-7B", "Qwen/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen2-VL-7B-Instruct"])
18 | parser.add_argument("--task_router_model_path", type=str, default="pretrained_models/TaskRouter-1.5B")
19 | parser.add_argument("--segmentation_model_path", type=str, default="facebook/sam2-hiera-large")
20 | parser.add_argument("--output_path", type=str, required=True)
21 | parser.add_argument("--test_data_path", type=str, required=True)
22 | parser.add_argument("--batch_size", type=int, default=1)
23 |
24 | # for parallel evaluation
25 | parser.add_argument("--idx", type=int, required=True)
26 | parser.add_argument("--num_parts", type=int, required=True)
27 | return parser.parse_args()
28 |
29 | def main():
30 | args = parse_args()
31 |
32 | # Initialize model
33 | if args.model == "qwen":
34 | model = QwenVLModel(model_path=args.model_path)
35 | elif args.model == "qwen2":
36 | model = QwenVLModel(model_path=args.model_path)
37 | elif args.model == "vision_reasoner":
38 | model = VisionReasonerModel(reasoning_model_path=args.model_path,
39 | task_router_model_path=args.task_router_model_path,
40 | segmentation_model_path=args.segmentation_model_path)
41 |
42 | # Load dataset
43 | dataset = load_dataset(args.test_data_path, split="test")
44 | total_len = len(dataset)
45 | part_size = total_len // args.num_parts
46 | start_idx = args.idx * part_size
47 | end_idx = start_idx + part_size if args.idx < args.num_parts - 1 else total_len
48 |
49 | dataset = dataset.select(range(start_idx, end_idx))
50 | all_outputs = []
51 |
52 | # Process in batches
53 | for i in tqdm(range(0, len(dataset), args.batch_size), desc="Processing batches"):
54 | batch_data = [dataset[j] for j in range(i, min(i + args.batch_size, len(dataset)))]
55 |
56 | batch_images = [item["image"].convert("RGB") for item in batch_data]
57 | batch_questions = [item["text"].lower().strip(".\"?!") for item in batch_data]
58 | id_list = [{
59 | "image_id": item["image_id"],
60 | "ann_id": item["ann_id"],
61 | "img_height": item["img_height"],
62 | "img_width": item["img_width"],
63 | "gt_count": item["count"]
64 | } for item in batch_data]
65 | process_batch(model, batch_images, batch_questions, id_list, all_outputs)
66 |
67 | # Save results
68 | output_file = os.path.join(args.output_path, f"output_{args.idx}.json")
69 | with open(output_file, "w") as f:
70 | json.dump(all_outputs, f, indent=2, ensure_ascii=False)
71 |
72 | def process_batch(model, batch_images, batch_questions, id_list, all_outputs):
73 | """Process a batch of images and questions"""
74 | batch_results = model.count_objects_batch(batch_images, batch_questions)
75 |
76 | for i, result in enumerate(batch_results):
77 | try:
78 | thinking = result["thinking"]
79 | bboxes = result["bboxes"]
80 | pred_count = result["count"]
81 |
82 | all_outputs.append({
83 | "image_id": id_list[i]["image_id"],
84 | "ann_id": id_list[i]["ann_id"],
85 | "think": thinking,
86 | "pred_count": pred_count,
87 | "gt_count": id_list[i]["gt_count"]
88 | })
89 |
90 | except Exception as e:
91 | # raise
92 | print(f"Error processing result: {e}")
93 | # Add penalty in this situation
94 | all_outputs.append({
95 | "image_id": id_list[i]["image_id"],
96 | "ann_id": id_list[i]["ann_id"],
97 | "think": "",
98 | "pred_count": 1,
99 | "gt_count": id_list[i]["gt_count"]
100 | })
101 |
102 | print(f"Processed batch of {len(batch_images)} images")
103 |
104 |
105 | if __name__ == "__main__":
106 | main()
--------------------------------------------------------------------------------
/evaluation/evaluation_segmentation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import json
4 | import numpy as np
5 | import os
6 | from datasets import load_from_disk, load_dataset
7 | from PIL import Image as PILImage
8 | from tqdm import tqdm
9 | import sys
10 |
11 | # Add the parent directory to the Python path to import model module
12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13 | from vision_reasoner.models.vision_reasoner_model import VisionReasonerModel
14 | from vision_reasoner.models.qwen_vl import QwenVLModel
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--model", type=str, default="vision_reasoner")
19 | parser.add_argument("--model_path", type=str, default="pretrained_models/VisionReasoner-7B", choices=["Ricky06662/VisionReasoner-7B", "Qwen/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen2-VL-7B-Instruct"])
20 | parser.add_argument("--task_router_model_path", type=str, default="pretrained_models/TaskRouter-1.5B")
21 | parser.add_argument("--segmentation_model_path", type=str, default="facebook/sam2-hiera-large")
22 | parser.add_argument("--output_path", type=str, required=True)
23 | parser.add_argument("--test_data_path", type=str, required=True)
24 | parser.add_argument("--batch_size", type=int, default=1)
25 |
26 | # for parallel evaluation
27 | parser.add_argument("--idx", type=int, required=True)
28 | parser.add_argument("--num_parts", type=int, required=True)
29 | return parser.parse_args()
30 |
31 | def compute_iou(mask1, mask2):
32 | intersection = np.logical_and(mask1, mask2).sum()
33 | union = np.logical_or(mask1, mask2).sum()
34 | if union == 0:
35 | return 0, 0
36 | return intersection, union
37 |
38 | def compute_bbox_iou(bbox1, bbox2):
39 | # Calculate the intersection area of two bboxes
40 | x1 = max(bbox1[0], bbox2[0])
41 | y1 = max(bbox1[1], bbox2[1])
42 | x2 = min(bbox1[2], bbox2[2])
43 | y2 = min(bbox1[3], bbox2[3])
44 |
45 | # Calculate intersection area
46 | intersection = max(0, x2 - x1) * max(0, y2 - y1)
47 |
48 | # Calculate areas of the two bboxes
49 | area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
50 | area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
51 |
52 | # Calculate union area
53 | union = area1 + area2 - intersection
54 |
55 | # Avoid division by zero
56 | if union == 0:
57 | return 0
58 |
59 | return intersection / union
60 |
61 | def main():
62 | args = parse_args()
63 |
64 | # Initialize model
65 | if args.model == "qwen":
66 | model = QwenVLModel(model_path=args.model_path)
67 | elif args.model == "qwen2":
68 | model = QwenVLModel(model_path=args.model_path)
69 | elif args.model == "vision_reasoner":
70 | model = VisionReasonerModel(reasoning_model_path=args.model_path,
71 | task_router_model_path=args.task_router_model_path,
72 | segmentation_model_path=args.segmentation_model_path)
73 |
74 | # Load dataset
75 | dataset = load_dataset(args.test_data_path, split="test")
76 | total_len = len(dataset)
77 | part_size = total_len // args.num_parts
78 | start_idx = args.idx * part_size
79 | end_idx = start_idx + part_size if args.idx < args.num_parts - 1 else total_len
80 |
81 | dataset = dataset.select(range(start_idx, end_idx))
82 |
83 | # Check if dataset has bbox information
84 | has_bbox = 'bbox' in dataset[0]
85 |
86 | all_outputs = []
87 |
88 | # Prepare batches
89 | for i in tqdm(range(0, len(dataset), args.batch_size), desc="Processing batches"):
90 | batch_data = [dataset[j] for j in range(i, min(i + args.batch_size, len(dataset)))]
91 |
92 | batch_images = [item["image"].convert("RGB") for item in batch_data]
93 | batch_questions = [item["text"].lower().strip(".\"?!") for item in batch_data]
94 | id_list = [{
95 | "image_id": item["image_id"],
96 | "ann_id": item["ann_id"],
97 | "mask": item["mask"],
98 | "img_height": item["img_height"],
99 | "img_width": item["img_width"],
100 | "bbox": item["bbox"] if has_bbox else None
101 | } for item in batch_data]
102 |
103 | process_batch(model, batch_images, batch_questions, id_list, all_outputs, has_bbox)
104 |
105 | # Save results
106 | output_file = os.path.join(args.output_path, f"output_{args.idx}.json")
107 | with open(output_file, "w") as f:
108 | json.dump(all_outputs, f, indent=2, ensure_ascii=False)
109 |
110 | def process_batch(model, batch_images, batch_questions, id_list, all_outputs, has_bbox):
111 | """Process a batch of images and questions"""
112 | batch_results = model.segment_objects_batch(batch_images, batch_questions)
113 |
114 | for i, result in enumerate(batch_results):
115 | try:
116 | thinking = result["thinking"]
117 | bboxes = result["bboxes"]
118 | mask_all = result["masks"]
119 | gt_mask = np.array(id_list[i]["mask"])
120 |
121 | intersection, union = compute_iou(mask_all, gt_mask)
122 |
123 | bbox_iou = 0.0
124 | if has_bbox:
125 | try:
126 | gt_bbox = id_list[i]["bbox"]
127 | for pred_bbox in bboxes:
128 | if compute_bbox_iou(pred_bbox, gt_bbox) > 0.5:
129 | bbox_iou = 1.0
130 | break
131 | except Exception as e:
132 | print(f"Bbox error: {e}, Image ID: {id_list[i]['image_id']}, Ann ID: {id_list[i]['ann_id']}")
133 | bbox_iou = 0.0
134 |
135 | all_outputs.append({
136 | "image_id": id_list[i]["image_id"],
137 | "ann_id": id_list[i]["ann_id"],
138 | "think": thinking,
139 | "intersection": int(intersection),
140 | "union": int(union),
141 | "bbox_iou": bbox_iou
142 | })
143 |
144 | except Exception as e:
145 | print(f"Error processing result: {e}")
146 | # Add penalty in this situation
147 | all_outputs.append({
148 | "image_id": id_list[i]["image_id"],
149 | "ann_id": id_list[i]["ann_id"],
150 | "think": "",
151 | "intersection": 0,
152 | "union": np.array(id_list[i]["mask"]).sum(),
153 | "bbox_iou": 0.0
154 | })
155 |
156 |
157 | if __name__ == "__main__":
158 | main()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | codetiming
3 | datasets
4 | flash-attn
5 | liger-kernel
6 | mathruler
7 | numpy
8 | omegaconf
9 | pandas
10 | peft
11 | pillow
12 | pyarrow>=15.0.0
13 | pylatexenc
14 | qwen-vl-utils
15 | scipy
16 | matplotlib
17 | ipykernel
18 | pycocotools
19 | scikit-image
20 | ray
21 | tensordict
22 | transformers
23 | vllm
24 | wandb
25 | sam2
26 | ultralytics
27 | openai
--------------------------------------------------------------------------------
/task_categorization.md:
--------------------------------------------------------------------------------
1 | # Current task types
2 | > This list shows how we reformulating existing visual perception tasks. We are working on it to support more task types.
3 |
4 | ## Detection
5 |
6 | - [Visual Grounding](https://paperswithcode.com/datasets?mod=images&task=visual-grounding)
7 | - [Object Detection](https://paperswithcode.com/task/2d-object-detection/latest)
8 | - [2D Object Detection](https://paperswithcode.com/task/2d-object-detection/latest)
9 | - [Small Object Detection](https://paperswithcode.com/task/object-detection)
10 | - [Defect Detection](https://paperswithcode.com/task/defect-detection)
11 | - [Face Detection](https://astro.paperswithcode.com/task/occluded-face-detection)
12 | - [License Plate Detection](https://paperswithcode.com/task/license-plate-detection)
13 | - [Anomaly Detection](https://paperswithcode.com/task/anomaly-detection)
14 | - [Human Detection](https://paperswithcode.com/task/human-detection)
15 | - [Surgical Tool Detection](https://paperswithcode.com/task/surgical-tool-detection)
16 | - [Dense Object Detection](https://paperswithcode.com/task/dense-object-detection)
17 | - [Open World Object Detection](https://paperswithcode.com/task/open-world-object-detection)
18 | - [Zero-Shot Object Detection](https://paperswithcode.com/task/zero-shot-object-detection)
19 | - [Animal Action Recognition](https://paperswithcode.com/task/animal-action-recognition)
20 | - [Robotic Grasping](https://paperswithcode.com/task/robotic-grasping)
21 | - [Object Localization](https://paperswithcode.com/task/object-localization)
22 | - [Hand Detection](https://paperswithcode.com/task/hand-detection)
23 | - [Visual Relationship Detection](https://paperswithcode.com/task/visual-relationship-detection)
24 | - [Open Vocabulary Object Detection](https://paperswithcode.com/task/open-vocabulary-object-detection)
25 | - [Oriented Object Detection](https://paperswithcode.com/task/oriented-object-detection)
26 | - [Object Detection in Indoor Scenes](https://paperswithcode.com/task/object-detection-in-indoor-scenes)
27 | - [Object Detection in Aerial Images](https://paperswithcode.com/task/object-detection-in-aerial-images)
28 | - [Person Search](https://paperswithcode.com/task/person-search)
29 | - [Object Recognition](https://paperswithcode.com/datasets?mod=images&page=1)
30 |
31 |
32 | ## Segmentation
33 |
34 |
35 | - [Semantic Segmentation](https://paperswithcode.com/task/semantic-segmentation)
36 | - [Instance Segmentation](https://paperswithcode.com/task/3d-instance-segmentation-1)
37 | - [Lane Detection](https://paperswithcode.com/task/lane-detection/social)
38 | - [2D Semantic Segmentation](https://paperswithcode.com/task/2d-semantic-segmentation)
39 | - [Medical Image Segmentation](https://paperswithcode.com/task/medical-image-segmentation/latest)
40 | - [Human Part Segmentation](https://paperswithcode.com/task/human-part-segmentation)
41 | - [Action Segmentation](https://paperswithcode.com/task/action-segmentation)
42 | - [Video Object Segmentation](https://paperswithcode.com/task/interactive-video-object-segmentation)
43 | - [Referring Expression Segmentation](https://paperswithcode.com/task/referring-expression-segmentation)
44 | - [Saliency Detection](https://paperswithcode.com/task/saliency-detection/latest)
45 | - [Salient Object Detection](https://paperswithcode.com/task/salient-object-detection)
46 | - [The Semantic Segmentation of Remote Sensing Imagery](https://paperswithcode.com/task/the-semantic-segmentation-of-remote-sensing)
47 | - [Crack Segmentation](https://paperswithcode.com/task/crack-segmentation)
48 | - [Action Unit Detection](https://paperswithcode.com/task/action-unit-detection)
49 | - [RGB Salient Object Detection](https://paperswithcode.com/task/salient-object-detection)
50 | - [Boundary Detection](https://paperswithcode.com/task/boundary-detection)
51 | - [Crack Segmentation for Infrastructure](https://paperswithcode.com/task/crack-segmentation)
52 | - [Surgical Tool Segmentation](https://paperswithcode.com/task/surgical-tool-detection)
53 |
54 | ## Counting
55 |
56 | - [Object Counting](https://paperswithcode.com/task/object-counting)
57 | - [Crowd Counting](https://paperswithcode.com/task/crowd-counting)
58 | - [Density Estimation](https://paperswithcode.com/task/density-estimation)
59 | - [Pedestrian Detection](https://paperswithcode.com/task/pedestrian-detection)
60 | - [Crowd Estimation in Dense Scenes](https://paperswithcode.com/task/crowd-counting/codeless)
61 | - [Traffic Counting in Surveillance](https://paperswithcode.com/task/crowd-counting)
62 |
63 | ## VQA
64 |
65 | - [Visual Question Answering (VQA)](https://paperswithcode.com/datasets?mod=images&task=visual-question-answering)
66 | - [Classification](https://paperswithcode.com/datasets?mod=images&task=classification-1)
67 | - [Image Captioning](https://paperswithcode.com/datasets?mod=images&task=image-captioning)
68 | - [Question Answering](https://paperswithcode.com/datasets?mod=images&task=question-answering)
69 | - [Visual Reasoning](https://paperswithcode.com/datasets?mod=images&task=visual-reasoning)
70 | - [Visual Question Answering](https://paperswithcode.com/datasets?mod=images&task=visual-question-answering-1)
71 | - [Relational Reasoning](https://paperswithcode.com/datasets?mod=images&task=relational-reasoning)
72 |
73 |
--------------------------------------------------------------------------------
/vision_reasoner/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/VisionReasoner/a455092a8d57d96f5d6e77597fb5a8741b8f3880/vision_reasoner/__init__.py
--------------------------------------------------------------------------------
/vision_reasoner/inference.py:
--------------------------------------------------------------------------------
1 | # test_qwen_vl_model.py
2 | import argparse
3 | from PIL import Image
4 | import os
5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
6 | from models.vision_reasoner_model import VisionReasonerModel
7 | from utils import visualize_results_enhanced
8 |
9 | def main():
10 | parser = argparse.ArgumentParser(description="Test unified vision model on a single image")
11 | parser.add_argument("--model_path", type=str, default='pretrained_models/VisionReasoner-7B', help="Path to the model")
12 | parser.add_argument("--task_router_model_path", type=str, default="pretrained_models/TaskRouter-1.5B")
13 | parser.add_argument("--segmentation_model_path", type=str, default="facebook/sam2-hiera-large")
14 | parser.add_argument("--image_path", type=str, default="assets/airplanes.png", help="Path to the input image")
15 | parser.add_argument("--query", type=str, default="How many airplanes are there in this image?", help="Query/instruction for the model")
16 | parser.add_argument("--task", type=str, choices=["auto", "detection", "segmentation", "counting", "vqa", "generation"],
17 | default="auto", help="Task type (default: auto)")
18 | parser.add_argument("--output_path", type=str, default="result_visualization.png", help="Path to save the output visualization")
19 | parser.add_argument("--hybrid_mode", action="store_true", help="Whether to use YOLO for object detection")
20 | parser.add_argument("--yolo_model_path", type=str, default="yolov8x-worldv2.pt", help="Path to the YOLO model")
21 | parser.add_argument("--refer_image_path", type=str, default="", help="Path to the reference image")
22 | parser.add_argument("--image_prompt", type=str, default="", help="Prompt for image generation")
23 | parser.add_argument("--generation_mode", action="store_true", help="Whether to use generation model")
24 | parser.add_argument("--generation_model_name", type=str, default="gpt-image-1", help="Name of the generation model")
25 | args = parser.parse_args()
26 |
27 | # Determine task type
28 | if args.image_prompt != "":
29 | assert args.generation_mode, "Please set --generation_mode to True when using image prompt"
30 | task_type = "generation"
31 | else:
32 | task_type = args.task
33 |
34 | # Load model
35 | print(f"Loading model from {args.model_path}...")
36 | if args.generation_mode:
37 | model = VisionReasonerModel(reasoning_model_path=args.model_path,
38 | task_router_model_path=args.task_router_model_path,
39 | segmentation_model_path=args.segmentation_model_path,
40 | generation_model_path=args.generation_model_name)
41 | elif args.hybrid_mode:
42 | model = VisionReasonerModel(reasoning_model_path=args.model_path,
43 | task_router_model_path=args.task_router_model_path,
44 | segmentation_model_path=args.segmentation_model_path,
45 | yolo_model_path=args.yolo_model_path)
46 | else:
47 | model = VisionReasonerModel(reasoning_model_path=args.model_path,
48 | task_router_model_path=args.task_router_model_path,
49 | segmentation_model_path=args.segmentation_model_path)
50 |
51 |
52 | if task_type != "generation":
53 | # Load image
54 | print(f"Loading image from {args.image_path}...")
55 | image = Image.open(args.image_path).convert("RGB")
56 |
57 | if task_type == "auto":
58 | result, task_type = model.process_single_image(image, args.query, return_task_type=True)
59 | elif task_type == "detection":
60 | result = model.detect_objects(image, args.query)
61 | elif task_type == "segmentation":
62 | result = model.segment_objects(image, args.query)
63 | elif task_type == "counting":
64 | result = model.count_objects(image, args.query)
65 | elif task_type == "generation":
66 | result = model.generate_image(args.refer_image_path, args.image_prompt)
67 | else: # VQA
68 | result = model.answer_question(image, args.query)
69 |
70 | # Print results
71 | print("\n===== Results =====")
72 | print("Task type: ", task_type)
73 | if args.image_prompt != "":
74 | print("User prompt: ", args.image_prompt)
75 | else:
76 | print("User question: ", args.query)
77 | if 'thinking' in result and result['thinking'].strip() != "":
78 | print("Thinking process: ", result['thinking'])
79 |
80 | # print("Response: ", result)
81 |
82 | if task_type == "detection":
83 | print(f"Total number of detected objects: {len(result['bboxes'])}")
84 | elif task_type == "segmentation":
85 | print(f"Total number of segmented objects: {len(result['bboxes'])}")
86 | elif task_type == "counting":
87 | print(f"Total number of interested objects is: {result['count']}")
88 | elif task_type == "generation":
89 | result.save(args.output_path, format="PNG")
90 | print(f"The generated image is saved as '{args.output_path}'")
91 | else: # QA
92 | print(f"The answer is: {result['answer']}")
93 |
94 | if task_type != "generation" and task_type != "vqa" and task_type != "counting":
95 | # Visualize results with the new three-panel layout
96 | visualize_results_enhanced(image, result, task_type, args.output_path)
97 | print(f"\nResult visualization saved as '{args.output_path}'")
98 |
99 |
100 |
101 | if __name__ == "__main__":
102 | main()
--------------------------------------------------------------------------------
/vision_reasoner/models/base_model.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | class BaseVisionModel(ABC):
4 | """Abstract base class for vision models that process images and instructions"""
5 |
6 | @abstractmethod
7 | def process_single_image(self, image, instruction):
8 | """
9 | Process a single image and instruction
10 |
11 | Args:
12 | image: Input image
13 | instruction: Text instruction/query
14 |
15 | Returns:
16 | dict: Results dictionary
17 | """
18 | pass
19 |
20 | @abstractmethod
21 | def process_batch(self, batch_images, batch_instructions):
22 | """
23 | Process a batch of images and instructions
24 |
25 | Args:
26 | batch_images: List of input images
27 | batch_instructions: List of text instructions/queries
28 |
29 | Returns:
30 | list: List of result dictionaries
31 | """
32 | pass
33 |
34 |
35 | class DetectionModel(ABC):
36 | """Interface for object detection tasks"""
37 |
38 | @abstractmethod
39 | def detect_objects(self, image, query):
40 | """
41 | Detect objects in an image based on a query
42 |
43 | Args:
44 | image: Input image
45 | query: Text query describing what to detect
46 |
47 | Returns:
48 | dict: Results containing at least:
49 | - bboxes: List of bounding boxes [x1, y1, x2, y2]
50 | - scores: List of confidence scores
51 | - thinking: Reasoning process (if available)
52 | """
53 | pass
54 |
55 | @abstractmethod
56 | def detect_objects_batch(self, images, queries):
57 | """
58 | Detect objects in a batch of images
59 |
60 | Args:
61 | images: List of input images
62 | queries: List of text queries
63 |
64 | Returns:
65 | list: List of result dictionaries
66 | """
67 | pass
68 |
69 |
70 | class SegmentationModel(ABC):
71 | """Interface for segmentation tasks"""
72 |
73 | @abstractmethod
74 | def segment_objects(self, image, query):
75 | """
76 | Segment objects in an image based on a query
77 |
78 | Args:
79 | image: Input image
80 | query: Text query describing what to segment
81 |
82 | Returns:
83 | dict: Results containing at least:
84 | - masks: Segmentation masks
85 | - bboxes: List of bounding boxes
86 | - thinking: Reasoning process (if available)
87 | """
88 | pass
89 |
90 | @abstractmethod
91 | def segment_objects_batch(self, images, queries):
92 | """
93 | Segment objects in a batch of images
94 |
95 | Args:
96 | images: List of input images
97 | queries: List of text queries
98 |
99 | Returns:
100 | list: List of result dictionaries
101 | """
102 | pass
103 |
104 |
105 | class CountingModel(ABC):
106 | """Interface for counting tasks"""
107 |
108 | @abstractmethod
109 | def count_objects(self, image, query):
110 | """
111 | Count objects in an image based on a query
112 |
113 | Args:
114 | image: Input image
115 | query: Text query describing what to count
116 |
117 | Returns:
118 | dict: Results containing at least:
119 | - count: Number of objects
120 | - bboxes: List of bounding boxes (optional)
121 | - thinking: Reasoning process (if available)
122 | """
123 | pass
124 |
125 | @abstractmethod
126 | def count_objects_batch(self, images, queries):
127 | """
128 | Count objects in a batch of images
129 |
130 | Args:
131 | images: List of input images
132 | queries: List of text queries
133 |
134 | Returns:
135 | list: List of result dictionaries
136 | """
137 | pass
138 |
139 |
140 | class QAModel(ABC):
141 | """Interface for visual question answering tasks"""
142 |
143 | @abstractmethod
144 | def answer_question(self, image, question):
145 | """
146 | Answer a question about an image
147 |
148 | Args:
149 | image: Input image
150 | question: Text question
151 |
152 | Returns:
153 | dict: Results containing at least:
154 | - answer: Text answer
155 | - thinking: Reasoning process (if available)
156 | """
157 | pass
158 |
159 | @abstractmethod
160 | def answer_questions_batch(self, images, questions):
161 | """
162 | Answer questions about a batch of images
163 |
164 | Args:
165 | images: List of input images
166 | questions: List of text questions
167 |
168 | Returns:
169 | list: List of result dictionaries
170 | """
171 | pass
--------------------------------------------------------------------------------
/vision_reasoner/models/qwen_vl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import re
4 | import json
5 | from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration
6 | from PIL import Image as PILImage
7 | from sam2.sam2_image_predictor import SAM2ImagePredictor
8 | from qwen_vl_utils import process_vision_info
9 | from .base_model import (
10 | BaseVisionModel,
11 | DetectionModel,
12 | SegmentationModel,
13 | CountingModel,
14 | QAModel
15 | )
16 |
17 | class QwenVLModel(BaseVisionModel, DetectionModel, SegmentationModel, CountingModel, QAModel):
18 | """
19 | QwenVL model implementing all task interfaces with custom prompts
20 | """
21 |
22 | def __init__(self, model_path='Qwen/Qwen2.5-VL-7B-Instruct',
23 | segmentation_model_path="facebook/sam2-hiera-large"):
24 | """
25 | Initialize the QwenVL model
26 |
27 | Args:
28 | model_path (str): Path to the model
29 | """
30 | self.model_path = model_path
31 |
32 | # Initialize model
33 | if 'Qwen2.5' in model_path:
34 | self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
35 | model_path,
36 | torch_dtype=torch.float16,
37 | device_map="auto",
38 | trust_remote_code=True
39 | )
40 | else:
41 | self.model = Qwen2VLForConditionalGeneration.from_pretrained(
42 | model_path,
43 | torch_dtype=torch.float16,
44 | device_map="auto",
45 | trust_remote_code=True
46 | )
47 | self.model.eval()
48 |
49 | # Initialize processor
50 | self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
51 |
52 | # Initialize segmentation model
53 | self.segmentation_model = SAM2ImagePredictor.from_pretrained(segmentation_model_path)
54 |
55 | # Task-specific prompts
56 | if 'Qwen2.5' in model_path:
57 | self.DETECTION_PROMPT = """
58 | Locate "{query}", report the bboxes coordinates in JSON format.
59 | """
60 | else:
61 | self.DETECTION_PROMPT = """
62 | Please identify "{query}" in the image.
63 | Return a JSON array where each object is represented as:
64 | {{"bbox": [x1, y1, x2, y2], "label": label}}
65 |
66 | Your response should be formatted as:
67 | ```json
68 | [
69 | {{"bbox": [x1, y1, x2, y2], "label": label}},
70 | ...
71 | ]
72 | ```
73 | """
74 | self.COUNTING_PROMPT = """
75 | Locate "{query}", report the bboxes coordinates in JSON format.
76 | """
77 |
78 | self.QA_PROMPT = """{query}"""
79 |
80 | def extract_json_from_response(self, response_text):
81 | """
82 | Extract JSON content from model response using triple quotes
83 |
84 | Args:
85 | response_text (str): Model response text
86 |
87 | Returns:
88 | dict or list: Parsed JSON content
89 | """
90 | json_pattern = r"```json\s*(.*?)\s*```"
91 | match = re.search(json_pattern, response_text, re.DOTALL)
92 |
93 | if match:
94 | json_str = match.group(1).strip()
95 | try:
96 | return json.loads(json_str)
97 | except json.JSONDecodeError:
98 | print(f"Error parsing JSON: {json_str}")
99 | return None
100 |
101 | # Fallback: try to find any JSON-like structure
102 | try:
103 | # Look for arrays
104 | array_pattern = r'\[(.*?)\]'
105 | array_match = re.search(array_pattern, response_text, re.DOTALL)
106 | if array_match:
107 | return json.loads(f"[{array_match.group(1)}]")
108 |
109 | # Look for objects
110 | object_pattern = r'\{(.*?)\}'
111 | object_match = re.search(object_pattern, response_text, re.DOTALL)
112 | if object_match:
113 | return json.loads(f"{{{object_match.group(1)}}}")
114 | except:
115 | pass
116 |
117 | return None
118 |
119 | def generate_response(self, image, prompt):
120 | """
121 | Generate response from the model
122 |
123 | Args:
124 | image (PIL.Image): Input image
125 | prompt (str): Text prompt
126 |
127 | Returns:
128 | str: Model response
129 | """
130 | # Resize image while maintaining aspect ratio
131 | messages = [
132 | {
133 | "role": "system",
134 | "content": "You are a helpful assistant"
135 | },
136 | {
137 | "role": "user",
138 | "content": [
139 | {
140 | "type": "image",
141 | "image": image
142 | },
143 | {
144 | "type": "text",
145 | "text": prompt
146 | }
147 | ]
148 | }
149 | ]
150 | text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
151 |
152 | image_inputs, video_inputs = process_vision_info(messages)
153 | inputs = self.processor(
154 | text=[text],
155 | images=image_inputs,
156 | videos=video_inputs,
157 | padding=True,
158 | return_tensors="pt",
159 | ).to('cuda')
160 |
161 |
162 | # Generate response
163 | with torch.inference_mode():
164 | output_ids = self.model.generate(
165 | **inputs,
166 | max_new_tokens=1024,
167 | do_sample=False
168 | )
169 |
170 | generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
171 | output_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
172 |
173 | input_height = inputs['image_grid_thw'][0][1]*14
174 | input_width = inputs['image_grid_thw'][0][2]*14
175 |
176 | return output_text[0], (input_height, input_width)
177 |
178 | def _determine_task_type(self, instruction):
179 | """
180 | Determine the task type based on the instruction
181 |
182 | Args:
183 | instruction (str): Text instruction or query
184 |
185 | Returns:
186 | str: Task type ("detection", "segmentation", "counting", or "qa")
187 | """
188 | instruction_lower = instruction.lower()
189 |
190 | if "how many" in instruction_lower:
191 | return "counting"
192 | elif any(word in instruction_lower for word in ["segment", "mask", "outline"]):
193 | return "segmentation"
194 | elif any(word in instruction_lower for word in ["find", "locate", "detect", "identify"]):
195 | return "detection"
196 | else:
197 | return "vqa"
198 |
199 | def process_single_image(self, image, instruction, return_task_type=False):
200 | """
201 | Process a single image with given instruction
202 |
203 | Args:
204 | image (PIL.Image): Input image
205 | instruction (str): Text instruction or query
206 |
207 | Returns:
208 | dict: Results dictionary
209 | """
210 | task_type = self._determine_task_type(instruction)
211 |
212 |
213 |
214 | if task_type == "detection":
215 | result = self.detect_objects(image, instruction)
216 | elif task_type == "segmentation":
217 | result = self.segment_objects(image, instruction)
218 | elif task_type == "counting":
219 | result = self.count_objects(image, instruction)
220 | else: # Default to QA
221 | result = self.answer_question(image, instruction)
222 |
223 | if return_task_type:
224 | return result, task_type
225 | else:
226 | return result
227 |
228 | def process_batch(self, batch_images, batch_instructions):
229 | """
230 | Process a batch of images with given instructions
231 |
232 | Args:
233 | batch_images (list): List of PIL Images
234 | batch_instructions (list): List of text instructions or queries
235 |
236 | Returns:
237 | list: List of result dictionaries
238 | """
239 | results = []
240 | for image, instruction in zip(batch_images, batch_instructions):
241 | result = self.process_single_image(image, instruction)
242 | results.append(result)
243 | return results
244 |
245 | def detect_objects(self, image, query):
246 | """
247 | Detect objects in an image based on a query
248 |
249 | Args:
250 | image: Input image
251 | query: Text query describing what to detect
252 |
253 | Returns:
254 | dict: Results with bounding boxes and scores
255 | """
256 | prompt = self.DETECTION_PROMPT.format(query=query)
257 | response, scale_factors = self.generate_response(image, prompt)
258 |
259 | # Get original image dimensions
260 | img_height, img_width = image.size[1], image.size[0]
261 |
262 | json_data = self.extract_json_from_response(response)
263 |
264 | bboxes = []
265 | points = []
266 | scores = []
267 |
268 | if json_data and isinstance(json_data, list):
269 | for item in json_data:
270 | for key in ['bbox', 'bbox_2d']:
271 | if key in item:
272 | # For Qwen2-VL, convert from normalized coordinates (0-1000) to actual image coordinates
273 | if 'Qwen2.5' not in self.model_path:
274 | bbox = [
275 | int(item[key][0] * img_width / 1000),
276 | int(item[key][1] * img_height / 1000),
277 | int(item[key][2] * img_width / 1000),
278 | int(item[key][3] * img_height / 1000)
279 | ]
280 | else:
281 | # Original scaling for Qwen2.5-VL
282 | bbox = [
283 | int(item[key][0]),
284 | int(item[key][1]),
285 | int(item[key][2]),
286 | int(item[key][3])
287 | ]
288 | bboxes.append(bbox)
289 |
290 | for key in ['point', 'point_2d']:
291 | if key in item:
292 | # Similarly handle points
293 | if 'Qwen2.5' not in self.model_path:
294 | point = [
295 | int(item[key][0] * img_width / 1000),
296 | int(item[key][1] * img_height / 1000)
297 | ]
298 | else:
299 | point = [
300 | int(item[key][0]),
301 | int(item[key][1])
302 | ]
303 | points.append(point)
304 |
305 | for key in ['score', 'score_2d']:
306 | if key in item:
307 | scores.append(item[key])
308 |
309 | if len(scores) == 0:
310 | scores.append(0.0)
311 |
312 | return {
313 | "bboxes": bboxes,
314 | "points": points,
315 | "scores": scores,
316 | "thinking": "",
317 | "full_response": response,
318 | "json_data": json_data
319 | }
320 |
321 | def detect_objects_batch(self, images, queries):
322 | """
323 | Detect objects in a batch of images
324 |
325 | Args:
326 | images: List of input images
327 | queries: List of text queries
328 |
329 | Returns:
330 | list: List of detection results
331 | """
332 | results = []
333 | for image, query in zip(images, queries):
334 | result = self.detect_objects(image, query)
335 | results.append(result)
336 | return results
337 |
338 |
339 | def generate_masks(self, image, bboxes, points):
340 | """
341 | Generate segmentation masks for given image, bounding boxes and points
342 |
343 | Args:
344 | image (PIL.Image): Input image
345 | bboxes (list): List of bounding boxes
346 | points (list): List of points
347 |
348 | Returns:
349 | numpy.ndarray: Combined segmentation mask
350 | """
351 | img_height, img_width = image.height, image.width
352 | mask_all = np.zeros((img_height, img_width), dtype=bool)
353 |
354 | if not bboxes:
355 | return mask_all
356 |
357 | try:
358 | self.segmentation_model.set_image(image)
359 | if not points:
360 | points = []
361 | if len(points) != len(bboxes):
362 | points.extend([None] * (len(bboxes) - len(points)))
363 |
364 | for bbox, point in zip(bboxes, points):
365 | masks, scores, _ = self.segmentation_model.predict(
366 | box=bbox
367 | )
368 | sorted_ind = np.argsort(scores)[::-1]
369 | masks = masks[sorted_ind]
370 | mask = masks[0].astype(bool)
371 | mask_all = np.logical_or(mask_all, mask)
372 |
373 | return mask_all
374 | except Exception as e:
375 | print(f"Error generating masks: {e}")
376 | return mask_all
377 |
378 | def segment_objects(self, image, query):
379 | """
380 | Segment objects in an image based on a query
381 |
382 | Args:
383 | image: Input image
384 | query: Text query describing what to segment
385 |
386 | Returns:
387 | dict: Results with masks and bounding boxes
388 | """
389 | try:
390 | prompt = self.DETECTION_PROMPT.format(query=query)
391 | response, scale_factors = self.generate_response(image, prompt)
392 | img_height, img_width = image.size[1], image.size[0]
393 | x_factor, y_factor = (1, 1) if 'Qwen2.5' in self.model_path \
394 | else (img_width / 1000, img_height / 1000)
395 |
396 | json_data = self.extract_json_from_response(response)
397 |
398 | bboxes = []
399 | points = []
400 |
401 | if json_data and isinstance(json_data, list):
402 | for item in json_data:
403 | for key in ['bbox', 'bbox_2d']:
404 | if key in item and len(item[key]) == 4:
405 | bbox = [
406 | int(item[key][0] * x_factor),
407 | int(item[key][1] * y_factor),
408 | int(item[key][2] * x_factor),
409 | int(item[key][3] * y_factor)
410 | ]
411 | bboxes.append(bbox)
412 |
413 | for key in ['point', 'point_2d']:
414 | if key in item and len(item[key]) == 2:
415 | point = [
416 | int(item[key][0] * x_factor),
417 | int(item[key][1] * y_factor)
418 | ]
419 | points.append(point)
420 |
421 | masks = self.generate_masks(image, bboxes, points)
422 |
423 | return {
424 | "masks": masks,
425 | "bboxes": bboxes,
426 | "points": points,
427 | "thinking": "",
428 | "full_response": response,
429 | "json_data": json_data
430 | }
431 | except Exception as e:
432 | raise
433 | print(f"Error in segmentation: {e}")
434 | return {
435 | "masks": np.zeros((image.height, image.width), dtype=bool),
436 | "bboxes": [],
437 | "points": [],
438 | "thinking": "",
439 | "full_response": "",
440 | "json_data": None
441 | }
442 |
443 | def segment_objects_batch(self, images, queries):
444 | """
445 | Segment objects in a batch of images
446 |
447 | Args:
448 | images: List of input images
449 | queries: List of text queries
450 |
451 | Returns:
452 | list: List of segmentation results
453 | """
454 | results = []
455 | for image, query in zip(images, queries):
456 | result = self.segment_objects(image, query)
457 | results.append(result)
458 | return results
459 |
460 | def count_objects(self, image, query):
461 | """
462 | Count objects in an image based on a query
463 |
464 | Args:
465 | image: Input image
466 | query: Text query describing what to count
467 |
468 | Returns:
469 | dict: Results with count and bounding boxes
470 | """
471 | try:
472 | prompt = self.COUNTING_PROMPT.format(query=query)
473 | response, scale_factors = self.generate_response(image, prompt)
474 |
475 | # Get original image dimensions
476 | img_height, img_width = image.size[1], image.size[0]
477 | x_factor, y_factor = (1, 1) if 'Qwen2.5' in self.model_path \
478 | else (img_width / 1000, img_height / 1000)
479 |
480 | json_data = self.extract_json_from_response(response)
481 |
482 | bboxes = []
483 | points = []
484 |
485 | if json_data and isinstance(json_data, list):
486 | for item in json_data:
487 | for key in ['bbox', 'bbox_2d']:
488 | if key in item:
489 | bbox = [
490 | int(item[key][0] * x_factor),
491 | int(item[key][1] * y_factor),
492 | int(item[key][2] * x_factor),
493 | int(item[key][3] * y_factor)
494 | ]
495 | bboxes.append(bbox)
496 |
497 | for key in ['point', 'point_2d']:
498 | if key in item:
499 | point = [
500 | int(item[key][0] * x_factor),
501 | int(item[key][1] * y_factor)
502 | ]
503 | points.append(point)
504 |
505 | return {
506 | "count": len(bboxes),
507 | "bboxes": bboxes,
508 | "thinking": "",
509 | "points": points,
510 | "full_response": response,
511 | "json_data": json_data
512 | }
513 |
514 | # If JSON extraction fails, extract count from text using regex
515 | # Match digits, number words and Roman numerals
516 | count = 0
517 |
518 | # Match cardinal numbers (e.g. "5", "five", "10")
519 | number_words = {
520 | 'zero': 0, 'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5,
521 | 'six': 6, 'seven': 7, 'eight': 8, 'nine': 9, 'ten': 10, 'eleven': 11,
522 | 'twelve': 12, 'thirteen': 13, 'fourteen': 14, 'fifteen': 15,
523 | 'sixteen': 16, 'seventeen': 17, 'eighteen': 18, 'nineteen': 19, 'twenty': 20
524 | }
525 |
526 | # Look for phrases like "there are X" or "I count X" or "X objects"
527 | count_patterns = [
528 | r'there (?:are|is) (\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty)',
529 | r'i (?:can |)(?:see|count|find) (\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty)',
530 | r'(\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty) (?:objects|items)',
531 | r'count(?:ing)? (?:is|of) (\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty)',
532 | r'total (?:of|is) (\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty)',
533 | r'(\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty) in total',
534 | r'found (\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty)',
535 | # Roman numerals
536 | r'there (?:are|is) (I|II|III|IV|V|VI|VII|VIII|IX|X|XI|XII|XIII|XIV|XV|XVI|XVII|XVIII|XIX|XX)',
537 | r'count(?:ing)? (?:is|of) (I|II|III|IV|V|VI|VII|VIII|IX|X|XI|XII|XIII|XIV|XV|XVI|XVII|XVIII|XIX|XX)',
538 | r'total (?:of|is) (I|II|III|IV|V|VI|VII|VIII|IX|X|XI|XII|XIII|XIV|XV|XVI|XVII|XVIII|XIX|XX)',
539 | ]
540 |
541 | roman_to_int = {
542 | 'I': 1, 'II': 2, 'III': 3, 'IV': 4, 'V': 5,
543 | 'VI': 6, 'VII': 7, 'VIII': 8, 'IX': 9, 'X': 10,
544 | 'XI': 11, 'XII': 12, 'XIII': 13, 'XIV': 14, 'XV': 15,
545 | 'XVI': 16, 'XVII': 17, 'XVIII': 18, 'XIX': 19, 'XX': 20
546 | }
547 |
548 | import re
549 | response_lower = response.lower()
550 |
551 | for pattern in count_patterns:
552 | match = re.search(pattern, response_lower)
553 | if match:
554 | match_text = match.group(1)
555 | if match_text.isdigit():
556 | count = int(match_text)
557 | break
558 | elif match_text in number_words:
559 | count = number_words[match_text]
560 | break
561 | elif match_text.upper() in roman_to_int:
562 | count = roman_to_int[match_text.upper()]
563 | break
564 |
565 | return {
566 | "count": count,
567 | "bboxes": bboxes,
568 | "thinking": "Extracted count from text response",
569 | "points": points,
570 | "full_response": response,
571 | "json_data": json_data
572 | }
573 |
574 | except Exception as e:
575 | print(f"Error in counting: {e}")
576 | return {
577 | "count": 0,
578 | "bboxes": [],
579 | "thinking": "",
580 | "points": [],
581 | "full_response": "",
582 | "json_data": None
583 | }
584 |
585 |
586 | def count_objects_batch(self, images, queries):
587 | """
588 | Count objects in a batch of images
589 |
590 | Args:
591 | images: List of input images
592 | queries: List of text queries
593 |
594 | Returns:
595 | list: List of counting results
596 | """
597 | results = []
598 | for image, query in zip(images, queries):
599 | result = self.count_objects(image, query)
600 | results.append(result)
601 | return results
602 |
603 | def answer_question(self, image, question):
604 | """
605 | Answer a question about an image
606 |
607 | Args:
608 | image: Input image
609 | question: Text question
610 |
611 | Returns:
612 | dict: Results with answer
613 | """
614 | try:
615 | prompt = self.QA_PROMPT.format(query=question)
616 | response, _ = self.generate_response(image, prompt)
617 |
618 | answer = response
619 |
620 | return {
621 | "answer": answer,
622 | "thinking": "",
623 | "full_response": response,
624 | "json_data": None
625 | }
626 | except Exception as e:
627 | print(f"Error in QA: {e}")
628 | return {
629 | "answer": "",
630 | "thinking": "",
631 | "full_response": "",
632 | "json_data": None
633 | }
634 |
635 | def answer_questions_batch(self, images, questions):
636 | """
637 | Answer questions about a batch of images
638 |
639 | Args:
640 | images: List of input images
641 | questions: List of text questions
642 |
643 | Returns:
644 | list: List of QA results
645 | """
646 | results = []
647 | for image, question in zip(images, questions):
648 | result = self.answer_question(image, question)
649 | results.append(result)
650 | return results
--------------------------------------------------------------------------------
/vision_reasoner/models/task_router.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForCausalLM, AutoTokenizer
3 | import re
4 |
5 | class TaskRouter:
6 | def __init__(self, model_name="Qwen/Qwen2.5-7B-Instruct"):
7 | """Initialize task router"""
8 | self.model = AutoModelForCausalLM.from_pretrained(
9 | model_name,
10 | torch_dtype=torch.bfloat16,
11 | attn_implementation="flash_attention_2",
12 | device_map="auto"
13 | )
14 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
15 | self.prompt_template = "Given a user instruction, please classify which type of task it is and output the final answer after \"####\".. The types are: 1) Segmentation/detection, 2) Counting, 3) Editing, 4) Caption/QA. The user instruction is: "
16 |
17 | def route_task(self, instruction):
18 | """Route input instruction to corresponding task category
19 |
20 | Args:
21 | instruction: User input instruction
22 |
23 | Returns:
24 | dict: Dictionary containing predicted category and confidence
25 | """
26 | # Get model response
27 | response = self._get_model_response(instruction)
28 |
29 | # Extract category
30 | predicted_category = self._extract_category(response)
31 |
32 | return predicted_category
33 |
34 | def route_batch(self, instructions):
35 | """Batch route tasks
36 |
37 | Args:
38 | instructions: List of instructions
39 |
40 | Returns:
41 | list: List of result dictionaries
42 | """
43 | # Get batch responses
44 | responses = self._get_model_responses(instructions)
45 |
46 | results = []
47 | for instruction, response in zip(instructions, responses):
48 | category = self._extract_category(response)
49 | results.append(category)
50 |
51 | return results
52 |
53 | def _get_model_response(self, instruction):
54 | """Get model response for a single instruction"""
55 | return self._get_model_responses([instruction])[0]
56 |
57 | def _get_model_responses(self, instructions):
58 | """Get batch model responses
59 |
60 | Args:
61 | instructions: List of instructions
62 |
63 | Returns:
64 | list: List of responses
65 | """
66 | # Build batch messages
67 | message_batch = [
68 | [
69 | {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
70 | {"role": "user", "content": self.prompt_template + instruction}
71 | ]
72 | for instruction in instructions
73 | ]
74 |
75 | # Process batch template
76 | text_batch = self.tokenizer.apply_chat_template(
77 | message_batch,
78 | tokenize=False,
79 | add_generation_prompt=True
80 | )
81 |
82 | # Batch tokenize
83 | model_inputs = self.tokenizer(
84 | text_batch,
85 | return_tensors="pt",
86 | padding=True
87 | ).to(self.model.device)
88 |
89 | with torch.no_grad():
90 | generated_ids = self.model.generate(
91 | **model_inputs,
92 | max_new_tokens=512
93 | )
94 | # Only keep newly generated tokens
95 | generated_ids = generated_ids[:, model_inputs.input_ids.shape[1]:]
96 | responses = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
97 |
98 | return responses
99 |
100 | def _extract_category(self, response):
101 | """Extract predicted category from model response"""
102 | response = response.lower()
103 |
104 | # 1. Find all number matches
105 | number_pattern = r'(?:type|category|task)?\s*(?:is|:)?\s*([1-4])'
106 | number_matches = re.findall(number_pattern, response)
107 | if number_matches:
108 | number = int(number_matches[-1]) # Take the last matched number
109 | if number == 1:
110 | return 'segmentation'
111 | elif number == 2:
112 | return 'counting'
113 | elif number == 3:
114 | return 'editing'
115 | elif number == 4:
116 | return 'vqa'
117 |
118 | # 2. Keyword matching - find all matching keywords
119 | matches = []
120 | if any(word in response for word in ['segmentation', 'detection', 'segment', 'detect', 'grounding']):
121 | matches.append('segmentation')
122 | if any(word in response for word in ['counting', 'count', 'number']):
123 | matches.append('counting')
124 | if any(word in response for word in ['editing', 'edit', 'modify']):
125 | matches.append('editing')
126 | if any(word in response for word in ['caption', 'qa', 'question', 'answer', 'describe']):
127 | matches.append('vqa')
128 |
129 | # Return the last matched category
130 | return matches[-1] if matches else "vqa"
--------------------------------------------------------------------------------
/vision_reasoner/models/vision_reasoner_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import re
4 | import json
5 | import os
6 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
7 | from sam2.sam2_image_predictor import SAM2ImagePredictor
8 | from PIL import Image as PILImage
9 | from ultralytics import YOLOWorld
10 | from openai import OpenAI
11 | from io import BytesIO
12 | import base64
13 |
14 | from .base_model import (
15 | BaseVisionModel,
16 | DetectionModel,
17 | SegmentationModel,
18 | CountingModel,
19 | QAModel
20 | )
21 | from qwen_vl_utils import process_vision_info
22 | from .task_router import TaskRouter
23 |
24 | STOP_WORDS = {"is", "are", "find", "the", "segment", "all", "in", "image",
25 | "how", "many", "there", "locate", "please"}
26 | MAX_QUERY_WORDS = 2
27 |
28 |
29 | class VisionReasonerModel(BaseVisionModel, DetectionModel, SegmentationModel, CountingModel, QAModel):
30 | """
31 | VisionReasoner model implementing all task interfaces
32 | """
33 | def __init__(self,
34 | reasoning_model_path="Ricky06662/VisionReasoner-7B",
35 | segmentation_model_path="facebook/sam2-hiera-large",
36 | task_router_model_path="Ricky06662/TaskRouter-1.5B",
37 | yolo_model_path=None,
38 | generation_model_path=None):
39 | """
40 | Initialize the VisionReasoner model with reasoning and segmentation components
41 |
42 | Args:
43 | reasoning_model_path (str): Path to the reasoning model
44 | segmentation_model_path (str): Path to the segmentation model
45 | """
46 | self.resize_size = 840
47 |
48 | # Initialize reasoning model
49 | self.reasoning_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
50 | reasoning_model_path,
51 | torch_dtype=torch.bfloat16,
52 | attn_implementation="flash_attention_2",
53 | device_map="auto",
54 | )
55 | self.reasoning_model.eval()
56 |
57 | # Initialize processor
58 | self.processor = AutoProcessor.from_pretrained(reasoning_model_path, padding_side="left")
59 |
60 | # Initialize segmentation model
61 | self.segmentation_model = SAM2ImagePredictor.from_pretrained(segmentation_model_path)
62 |
63 | self.task_router = TaskRouter(task_router_model_path)
64 |
65 | # Template for detection/segmentation tasks
66 | self.DETECTION_TEMPLATE = \
67 | "Please find \"{Question}\" with bboxs and points." \
68 | "Compare the difference between object(s) and find the most closely matched object(s)." \
69 | "Output the thinking process in and final answer in tags." \
70 | "Output the bbox(es) and point(s) inside the interested object(s) in JSON format." \
71 | "i.e., thinking process here " \
72 | "{Answer}"
73 |
74 | # Template for QA tasks
75 | self.QA_TEMPLATE = "{Question}"
76 |
77 | # Initialize YOLO model
78 | self.use_hybrid_mode = False
79 | if yolo_model_path:
80 | self.use_hybrid_mode = True
81 | self.yolo_model = YOLOWorld(yolo_model_path)
82 |
83 | # Initialize generation model
84 | if generation_model_path:
85 | self.generation_model = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", generation_model_path))
86 |
87 |
88 | def extract_bbox_points_think(self, output_text, x_factor, y_factor):
89 | """
90 | Extract bounding boxes, points, and thinking process from model output
91 |
92 | Args:
93 | output_text (str): Raw output text from the model
94 | x_factor (float): Scaling factor for x coordinates
95 | y_factor (float): Scaling factor for y coordinates
96 |
97 | Returns:
98 | tuple: (pred_bboxes, pred_points, think_text, pred_answer)
99 | """
100 | json_match = re.search(r'\s*(.*?)\s*', output_text, re.DOTALL)
101 | pred_bboxes = []
102 | pred_points = []
103 | pred_answer = None
104 | think_text = ""
105 |
106 | if json_match:
107 | try:
108 | data = json.loads(json_match.group(1))
109 | pred_answer = data
110 | pred_bboxes = [[
111 | int(item['bbox_2d'][0] * x_factor + 0.5),
112 | int(item['bbox_2d'][1] * y_factor + 0.5),
113 | int(item['bbox_2d'][2] * x_factor + 0.5),
114 | int(item['bbox_2d'][3] * y_factor + 0.5)
115 | ] for item in data]
116 | pred_points = [[
117 | int(item['point_2d'][0] * x_factor + 0.5),
118 | int(item['point_2d'][1] * y_factor + 0.5)
119 | ] for item in data]
120 | except Exception as e:
121 | print(f"Error parsing JSON: {e}")
122 |
123 | think_pattern = r'([^<]+)'
124 | think_match = re.search(think_pattern, output_text)
125 | if think_match:
126 | think_text = think_match.group(1)
127 |
128 | return pred_bboxes, pred_points, think_text, pred_answer
129 |
130 | def extract_qa_answer(self, output_text):
131 | """
132 | Extract answer for QA tasks
133 |
134 | Args:
135 | output_text (str): Raw output text from the model
136 |
137 | Returns:
138 | dict: Result dictionary with answer and thinking (if available)
139 | """
140 | think_pattern = r'([^<]+)'
141 | think_match = re.search(think_pattern, output_text)
142 | thinking = think_match.group(1) if think_match else ""
143 |
144 | # Remove thinking tags from output to get cleaner answer
145 | clean_answer = re.sub(r'.*?', '', output_text, flags=re.DOTALL).strip()
146 |
147 | return {
148 | "answer": clean_answer,
149 | "thinking": thinking,
150 | "full_response": output_text
151 | }
152 |
153 | def generate_masks(self, image, bboxes, points=None):
154 | """
155 | Generate segmentation masks for given image, bounding boxes and points
156 |
157 | Args:
158 | image (PIL.Image): Input image
159 | bboxes (list): List of bounding boxes
160 | points (list): List of points
161 |
162 | Returns:
163 | numpy.ndarray: Combined segmentation mask
164 | """
165 | img_height, img_width = image.height, image.width
166 | mask_all = np.zeros((img_height, img_width), dtype=bool)
167 |
168 | if not bboxes:
169 | return mask_all
170 |
171 | if points and len(points) != len(bboxes):
172 | return mask_all
173 |
174 | try:
175 | self.segmentation_model.set_image(image)
176 | if points:
177 | for bbox, point in zip(bboxes, points):
178 | masks, scores, _ = self.segmentation_model.predict(
179 | point_coords=[point],
180 | point_labels=[1],
181 | box=bbox
182 | )
183 | sorted_ind = np.argsort(scores)[::-1]
184 | masks = masks[sorted_ind]
185 | mask = masks[0].astype(bool)
186 | mask_all = np.logical_or(mask_all, mask)
187 | else:
188 | for bbox in bboxes:
189 | masks, scores, _ = self.segmentation_model.predict(
190 | box=bbox
191 | )
192 | sorted_ind = np.argsort(scores)[::-1]
193 | masks = masks[sorted_ind]
194 |
195 | return mask_all
196 | except Exception as e:
197 | print(f"Error generating masks: {e}")
198 | return mask_all
199 |
200 | def _generate_model_output(self, images, instructions, template, batch_mode=False):
201 | """
202 | Generate raw model output for images and instructions
203 |
204 | Args:
205 | images (PIL.Image or List[PIL.Image]): Input image(s)
206 | instructions (str or List[str]): Text instruction(s)/query(ies)
207 | template (str): Template to use for the prompt
208 | batch_mode (bool): Whether to process in batch mode
209 |
210 | Returns:
211 | tuple: (output_texts, scale_factors)
212 | """
213 | if not batch_mode:
214 | images = [images]
215 | instructions = [instructions]
216 |
217 | batch_messages = []
218 | scale_factors = []
219 |
220 | for image, instruction in zip(images, instructions):
221 | # Prepare image
222 | original_width, original_height = image.size
223 | x_factor, y_factor = original_width/self.resize_size, original_height/self.resize_size
224 | scale_factors.append((x_factor, y_factor))
225 | resized_image = image.resize((self.resize_size, self.resize_size), PILImage.BILINEAR)
226 |
227 | # Format text based on template
228 | if "{Question}" in template:
229 | formatted_text = template.format(
230 | Question=instruction.lower().strip(".\"?!"),
231 | Answer="[{\"bbox_2d\": [10,100,200,210], \"point_2d\": [30,110]}, {\"bbox_2d\": [225,296,706,786], \"point_2d\": [302,410]}]"
232 | )
233 | else:
234 | formatted_text = template
235 |
236 | # Create message
237 | message = [{
238 | "role": "user",
239 | "content": [
240 | {
241 | "type": "image",
242 | "image": resized_image
243 | },
244 | {
245 | "type": "text",
246 | "text": formatted_text
247 | }
248 | ]
249 | }]
250 | batch_messages.append(message)
251 |
252 | # Prepare for batch inference
253 | texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
254 |
255 | image_inputs, video_inputs = process_vision_info(batch_messages)
256 |
257 | inputs = self.processor(
258 | text=texts,
259 | images=image_inputs,
260 | padding=True,
261 | return_tensors="pt",
262 | )
263 | inputs = inputs.to("cuda")
264 |
265 | # Generate output
266 | generated_ids = self.reasoning_model.generate(**inputs, use_cache=True, max_new_tokens=2048, do_sample=False)
267 |
268 | generated_ids_trimmed = [
269 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
270 | ]
271 |
272 | output_texts = self.processor.batch_decode(
273 | generated_ids_trimmed,
274 | skip_special_tokens=True,
275 | clean_up_tokenization_spaces=False
276 | )
277 | if not batch_mode:
278 | return output_texts[0], scale_factors[0]
279 | return output_texts, scale_factors
280 |
281 | # BaseVisionModel implementation
282 | def process_single_image(self, image, instruction, return_task_type=False):
283 | """
284 | Process a single image with given instruction
285 |
286 | Args:
287 | image (PIL.Image): Input image
288 | instruction (str): Text instruction or query
289 |
290 | Returns:
291 | dict: Results dictionary
292 | """
293 | # Determine task type based on instruction
294 | task_type = self.task_router.route_task(instruction)
295 |
296 | if task_type == "segmentation":
297 | result = self.segment_objects(image, instruction)
298 | elif task_type == "detection":
299 | result = self.detect_objects(image, instruction)
300 | elif task_type == "counting":
301 | result = self.count_objects(image, instruction)
302 | else: # Default to VQA
303 | result = self.answer_question(image, instruction)
304 |
305 | if return_task_type:
306 | return result, task_type
307 | else:
308 | return result
309 |
310 | def process_batch(self, batch_images, batch_instructions):
311 | """
312 | Process a batch of images with given instructions
313 |
314 | Args:
315 | batch_images (list): List of PIL Images
316 | batch_instructions (list): List of text instructions or queries
317 |
318 | Returns:
319 | list: List of result dictionaries
320 | """
321 | results = []
322 | for image, instruction in zip(batch_images, batch_instructions):
323 | result = self.process_single_image(image, instruction)
324 | results.append(result)
325 | return results
326 |
327 | def detect_objects_yolo(self, image, query):
328 | """
329 | Detect objects in an image based on a query using YOLO model
330 |
331 | Args:
332 | image: Input image
333 | query: Text query describing what to detect
334 |
335 | Returns:
336 | dict: Results with bounding boxes and scores
337 | """
338 | # Initialize a YOLO model
339 | query = " ".join([word.strip(".\"?!'") for word in query.lower().strip(".\"?!").split() if word not in STOP_WORDS])
340 | names = [query]
341 | self.yolo_model.set_classes(names)
342 |
343 | # Run detection on the given image
344 | results = self.yolo_model.predict(image)
345 |
346 | # Get original image dimensions
347 | img_height, img_width = image.height, image.width
348 |
349 | # Get YOLO's input size
350 | yolo_input_size = results[0].orig_shape
351 |
352 | # Calculate scaling factors
353 | x_scale = img_width / yolo_input_size[1]
354 | y_scale = img_height / yolo_input_size[0]
355 |
356 | # Scale the bounding boxes back to original image size
357 | bboxes = results[0].boxes.xyxy
358 | scaled_bboxes = []
359 | for bbox in bboxes:
360 | scaled_bbox = [
361 | int(bbox[0] * x_scale),
362 | int(bbox[1] * y_scale),
363 | int(bbox[2] * x_scale),
364 | int(bbox[3] * y_scale)
365 | ]
366 | scaled_bboxes.append(scaled_bbox)
367 |
368 | return scaled_bboxes
369 |
370 | def if_yolo_condition(self, query):
371 | """
372 | Check if YOLO should be used for the given query
373 |
374 | Args:
375 | query (str): Text query describing what to detect
376 |
377 | Returns:
378 | bool: True if YOLO should be used, False otherwise
379 | """
380 |
381 | # trivial condition
382 | query_words = [word for word in query.lower().strip(".\"?!").split() if word not in STOP_WORDS]
383 | return len(query_words) <= MAX_QUERY_WORDS
384 |
385 | # DetectionModel implementation
386 | def detect_objects(self, image, query):
387 | """
388 | Detect objects in an image based on a query
389 |
390 | Args:
391 | image: Input image
392 | query: Text query describing what to detect
393 |
394 | Returns:
395 | dict: Results with bounding boxes and scores
396 | """
397 | try:
398 | if self.use_hybrid_mode and self.if_yolo_condition(query):
399 | bboxes = self.detect_objects_yolo(image, query)
400 | scores = [1.0] * len(bboxes)
401 | # use middle point of bbox as point
402 | points = [[int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2)] for bbox in bboxes]
403 | output_text, thinking, pred_answer = "", "", str(bboxes)
404 | else:
405 | output_text, (x_factor, y_factor) = self._generate_model_output(
406 | image,
407 | query,
408 | self.DETECTION_TEMPLATE
409 | )
410 |
411 | bboxes, points, thinking, pred_answer = self.extract_bbox_points_think(
412 | output_text,
413 | x_factor,
414 | y_factor
415 | )
416 |
417 | # Assign confidence scores (all 1.0 as the model doesn't provide them)
418 | scores = [1.0] * len(bboxes)
419 |
420 | return {
421 | "bboxes": bboxes,
422 | "points": points,
423 | "scores": scores,
424 | "thinking": thinking,
425 | "full_response": output_text,
426 | "pred_answer": pred_answer
427 | }
428 | except Exception as e:
429 | print(f"Error in detection: {e}")
430 | return {
431 | "bboxes": [],
432 | "points": [],
433 | "scores": [],
434 | "thinking": "",
435 | "full_response": "",
436 | "pred_answer": None
437 | }
438 |
439 | def detect_objects_batch(self, images, queries):
440 | """
441 | Detect objects in a batch of images
442 |
443 | Args:
444 | images: List of input images
445 | queries: List of text queries
446 |
447 | Returns:
448 | list: List of detection results
449 | """
450 | try:
451 | # TODO: support yolo for batch
452 |
453 | output_texts, scale_factors = self._generate_model_output(
454 | images,
455 | queries,
456 | self.DETECTION_TEMPLATE,
457 | batch_mode=True
458 | )
459 |
460 | results = []
461 | for output_text, (x_factor, y_factor) in zip(output_texts, scale_factors):
462 | bboxes, points, thinking, pred_answer = self.extract_bbox_points_think(
463 | output_text,
464 | x_factor,
465 | y_factor
466 | )
467 |
468 | scores = [1.0] * len(bboxes)
469 | results.append({
470 | "bboxes": bboxes,
471 | "points": points,
472 | "scores": scores,
473 | "thinking": thinking,
474 | "full_response": output_text,
475 | "pred_answer": pred_answer
476 | })
477 | return results
478 | except Exception as e:
479 | print(f"Error in batch detection: {e}")
480 | return [{
481 | "bboxes": [],
482 | "points": [],
483 | "scores": [],
484 | "thinking": "",
485 | "full_response": "",
486 | "pred_answer": None
487 | } for _ in range(len(images))]
488 |
489 | # SegmentationModel implementation
490 | def segment_objects(self, image, query):
491 | """
492 | Segment objects in an image based on a query
493 |
494 | Args:
495 | image: Input image
496 | query: Text query describing what to segment
497 |
498 | Returns:
499 | dict: Results with masks and bounding boxes
500 | """
501 | try:
502 | if self.use_hybrid_mode and self.if_yolo_condition(query):
503 | #bboxes, masks = self.segment_objects_yolo(image, query)
504 | bboxes = self.detect_objects_yolo(image, query)
505 | # use middle point of bbox as point
506 | points = [[int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2)] for bbox in bboxes]
507 | output_text, thinking, pred_answer = "", "", str(bboxes)
508 | else:
509 | output_text, (x_factor, y_factor) = self._generate_model_output(
510 | image,
511 | query,
512 | self.DETECTION_TEMPLATE
513 | )
514 | bboxes, points, thinking, pred_answer = self.extract_bbox_points_think(
515 | output_text,
516 | x_factor,
517 | y_factor
518 | )
519 | masks = self.generate_masks(image, bboxes, points)
520 |
521 | return {
522 | "masks": masks,
523 | "bboxes": bboxes,
524 | "points": points,
525 | "thinking": thinking,
526 | "full_response": output_text,
527 | "pred_answer": pred_answer
528 | }
529 | except Exception as e:
530 | raise
531 | print(f"Error in segmentation: {e}")
532 | img_height, img_width = image.height, image.width
533 | return {
534 | "masks": np.zeros((img_height, img_width), dtype=bool),
535 | "bboxes": [],
536 | "points": [],
537 | "thinking": "",
538 | "full_response": "",
539 | "pred_answer": None
540 | }
541 |
542 | def segment_objects_batch(self, images, queries):
543 | """
544 | Segment objects in a batch of images
545 |
546 | Args:
547 | images: List of input images
548 | queries: List of text queries
549 |
550 | Returns:
551 | list: List of segmentation results
552 | """
553 | try:
554 | # TODO: support yolo for batch
555 | output_texts, scale_factors = self._generate_model_output(
556 | images,
557 | queries,
558 | self.DETECTION_TEMPLATE,
559 | batch_mode=True
560 | )
561 |
562 | results = []
563 | for image, output_text, (x_factor, y_factor) in zip(images, output_texts, scale_factors):
564 | bboxes, points, thinking, pred_answer = self.extract_bbox_points_think(
565 | output_text,
566 | x_factor,
567 | y_factor
568 | )
569 |
570 | masks = self.generate_masks(image, bboxes, points)
571 | results.append({
572 | "masks": masks,
573 | "bboxes": bboxes,
574 | "points": points,
575 | "thinking": thinking,
576 | "full_response": output_text,
577 | "pred_answer": pred_answer
578 | })
579 | return results
580 | except Exception as e:
581 | print(f"Error in batch segmentation: {e}")
582 | return [{
583 | "masks": np.zeros((img.height, img.width), dtype=bool),
584 | "bboxes": [],
585 | "points": [],
586 | "thinking": "",
587 | "full_response": "",
588 | "pred_answer": None
589 | } for img in images]
590 |
591 | # CountingModel implementation
592 | def count_objects(self, image, query):
593 | """
594 | Count objects in an image based on a query
595 |
596 | Args:
597 | image: Input image
598 | query: Text query describing what to count
599 |
600 | Returns:
601 | dict: Results with count and bounding boxes
602 | """
603 | try:
604 | if self.use_hybrid_mode and self.if_yolo_condition(query):
605 | bboxes = self.detect_objects_yolo(image, query)
606 | # use middle point of bbox as point
607 | points = [[int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2)] for bbox in bboxes]
608 | output_text, thinking, pred_answer = "", "", str(bboxes)
609 | else:
610 | output_text, (x_factor, y_factor) = self._generate_model_output(
611 | image,
612 | query,
613 | self.DETECTION_TEMPLATE
614 | )
615 |
616 | bboxes, points, thinking, pred_answer = self.extract_bbox_points_think(
617 | output_text,
618 | x_factor,
619 | y_factor
620 | )
621 |
622 | count = len(bboxes)
623 |
624 | return {
625 | "count": count,
626 | "bboxes": bboxes,
627 | "points": points,
628 | "thinking": thinking,
629 | "full_response": output_text,
630 | "pred_answer": pred_answer
631 | }
632 | except Exception as e:
633 | print(f"Error in counting: {e}")
634 | return {
635 | "count": 0,
636 | "bboxes": [],
637 | "points": [],
638 | "thinking": "",
639 | "full_response": "",
640 | "pred_answer": None
641 | }
642 |
643 | def count_objects_batch(self, images, queries):
644 | """
645 | Count objects in a batch of images
646 |
647 | Args:
648 | images: List of input images
649 | queries: List of text queries
650 |
651 | Returns:
652 | list: List of counting results
653 | """
654 | try:
655 | # TODO: support yolo for batch
656 | output_texts, scale_factors = self._generate_model_output(
657 | images,
658 | queries,
659 | self.DETECTION_TEMPLATE,
660 | batch_mode=True
661 | )
662 |
663 | results = []
664 | for output_text, (x_factor, y_factor) in zip(output_texts, scale_factors):
665 | bboxes, points, thinking, pred_answer = self.extract_bbox_points_think(
666 | output_text,
667 | x_factor,
668 | y_factor
669 | )
670 |
671 | count = len(bboxes)
672 | results.append({
673 | "count": count,
674 | "bboxes": bboxes,
675 | "points": points,
676 | "thinking": thinking,
677 | "full_response": output_text,
678 | "pred_answer": pred_answer
679 | })
680 | return results
681 | except Exception as e:
682 | print(f"Error in batch counting: {e}")
683 | return [{
684 | "count": 0,
685 | "bboxes": [],
686 | "points": [],
687 | "thinking": "",
688 | "full_response": "",
689 | "pred_answer": None
690 | } for _ in range(len(images))]
691 |
692 | # QAModel implementation
693 | def answer_question(self, image, question):
694 | """
695 | Answer a question about an image
696 |
697 | Args:
698 | image: Input image
699 | question: Text question
700 |
701 | Returns:
702 | dict: Results with answer and thinking (if available)
703 | """
704 | try:
705 | output_text, _ = self._generate_model_output(
706 | image,
707 | question,
708 | self.QA_TEMPLATE
709 | )
710 |
711 | result = self.extract_qa_answer(output_text)
712 | return result
713 | except Exception as e:
714 | print(f"Error in QA: {e}")
715 | return {
716 | "answer": "",
717 | "thinking": "",
718 | "full_response": ""
719 | }
720 |
721 | def answer_questions_batch(self, images, questions):
722 | """
723 | Answer questions about a batch of images
724 |
725 | Args:
726 | images: List of input images
727 | questions: List of text questions
728 |
729 | Returns:
730 | list: List of QA results
731 | """
732 | try:
733 | output_texts, _ = self._generate_model_output(
734 | images,
735 | questions,
736 | self.QA_TEMPLATE,
737 | batch_mode=True
738 | )
739 |
740 | results = []
741 | for output_text in output_texts:
742 | result = self.extract_qa_answer(output_text)
743 | results.append(result)
744 | return results
745 | except Exception as e:
746 | print(f"Error in batch QA: {e}")
747 | return [{
748 | "answer": "",
749 | "thinking": "",
750 | "full_response": ""
751 | } for _ in range(len(images))]
752 |
753 | def generate_image(self, refer_image_path, image_prompt):
754 | """
755 | Generate an image based on a query
756 |
757 | Args:
758 | refer_image_path: Path to the reference image
759 | image_prompt: Text prompt describing what to generate
760 |
761 | Returns:
762 | dict: Results with generated image and thinking (if available)
763 | """
764 | if self.generation_model is None or image_prompt is None:
765 | raise ValueError("Do not have generation model or query")
766 |
767 | try:
768 | if refer_image_path == "":
769 | # Generate the image
770 | output = self.generation_model.images.generate(
771 | model="gpt-image-1",
772 | prompt=image_prompt,
773 | )
774 | image_base64 = output.data[0].b64_json
775 |
776 | else:
777 | output = self.generation_model.images.edit(
778 | model="gpt-image-1",
779 | image=[open(refer_image_path, "rb")],
780 | prompt=image_prompt,
781 | )
782 | image_base64 = output.data[0].b64_json
783 |
784 | image = PILImage.open(BytesIO(base64.b64decode(image_base64)))
785 | return image
786 | except Exception as e:
787 | print(f"Error in image generation: {e}")
788 | return None
--------------------------------------------------------------------------------
/vision_reasoner/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 |
4 | def visualize_results_enhanced(image, result, task_type, output_path):
5 | """
6 | Enhanced visualization with three-panel layout
7 | """
8 | # Create a figure with 3 subplots
9 | plt.figure(figsize=(15, 5))
10 |
11 | # First panel: Original image
12 | plt.subplot(1, 3, 1)
13 | plt.imshow(image)
14 | plt.title('Original Image')
15 | plt.axis('off')
16 |
17 | # Second panel: Image with bounding boxes
18 | plt.subplot(1, 3, 2)
19 | plt.imshow(image)
20 |
21 | if 'bboxes' in result and result['bboxes']:
22 | for bbox in result['bboxes']:
23 | x1, y1, x2, y2 = bbox
24 | rect = plt.Rectangle((x1, y1), x2-x1, y2-y1,
25 | fill=False, edgecolor='red', linewidth=2)
26 | plt.gca().add_patch(rect)
27 |
28 | if 'points' in result and result['points']:
29 | for point in result['points']:
30 | plt.plot(point[0], point[1], 'go', markersize=8) # Green point
31 |
32 | plt.title('Image with Bounding Boxes')
33 | plt.axis('off')
34 |
35 | # Third panel: Mask overlay (for segmentation tasks)
36 | plt.subplot(1, 3, 3)
37 | plt.imshow(image, alpha=0.6)
38 |
39 | if task_type == 'segmentation' and 'masks' in result and result['masks'] is not None:
40 | mask = result['masks']
41 | if np.any(mask):
42 | mask_overlay = np.zeros_like(np.array(image))
43 | mask_overlay[mask] = [255, 0, 0] # Red color for mask
44 | plt.imshow(mask_overlay, alpha=0.4)
45 |
46 | if task_type == 'detection' or task_type == 'counting':
47 | # For non-segmentation tasks, just show bounding boxes again
48 | if 'bboxes' in result and result['bboxes']:
49 | for bbox in result['bboxes']:
50 | x1, y1, x2, y2 = bbox
51 | rect = plt.Rectangle((x1, y1), x2-x1, y2-y1,
52 | fill=True, edgecolor='red', facecolor='red', alpha=0.3)
53 | plt.gca().add_patch(rect)
54 |
55 | task_title = {
56 | 'detection': 'Detection Overlay',
57 | 'segmentation': 'Segmentation Mask',
58 | 'counting': 'Counting Results',
59 | 'qa': 'Visual QA'
60 | }
61 |
62 | plt.title(task_title.get(task_type, 'Results Overlay'))
63 | plt.axis('off')
64 |
65 | plt.tight_layout()
66 | plt.savefig(output_path)
67 | plt.close()
--------------------------------------------------------------------------------