├── .gitignore ├── LICENSE ├── README.md ├── deepperception └── eval │ ├── eval.sh │ ├── evaluate.py │ └── inference.py ├── figs └── header.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | deepperception/data -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Maxy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepPerception: Advancing R1-like Cognitive Visual Perception in MLLMs for Knowledge-Intensive Visual Grounding 2 | Xinyu Ma, Ziyang Ding, Zhicong Luo, Chi Chen, Zonghao Guo, Derek F. Wong, Xiaoyi Feng, Maosong Sun 3 | 4 | ----- 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | This is the official repository of **DeepPerception**, an MLLM enhanced with cognitive visual perception capabilities. 13 | 14 | ## Release 15 | 16 | - [x] **`2025.03.18`** 🔥Release the DeepPerception evaluation code and model in [`🤗HuggingFace`](https://huggingface.co/MaxyLee/DeepPerception). 17 | - [x] **`2025.03.18`** 🔥DeepPerception Paper has been released in [`📕Arxiv`](https://arxiv.org/abs/2503.12797). 18 | 19 | ## Overview 20 | 21 |

22 |
23 | Figure 1: (a) DeepPerception employs knowledge-driven reasoning to derive answers, while the baseline model directly outputs predictions without cognitive processing. (b) DeepPerception demonstrates superior cognitive visual perception capabilities that cannot be elicited in the foundation model through simplistic zero-shot CoT prompting. 24 |

25 | 26 | #### Abstract 27 | 28 | Human experts excel at fine-grained visual discrimination by leveraging domain knowledge to refine perceptual features, a capability that remains underdeveloped in current Multimodal Large Language Models (MLLMs). Despite possessing vast expert-level knowledge, MLLMs struggle to integrate reasoning into visual perception, often generating direct responses without deeper analysis. 29 | 30 | To bridge this gap, we introduce knowledge-intensive visual grounding (KVG), a novel visual grounding task that requires both finegrained perception and domain-specific knowledge integration. To address the challenges of KVG, we propose **DeepPerception**, an MLLM enhanced with cognitive visual perception capabilities. Our approach consists of (1) an automated data synthesis pipeline that generates high-quality, knowledge-aligned training samples, and (2) a two-stage training framework combining supervised fine-tuning for cognitive reasoning scaffolding and reinforcement learning to optimize perceptioncognition synergy. To benchmark performance, we introduce KVG-Bench, a comprehensive dataset spanning 10 domains with 1.3K manually curated test cases. 31 | 32 | Experimental results demonstrate that DeepPerception significantly outperforms direct fine-tuning, achieving +8.08% accuracy improvements on KVG-Bench and exhibiting +4.60% superior cross-domain generalization over baseline approaches. Our findings highlight the importance of integrating cognitive processes into MLLMs for human-like visual perception and open new directions for multimodal reasoning research. 33 | 34 | #### Key Contributions 35 | 36 | - We introduce the task of **Knowledge-intensive Visual Grounding (KVG)** to explore the concept of cognitive visual perception for MLLMs, aiming to integrate their inherent knowledge and reasoning capabilities into visual perception. 37 | - We propose **[DeepPerception](https://huggingface.co/MaxyLee/DeepPerception)**, an MLLM with enhanced cognitive visual perception capabilities. To achieve this, we develop an automated dataset creation pipeline and a two-stage framework integrating supervised cognitive capability enhancement with perception-oriented reinforcement learning. 38 | - We introduce **[KVG-Bench](https://huggingface.co/datasets/MaxyLee/KVG-Bench)**, a manually curated benchmark for the KVG task involving diverse knowledge domains and entities. Experiments on KVG-Bench and other fine-grained visual recognition tasks demonstrate DeepPerception's exceptional cognitive visual perception capabilities and superior cross-domain generalization performance. 39 | 40 | ## Get Started 41 | 42 | ### Contents: 43 | 44 | - [Environment](#environment) 45 | - [Data Preparation](#data-preparation) 46 | - [Checkpoints](#checkpoints) 47 | - [Evaluation](#evaluation) 48 | - [Training](#training) 49 | 50 | ### Environment 51 | 52 | 1. Clone this repository and navigate to DeepPerception folder 53 | ```bash 54 | git clone https://github.com/MaxyLee/DeepPerception.git 55 | cd DeepPerception 56 | ``` 57 | 2. Install Packages 58 | For evaluation: 59 | ```bash 60 | conda create -n deepperception python=3.9 61 | conda activate deepperception 62 | 63 | pip install -r requirements.txt 64 | ``` 65 | 66 | ### Data Preparation 67 | 68 | | Dataset | Links | 69 | |--------- |---------------------------------------| 70 | | KVG-Bench | [`🤗HuggingFace`](https://huggingface.co/datasets/MaxyLee/KVG-Bench) | 71 | | KVG Training | [`🤗HuggingFace`](https://huggingface.co/datasets/MaxyLee/KVG) | 72 | --- 73 | 74 | ### Checkpoints 75 | 76 | | Model | Links | 77 | |--------- |---------------------------------------| 78 | | DeepPerception | [`🤗HuggingFace`](https://huggingface.co/MaxyLee/DeepPerception) | 79 | | DeepPerception-FGVR | [`🤗HuggingFace`](https://huggingface.co/MaxyLee/DeepPerception-FGVR) | 80 | --- 81 | 82 | ### Evaluation 83 | 84 | ```bash 85 | # Evaluate on KVG-Bench 86 | bash eval.sh [CUDA_IDS] [KVG_BENCH_PATH] [CKPT_PATH] 87 | ``` 88 | Notice: Please modify the script if you want to evaluate on Qwen2-VL. 89 | 90 | ### Training 91 | 92 | TODO 93 | 94 | ## Citation 95 | 96 | If you find DeepPerception useful for your research or applications, please cite using this BibTeX: 97 | 98 | ```bibtex 99 | @misc{ma2025deepperception, 100 | title={DeepPerception: Advancing R1-like Cognitive Visual Perception in MLLMs for Knowledge-Intensive Visual Grounding}, 101 | author={Xinyu Ma and Ziyang Ding and Zhicong Luo and Chi Chen and Zonghao Guo and Derek F. Wong and Xiaoyi Feng and Maosong Sun}, 102 | year={2025}, 103 | url={https://arxiv.org/abs/2503.12797}, 104 | } 105 | ``` 106 | 107 | ## Acknowledgement 108 | 109 | - [Qwen2-VL](https://github.com/QwenLM/Qwen2.5-VL) 110 | - [vLLM](https://github.com/vllm-project/vllm) 111 | - [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) 112 | - [R1-V](https://github.com/Deep-Agent/R1-V) 113 | 114 | ## License 115 | 116 | [![Code License](https://img.shields.io/badge/Code%20License-MIT-Green.svg)](https://github.com/twbs/bootstrap/blob/main/LICENSE) 117 | [![Data License](https://img.shields.io/badge/Code%20License-Apache_2.0-Green.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE) 118 | -------------------------------------------------------------------------------- /deepperception/eval/eval.sh: -------------------------------------------------------------------------------- 1 | GPU_IDs=$1 2 | DATA_PATH=$2 3 | CKPT=$3 4 | 5 | 6 | if [[ $DATA_PATH == *"kvg-bench"* ]]; then 7 | # KVG-Bench 8 | OUT_DIR=$CKPT/kvg-bench-eval 9 | 10 | # Evaluate DeepPerception 11 | # To ensure precise reproduction of the experimental results of KVG-Bench presented in the paper, please strictly adhere to the package versions specified in the requirements.txt file and DO NOT use the vllm. 12 | 13 | python evaluate.py \ 14 | --data_path $DATA_PATH \ 15 | --ckpt_path $CKPT \ 16 | --gpu_ids $GPU_IDs \ 17 | --output_path $OUT_DIR \ 18 | --prompt r1 19 | 20 | # Evaluate Qwen2-VL 21 | # DO NOT use --prompt r1, which requires model to first output the thinking process 22 | 23 | # python evaluate.py \ 24 | # --data_path $DATASET \ 25 | # --ckpt_path $CKPT \ 26 | # --gpu_ids $GPU_IDs \ 27 | # --output_path $OUT_DIR 28 | else 29 | # TODO 30 | # FGVR 31 | OUT_DIR=$CKPT/fgvr-eval 32 | 33 | # Evaluate DeepPerception-FGVR 34 | 35 | python evaluate.py \ 36 | --data_path $DATA_PATH \ 37 | --ckpt_path $CKPT \ 38 | --gpu_ids $GPU_IDs \ 39 | --output_path $OUT_DIR \ 40 | --vllm \ 41 | --prompt r1 42 | 43 | # Evaluate Qwen2-VL 44 | 45 | # python evaluate.py \ 46 | # --data_path $DATA_PATH \ 47 | # --ckpt_path $CKPT \ 48 | # --gpu_ids $GPU_IDs \ 49 | # --output_path $OUT_DIR \ 50 | # --vllm 51 | fi -------------------------------------------------------------------------------- /deepperception/eval/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import argparse 5 | import subprocess 6 | import time 7 | import torch 8 | 9 | from datasets import load_dataset 10 | from torchvision.ops.boxes import box_area 11 | from multiprocessing import Process 12 | from tqdm import tqdm 13 | 14 | 15 | bbox_patterns = [ 16 | re.compile(r'.*?\((\d*?),.*?(\d*?)\),\((\d*?),(\d*?)\)'), 17 | re.compile(r'So the answer is.*?\((\d*?),.*?(\d*?)\),\((\d*?),(\d*?)\)'), 18 | re.compile(r'\((\d*?),.*?(\d*?)\),\((\d*?),(\d*?)\)'), 19 | re.compile(r'\((.*?),.*?(.*?)\).*?\((.*?),.*?(.*?)\)'), 20 | re.compile(r'\[(\d*?), (\d*?), (\d*?), (\d*?)\]'), 21 | re.compile(r'\[(.*?), (.*?), (.*?), (.*?)\]'), 22 | re.compile(r'\((\d*?), (\d*?), (\d*?), (\d*?)\)'), 23 | re.compile(r'\((\d*?), (\d*?)\)\n?.*?\((\d*?), (\d*?)\)') 24 | ] 25 | 26 | REF_PATTERN = re.compile(r'<\|object_ref_start\|>(.*?)<\|object_ref_end\|>') 27 | ANSWER_PATTERN = re.compile(r'(.*?)') 28 | 29 | def get_choice(ans): 30 | match = re.findall(ANSWER_PATTERN, ans) 31 | if len(match) > 0: 32 | choice = match[0].strip() 33 | if len(choice) > 1: 34 | choice = choice.split('.')[0] 35 | return choice 36 | else: 37 | return None 38 | 39 | def box_iou(boxes1, boxes2): 40 | area1 = box_area(boxes1) 41 | area2 = box_area(boxes2) 42 | 43 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 44 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 45 | 46 | wh = (rb - lt).clamp(min=0) # [N,M,2] 47 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 48 | 49 | union = area1[:, None] + area2 - inter 50 | 51 | iou = inter / union 52 | return iou, union 53 | 54 | def get_bbox(ans): 55 | for i, pattern in enumerate(bbox_patterns): 56 | predict_bbox = re.findall(pattern, ans) 57 | if len(predict_bbox) != 0: 58 | try: 59 | predict_bbox = (float(predict_bbox[-1][0].replace('[', '').replace('x', '')), float(predict_bbox[-1][1]), float(predict_bbox[-1][2]), float(predict_bbox[-1][3])) 60 | except: 61 | predict_bbox = [0, 0, 0, 0] 62 | if sum(predict_bbox) < 4: 63 | predict_bbox = [c*1000 for c in predict_bbox] 64 | 65 | return predict_bbox, i+1 66 | 67 | return (0., 0., 0., 0.), 0 68 | 69 | def calculate_ious(category, results): 70 | ious = [] 71 | correct = 0 72 | match_patterns_cnt = [0] * (len(bbox_patterns) + 1) 73 | for r in results: 74 | answer = r['answer'] 75 | 76 | predict_bbox, i = get_bbox(answer) 77 | r['pred_bbox'] = predict_bbox 78 | predict_bbox = torch.tensor(predict_bbox, dtype=torch.float32).view(-1, 4) 79 | 80 | max_iou = 0 81 | for gt_bbox in r['gt_bbox']: 82 | target_bbox = torch.tensor(gt_bbox, dtype=torch.float32).view(-1, 4) 83 | iou, _ = box_iou(predict_bbox, target_bbox) 84 | iou = iou.item() 85 | if iou > max_iou: 86 | max_iou = iou 87 | 88 | ious.append(max_iou) 89 | r['iou'] = max_iou 90 | r['match pattern'] = i 91 | match_patterns_cnt[i] += 1 92 | if max_iou >= 0.5: 93 | correct += 1 94 | 95 | metrics = dict() 96 | acc = correct / len(ious) 97 | avg_iou = sum(ious)/len(ious) 98 | 99 | print(category) 100 | print(f'unmatch: {match_patterns_cnt[0]}, ' + ', '.join([f'match {i+1}: {cnt}' for i, cnt in enumerate(match_patterns_cnt[1:])])) 101 | print(f'Acc @ 0.5: {acc}, IoU: {avg_iou}') 102 | 103 | metrics['all'] = { 104 | 'Acc': acc, 105 | 'IoU': avg_iou, 106 | 'Num': len(ious) 107 | } 108 | 109 | return results, metrics 110 | 111 | def eval(task, args, test_data): 112 | output_path = args.output_path 113 | 114 | all_metrics = dict() 115 | if task == 'grounding': 116 | seen_categories = args.seen_categories.split(',') 117 | all_categories = args.all_categories.split(',') 118 | 119 | 120 | results = {d: [] for d in all_categories} 121 | 122 | all_res = [] 123 | seen_res = [] 124 | unseen_res = [] 125 | for data in tqdm(test_data): 126 | with open(f'{output_path}/temp/{data["question_id"]}.json', 'r') as f: 127 | r = json.load(f) 128 | results[data["category"]].append(r) 129 | all_res.append(r) 130 | if data["category"] in seen_categories: 131 | seen_res.append(r) 132 | else: 133 | unseen_res.append(r) 134 | 135 | all_res, metrics = calculate_ious('all', all_res) 136 | all_metrics['all'] = metrics 137 | 138 | seen_res, metrics = calculate_ious('seen domain', seen_res) 139 | all_metrics['seen domain'] = metrics 140 | 141 | unseen_res, metrics = calculate_ious('unseen domain', unseen_res) 142 | all_metrics['unseen domain'] = metrics 143 | 144 | for dataset, res in results.items(): 145 | res, metrics = calculate_ious(dataset, res) 146 | all_metrics[dataset] = metrics 147 | with open(f'{args.output_path}/{dataset}.json', 'w') as f: 148 | json.dump(res, f, indent=4) 149 | 150 | with open(f'{args.output_path}/metrics.json', 'w') as f: 151 | json.dump(all_metrics, f) 152 | 153 | elif task == 'classification': 154 | correct = 0 155 | match_cnt = 0 156 | results = [] 157 | for data in tqdm(test_data): 158 | with open(f'{output_path}/temp/{data["question_id"]}.json', 'r') as f: 159 | r = json.load(f) 160 | answer = r['answer'] 161 | gt = data['messages'][1]['content'] 162 | 163 | pred = get_choice(answer) 164 | 165 | if pred: 166 | match_cnt += 1 167 | if gt == pred: 168 | correct += 1 169 | r['correct'] = True 170 | else: 171 | r['correct'] = False 172 | results.append(r) 173 | 174 | acc = correct / len(test_data) 175 | print(f'Acc ({dataset}): {acc}') 176 | print(f'Match rate: {match_cnt/len(test_data)}') 177 | 178 | category = test_data['question_id'].split('/')[0] 179 | with open(f'{args.output_path}/{category}.json', 'w') as f: 180 | json.dump(results, f) 181 | with open(f'{args.output_path}/{category}-metrics.json', 'w') as f: 182 | json.dump(all_metrics, f) 183 | 184 | 185 | def infer(args, json_path, gpu_id): 186 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 187 | 188 | if args.vllm: 189 | subprocess.run(["python", 'inference.py', 190 | "--data_path", args.data_path, 191 | "--prompt", str(args.prompt), 192 | "--vllm", 193 | "--id_path", json_path, 194 | "--model_path", args.ckpt_path, 195 | "--output_path", args.output_path]) 196 | else: 197 | subprocess.run(["python", 'inference.py', 198 | "--data_path", args.data_path, 199 | "--prompt", str(args.prompt), 200 | "--id_path", json_path, 201 | "--model_path", args.ckpt_path, 202 | "--output_path", args.output_path]) 203 | 204 | def launch_subprocesses(args, temp): 205 | processes = [] 206 | temp_files = [] 207 | 208 | if len(temp) > 0: 209 | if '72B' in args.ckpt_path: 210 | nprocs = args.num_processes 211 | if nprocs == 2: 212 | gpu_ids = ['0,1,2,3', '4,5,6,7'] 213 | elif nprocs == 1: 214 | gpu_ids = [args.gpu_ids] 215 | else: # 7B-scale models 216 | gpu_ids = list(map(int, args.gpu_ids.split(','))) 217 | nprocs = len(gpu_ids) 218 | 219 | num_data_per_group = len(temp) // len(gpu_ids) 220 | 221 | for i, gpu_id in enumerate(gpu_ids): 222 | start_idx = i * num_data_per_group 223 | end_idx = start_idx + num_data_per_group if i != (nprocs-1) else None 224 | 225 | timestamp = time.strftime("%Y%m%d%H%M%S") 226 | json_path = f'{args.output_path}/temp/{timestamp}_{gpu_id}.json' 227 | temp_files.append(json_path) 228 | with open(json_path, "w") as f: 229 | json.dump(temp[start_idx:end_idx], f) 230 | 231 | p = Process(target=infer, args=(args, json_path, gpu_id)) 232 | processes.append(p) 233 | p.start() 234 | 235 | for p in processes: 236 | p.join() 237 | 238 | for temp_file in temp_files: 239 | os.remove(temp_file) 240 | 241 | def get_data(args): 242 | output_path = args.output_path 243 | 244 | if args.data_path.endswith('.parquet'): # KVG-Bench 245 | task = 'grounding' 246 | all_categories = args.all_categories.split(',') 247 | for c in all_categories: 248 | os.makedirs(f'{output_path}/temp/{c}', exist_ok=True) 249 | 250 | dataset = load_dataset("parquet", data_files={"test": args.data_path}) 251 | test_data = dataset['test'] 252 | elif args.data_path.endswith('.json'): # FGVR 253 | task = 'classification' 254 | with open(args.data_path, 'r') as f: 255 | test_data = json.load(f) 256 | 257 | category = args.data_path.split('/')[-1].split('.')[0] 258 | os.makedirs(f'{output_path}/temp/{category}', exist_ok=True) 259 | i = 0 260 | for d in test_data: 261 | d['question_id'] = f'{category}/{str(i).zfill(5)}' 262 | else: 263 | print(f'No supported file type: {args.data_path}') 264 | 265 | qids = [] 266 | for d in test_data: 267 | if not os.path.isfile(f'{output_path}/temp/{d["question_id"]}.json'): 268 | qids.append(d['question_id']) 269 | print(f'# Test data: {len(qids)}') 270 | 271 | return task, test_data, qids 272 | 273 | 274 | def parse_arguments(): 275 | parser = argparse.ArgumentParser(description="Process images across multiple GPUs.") 276 | parser.add_argument("--data_path", required=True) 277 | parser.add_argument("--prompt", required=False, default=None) 278 | parser.add_argument("--vllm", action='store_true') 279 | parser.add_argument("--seen_categories", required=False, default='aircraft,car,reptilia,bird,food') 280 | parser.add_argument("--all_categories", required=False, default='aircraft,car,reptilia,bird,food,dog,mollusca,mammal,flower,landmark') 281 | parser.add_argument("--ckpt_path", required=True) 282 | parser.add_argument("--num_processes", type=int, required=False, default=8) 283 | parser.add_argument("--gpu_ids", type=str, required=True, help="Comma-separated GPU IDs.") 284 | parser.add_argument("--output_path", required=True, help="Path to the output dir") 285 | 286 | return parser.parse_args() 287 | 288 | 289 | def main(): 290 | args = parse_arguments() 291 | 292 | print(f"Evaluating {args.ckpt_path}. Prompt: {args.prompt}. Results will be saved in {args.output_path}.") 293 | 294 | task, test_data, qids = get_data(args) 295 | launch_subprocesses(args, qids) 296 | eval(task, args, test_data) 297 | 298 | 299 | 300 | if __name__ == "__main__": 301 | main() 302 | -------------------------------------------------------------------------------- /deepperception/eval/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | import os 4 | import re 5 | import json 6 | import time 7 | import torch 8 | import base64 9 | 10 | from PIL import Image 11 | from tqdm import tqdm 12 | from datasets import load_dataset 13 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor 14 | from qwen_vl_utils import process_vision_info 15 | 16 | PATTERN = re.compile(r'<\|box_start\|>\(([0-9]*?),([0-9]*?)\),\(([0-9]*?),([0-9]*?)\)<\|box_end\|>') 17 | REF_PATTERN = re.compile(r'<\|object_ref_start\|>(.*?)<\|object_ref_end\|>') 18 | 19 | GROUNDING_TEMPLATE = "{Question} Output the thinking process in and final answer (bounding box) in tags." 20 | # QUESTION_TEMPLATE = "{Question} Output the thinking process in and final answer (bounding box in (x1,y1),(x2,y2) format) in tags." 21 | CLASSIFICATION_TEMPLATE = "{Question} Output the thinking process in and final answer in tags." 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="Training") 25 | parser.add_argument("--data_path", required=True) 26 | parser.add_argument("--prompt", required=False, default=None) 27 | parser.add_argument("--id_path", required=True) 28 | parser.add_argument("--model_path", required=True, help="Path to qwen.") 29 | parser.add_argument("--output_path", required=True) 30 | parser.add_argument('--vllm', action='store_true') 31 | 32 | parser.add_argument("--batch_size", required=False, type=int, default=1) 33 | 34 | args = parser.parse_args() 35 | 36 | return args 37 | 38 | def inference_classification(model, processor, sampling_params, prompt, query, image): 39 | messages = [] 40 | 41 | if prompt == 'r1': 42 | query = CLASSIFICATION_TEMPLATE.format(Question=query) 43 | messages.append({"role": "user", "content": [dict(type='image', image=image), dict(type='text', text=query)]}) 44 | else: 45 | messages.append({"role": "user", "content": [dict(type='image', image=image), dict(type='text', text=query)]}) 46 | 47 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 48 | image_inputs, _ = process_vision_info(messages) 49 | if sampling_params: 50 | llm_inputs = { 51 | "prompt": text, 52 | "multi_modal_data": { 53 | "image": image_inputs 54 | } 55 | } 56 | outputs = model.generate([llm_inputs], sampling_params=sampling_params) 57 | generated_text = outputs[0].outputs[0].text 58 | else: 59 | inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(model.device) 60 | 61 | generated_ids = model.generate(**inputs, max_new_tokens=1500) 62 | generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] 63 | response = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) 64 | generated_text = response[0] 65 | 66 | return { 67 | 'answer': generated_text, 68 | } 69 | 70 | def inference_grounding(model, processor, sampling_params, prompt, query, image_bytes): 71 | encoded_string = 'data:image:base64,' + str(base64.b64encode(image_bytes).decode("utf-8")) 72 | messages = [] 73 | cot_response = None 74 | # CoT 75 | if prompt == 'cot-kvg': 76 | match = re.search(REF_PATTERN, query) 77 | ref = match[1] 78 | 79 | cot_text = ( 80 | f'Which object in this image is {ref}? Give a detailed and discriminative description of the appearance of it' 81 | ) 82 | 83 | messages.append({"role": "user", "content": [dict(type='image_url', image_url=encoded_string), dict(type='text', text=cot_text)]}) 84 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 85 | image_inputs, _ = process_vision_info(messages) 86 | if sampling_params: 87 | llm_inputs = { 88 | "prompt": text, 89 | "multi_modal_data": { 90 | "image": image_inputs 91 | } 92 | } 93 | outputs = model.generate([llm_inputs], sampling_params=sampling_params) 94 | cot_response = outputs[0].outputs[0].text 95 | else: 96 | inputs = processor(text=[text], padding=True, return_tensors="pt").to(model.device) 97 | 98 | generated_ids = model.generate(**inputs, max_new_tokens=1500) 99 | generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] 100 | cot_response = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 101 | 102 | messages.append({"role": "assistant", "content": cot_response}) 103 | grounding_text = f'Based on the description, find and give the bounding box of <|object_ref_start|>{ref}<|object_ref_end|>' 104 | messages.append({"role": "user", "content": [dict(type='text', text=grounding_text)]}) 105 | elif prompt == 'cot-normal': 106 | query += ". Let's think step by step" 107 | messages.append({"role": "user", "content": [dict(type='image_url', image_url=encoded_string), dict(type='text', text=query)]}) 108 | elif prompt == 'r1': 109 | query = GROUNDING_TEMPLATE.format(Question=query) 110 | messages.append({"role": "user", "content": [dict(type='image_url', image_url=encoded_string), dict(type='text', text=query)]}) 111 | else: 112 | messages.append({"role": "user", "content": [dict(type='image_url', image_url=encoded_string), dict(type='text', text=query)]}) 113 | 114 | # Grounding 115 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 116 | image_inputs, _ = process_vision_info(messages) 117 | if sampling_params: 118 | llm_inputs = { 119 | "prompt": text, 120 | "multi_modal_data": { 121 | "image": image_inputs 122 | } 123 | } 124 | outputs = model.generate([llm_inputs], sampling_params=sampling_params) 125 | generated_text = outputs[0].outputs[0].text 126 | else: 127 | inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(model.device) 128 | 129 | generated_ids = model.generate(**inputs, max_new_tokens=1500) 130 | generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] 131 | response = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) 132 | generated_text = response[0] 133 | 134 | return { 135 | 'cot': cot_response, 136 | 'answer': generated_text, 137 | } 138 | 139 | def infer(model, processor, sampling_params, args): 140 | prompt = args.prompt 141 | output_path = args.output_path 142 | 143 | if args.data_path.endswith('.parquet'): # KVG-Bench 144 | task = 'grounding' 145 | dataset = load_dataset("parquet", data_files={"test": args.data_path}) 146 | elif args.data_path.endswith('.json'): # FGVR 147 | task = 'classification' 148 | with open(args.data_path, 'r') as f: 149 | dataset = json.load(f) 150 | else: 151 | print(f'No supported file type: {args.data_path}') 152 | 153 | with open(args.id_path, 'r') as f: 154 | qids = json.load(f) 155 | 156 | test_data = [] 157 | for d in dataset['test']: 158 | if d['question_id'] in qids: 159 | test_data.append(d) 160 | 161 | for data in tqdm(test_data): 162 | if task == 'grounding': 163 | query = data['question'] 164 | image_bytes = data['image']['bytes'] 165 | 166 | gt = data['answer'] 167 | match = re.search(PATTERN, gt) 168 | bbox = [[float(match[1]), float(match[2]), float(match[3]), float(match[4])]] 169 | 170 | image = Image.open(io.BytesIO(image_bytes)) 171 | w, h = image.size 172 | 173 | out_filename = f"{output_path}/temp/{data['question_id']}.json" 174 | 175 | response = inference_grounding(model, processor, sampling_params, prompt, query, image_bytes) 176 | response['gt_bbox'] = bbox 177 | response['hw'] = (h ,w) 178 | elif task == 'classification': 179 | query = data['messages'][0]['content'].replace('', '') 180 | image = data['images'][0] 181 | 182 | 183 | out_filename = f"{output_path}/temp/{data['question_id']}.json" 184 | response = inference_grounding(model, processor, sampling_params, prompt, query, image) 185 | 186 | 187 | with open(out_filename, 'w') as f: 188 | json.dump(response, f) 189 | 190 | 191 | def main(): 192 | args = parse_args() 193 | os.makedirs(args.output_path, exist_ok=True) 194 | 195 | if args.vllm: 196 | from vllm import LLM, SamplingParams 197 | model = LLM(args.model_path, max_model_len=17920, tensor_parallel_size=1) 198 | sampling_params = SamplingParams(n=1, temperature=0, max_tokens=1536) 199 | else: 200 | model = Qwen2VLForConditionalGeneration.from_pretrained( 201 | args.model_path, 202 | torch_dtype=torch.bfloat16, 203 | attn_implementation="flash_attention_2", 204 | device_map="auto" 205 | ) 206 | sampling_params = None 207 | 208 | processor = AutoProcessor.from_pretrained(args.model_path) 209 | 210 | start_time = time.time() 211 | 212 | with torch.no_grad(): 213 | infer(model, processor, sampling_params, args) 214 | 215 | end_time = time.time() 216 | elapsed_time = end_time - start_time 217 | 218 | print('\033[92m' + "---- Evaluate Time taken: {} seconds ----".format(elapsed_time) + '\033[0m') 219 | 220 | 221 | if __name__ == "__main__": 222 | main() -------------------------------------------------------------------------------- /figs/header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/DeepPerception/971d92f67a21d9aca53f2f565b6899a6cf11dd5f/figs/header.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.16.1 2 | Pillow==11.1.0 3 | qwen_vl_utils==0.0.10 4 | torch==2.2.2 5 | torchvision==0.17.2 6 | tqdm==4.66.5 7 | transformers==4.45.1 8 | --------------------------------------------------------------------------------