├── prepare_meta_info.py
├── README.md
├── PicaEval_consistency.py
├── PicaEval_gpt.py
└── PicaEval_qwen.py
/prepare_meta_info.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Generate standard meta_info.json from HuggingFace PICABench dataset + model output images.
4 |
5 | Solves the problem:
6 | - Users get a Dataset object from load_dataset but don't know how to organize it into the JSON format required by evaluation scripts
7 | - Automatically saves input images to filesystem, maps output image paths, and generates complete metadata
8 |
9 | Usage:
10 | pip install datasets pillow tqdm
11 |
12 | # Basic usage: assuming output images are in outputs/ directory with filenames matching dataset index
13 | python prepare_meta_info.py --output_image_dir outputs --save_dir PICABench_data
14 |
15 | # Specify output image naming pattern (if not named by index)
16 | python prepare_meta_info.py --output_image_dir my_results --save_dir data \
17 | --output_name_pattern "{index:05d}.png"
18 |
19 | Output:
20 | - save_dir/input_img/ # Input images (saved from HF dataset)
21 | - save_dir/meta_info.json # Standard format JSON, ready for evaluation
22 | """
23 |
24 | import argparse
25 | import json
26 | import os
27 | from pathlib import Path, PurePosixPath
28 | from typing import Any, Dict, List, Optional
29 |
30 | from datasets import load_dataset
31 | from PIL import Image
32 | from tqdm import tqdm
33 |
34 |
35 | def find_output_image(output_dir: Path, index: int, pattern: str) -> Optional[str]:
36 | """
37 | Find output image based on index and naming pattern.
38 | pattern supports {index} placeholder, e.g. "{index:05d}.jpg"
39 | """
40 | # Try user-specified pattern
41 | filename = pattern.format(index=index)
42 | candidate = output_dir / filename
43 | if candidate.exists():
44 | return filename
45 |
46 | # Try common extensions
47 | for ext in [".jpg", ".jpeg", ".png", ".webp"]:
48 | for fmt in [f"{index:05d}{ext}", f"{index:04d}{ext}", f"{index}{ext}"]:
49 | candidate = output_dir / fmt
50 | if candidate.exists():
51 | return fmt
52 |
53 | return None
54 |
55 |
56 | def save_input_image(img: Image.Image, path: Path) -> None:
57 | """Save input image to specified path"""
58 | path.parent.mkdir(parents=True, exist_ok=True)
59 | img.convert("RGB").save(path, quality=95)
60 |
61 |
62 | def build_meta_item(
63 | idx: int,
64 | example: Dict[str, Any],
65 | input_filename: str,
66 | output_filename: Optional[str],
67 | output_dir_for_json: str,
68 | ) -> Dict[str, Any]:
69 | """Build a single meta_info record"""
70 | if output_filename:
71 | output_path = (
72 | str(PurePosixPath(output_dir_for_json) / output_filename)
73 | if output_dir_for_json
74 | else output_filename
75 | )
76 | else:
77 | output_path = None
78 |
79 | item = {
80 | "index": idx,
81 | "input_path": f"input_img/{input_filename}",
82 | "output_path": output_path,
83 | "edit_instruction": example.get("edit_instruction", ""),
84 | "physics_category": example.get("physics_category", "unknown"),
85 | "physics_law": example.get("physics_law", "unknown"),
86 | "edit_operation": example.get("edit_operation", "unknown"),
87 | "difficulty": example.get("difficulty", "unknown"),
88 | "annotated_qa_pairs": example.get("annotated_qa_pairs", []),
89 | "edit_area": example.get("edit_area", "unknown"),
90 | }
91 | return item
92 |
93 |
94 | def main() -> None:
95 | parser = argparse.ArgumentParser(
96 | description="Generate meta_info.json from HF PICABench dataset + output image directory"
97 | )
98 | parser.add_argument(
99 | "--hf_repo",
100 | type=str,
101 | default="PICABench",
102 | help="HuggingFace dataset repository name",
103 | )
104 | parser.add_argument(
105 | "--split",
106 | type=str,
107 | default="picabench",
108 | help="Dataset split name",
109 | )
110 | parser.add_argument(
111 | "--output_image_dir",
112 | type=str,
113 | required=True,
114 | help="Directory containing model-generated output images (relative to save_dir or absolute path)",
115 | )
116 | parser.add_argument(
117 | "--save_dir",
118 | type=str,
119 | required=True,
120 | help="Root directory to save meta_info.json and input_img/",
121 | )
122 | parser.add_argument(
123 | "--output_name_pattern",
124 | type=str,
125 | default="{index:05d}.jpg",
126 | help="Output image filename pattern, supports {index} placeholder, default {index:05d}.jpg",
127 | )
128 | parser.add_argument(
129 | "--allow_missing",
130 | action="store_true",
131 | help="Allow missing output images, still generate JSON (output_path will be null)",
132 | )
133 | parser.add_argument(
134 | "--force_input_save",
135 | action="store_true",
136 | help="Overwrite input images even if they already exist under save_dir/input_img",
137 | )
138 | args = parser.parse_args()
139 |
140 | save_dir = Path(args.save_dir).resolve()
141 | input_dir = save_dir / "input_img"
142 | input_dir.mkdir(parents=True, exist_ok=True)
143 |
144 | # Output image directory: supports relative path (relative to save_dir) or absolute path
145 | output_dir_arg = Path(args.output_image_dir)
146 | output_dir = output_dir_arg if output_dir_arg.is_absolute() else save_dir / output_dir_arg
147 | output_dir = output_dir.resolve()
148 |
149 | try:
150 | output_dir_for_json = output_dir.relative_to(save_dir).as_posix()
151 | except ValueError:
152 | output_dir_for_json = output_dir.as_posix()
153 |
154 | if output_dir_for_json in ("", "."):
155 | output_dir_for_json = ""
156 |
157 | print(f"Loading dataset: {args.hf_repo} (split={args.split})")
158 | dataset = load_dataset(args.hf_repo, split=args.split)
159 | print(f"Dataset size: {len(dataset)}")
160 |
161 | meta_info: List[Dict[str, Any]] = []
162 | missing_outputs: List[int] = []
163 |
164 | for idx, example in tqdm(enumerate(dataset), total=len(dataset), desc="Processing samples"):
165 | # 1. Save input image
166 | input_img = example.get("input_image")
167 | if input_img is None:
168 | # Fallback: try loading from image_path
169 | img_path = example.get("image_path")
170 | if img_path and os.path.exists(img_path):
171 | input_img = Image.open(img_path)
172 | else:
173 | print(f"Warning: sample {idx} has no input image, skipping")
174 | continue
175 |
176 | input_filename = f"{idx:05d}.jpg"
177 | input_path = input_dir / input_filename
178 | if input_path.exists():
179 | if args.force_input_save:
180 | save_input_image(input_img, input_path)
181 | else:
182 | save_input_image(input_img, input_path)
183 |
184 | # 2. Find corresponding output image
185 | output_filename = find_output_image(output_dir, idx, args.output_name_pattern)
186 | if output_filename is None:
187 | missing_outputs.append(idx)
188 | if not args.allow_missing:
189 | continue # Skip samples without output images
190 |
191 | # 3. Build meta_info entry
192 | item = build_meta_item(idx, example, input_filename, output_filename, output_dir_for_json)
193 | meta_info.append(item)
194 |
195 | # Save JSON
196 | json_path = save_dir / "meta_info.json"
197 | with open(json_path, "w", encoding="utf-8") as f:
198 | json.dump(meta_info, f, ensure_ascii=False, indent=2)
199 |
200 | print(f"\n✓ Successfully generated meta_info.json: {json_path}")
201 | print(f" - Total samples: {len(meta_info)}")
202 | print(f" - Input images saved to: {input_dir}")
203 |
204 | if missing_outputs:
205 | print(f"\n⚠ Warning: {len(missing_outputs)} samples missing output images")
206 | if len(missing_outputs) <= 10:
207 | print(f" Missing indices: {missing_outputs}")
208 | else:
209 | print(f" Missing indices (first 10): {missing_outputs[:10]}")
210 |
211 | if not args.allow_missing:
212 | print(" These samples are excluded from JSON")
213 | else:
214 | print(" These samples have output_path set to null")
215 |
216 |
217 | if __name__ == "__main__":
218 | main()
219 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | PICABench: How Far Are We from Physically Realistic Image Editing?
4 |
5 |
6 | Benchmark, evaluator, and data suite for physically realistic image editing.
7 |
8 | [](https://huggingface.co/papers/2510.17681)
9 | [](https://arxiv.org/pdf/2510.17681)
10 | [](https://picabench.github.io)
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |

19 |
20 |
21 | ## 📋 Table of Contents
22 |
23 | - [Overview](#overview)
24 | - [Quick Start](#quick-start)
25 | - [Installation](#installation)
26 | - [Data Preparation](#data-preparation)
27 | - [Evaluation Pipelines](#evaluation-pipelines)
28 | - [PICA-100K Training Data](#pica-100k-training-data)
29 | - [Leaderboard & Qualitative Explorer](#leaderboard--qualitative-explorer)
30 | - [Leaderboard Submission](#leaderboard-submission)
31 | - [Citation](#citation)
32 |
33 | ## Overview
34 |
35 | PICABench probes how far current editing models are from physically realistic image manipulation. It ties together:
36 |
37 | - **PICABench benchmark** – physics-aware editing cases spanning eight laws across *Optics*, *Mechanics*, and *State Transition*, each labeled with superficial/intermediate/explicit difficulty tiers.
38 | - **PICAEval metric** – region-grounded, QA-based verification with human-annotated regions of interest (ROIs) and spatially anchored yes/no questions.
39 | - **PICA-100K dataset** – synthetic, video-derived training data that boosts physics consistency when used for fine-tuning.
40 |
41 | The leaderboard shows that even top proprietary systems only reach ~60% accuracy, indicating a significant physics-awareness gap.
42 |
43 | ## ⚡ Quick Start
44 |
45 | Evaluate your model's physics-aware editing from a folder of output images in 3 steps:
46 |
47 | ```bash
48 | # 1) Install dependencies (choose GPT or Qwen)
49 | pip install openai Pillow tqdm datasets huggingface_hub # GPT
50 | # or
51 | pip install vllm transformers Pillow tqdm datasets huggingface_hub # Qwen/vLLM
52 |
53 | # 2) Build meta_info.json from HF dataset + your outputs
54 | # (Assume your edited images are under ./outputs as 00000.jpg, 00001.jpg, ...)
55 | python prepare_meta_info.py \
56 | --hf_repo Andrew613/PICABench \
57 | --output_image_dir outputs \
58 | --save_dir PICABench_data
59 |
60 | # 3) Run evaluation (multi-threaded)
61 | export OPENAI_API_KEY="sk-..."
62 | python PicaEval_gpt.py \
63 | --input_json_path PICABench_data/meta_info.json \
64 | --gpt_model gpt-4o \
65 | --num_workers 16
66 | ```
67 |
68 | Notes:
69 | - When `meta_info.json` lives in `PICABench_data/`, you can omit `--image_base_dir` (defaults to the JSON directory).
70 | - If your output images are outside `PICABench_data/`, `prepare_meta_info.py` will write absolute paths and the evaluators will resolve them automatically.
71 |
72 | Results are saved as `PICABench_data/meta_info_gpt_output_1024_crop_box_and_resize.json` and the corresponding `_analysis_...json`.
73 |
74 | ## Installation
75 |
76 | We recommend using a Python 3.10+ virtual environment:
77 |
78 | ```bash
79 | conda create -n picabench python=3.10
80 | conda activate picabench
81 | ```
82 |
83 | Install dependencies based on your evaluation needs:
84 |
85 | ```bash
86 | # For GPT evaluation (multi-threaded with OpenAI SDK)
87 | pip install openai Pillow tqdm huggingface_hub
88 |
89 | # For Qwen evaluation (with vLLM acceleration)
90 | pip install vllm transformers Pillow tqdm
91 | ```
92 |
93 | ## Data Preparation
94 |
95 | ### Generating `meta_info.json` from HuggingFace Dataset
96 |
97 | If you've already generated edited images with your model but don't know how to organize them into the `meta_info.json` format required by evaluation scripts, use the provided conversion script:
98 |
99 | ```bash
100 | # 1. Install dependencies
101 | pip install datasets pillow tqdm
102 |
103 | # 2. Assuming your model outputs are in outputs/ directory with filenames 00000.jpg, 00001.jpg, ...
104 | python prepare_meta_info.py \
105 | --hf_repo Andrew613/PICABench \
106 | --output_image_dir outputs \
107 | --save_dir PICABench_data
108 |
109 | # 3. Generated files:
110 | # PICABench_data/input_img/ - Input images (automatically saved from HF dataset)
111 | # PICABench_data/meta_info.json - Standard format JSON, ready for evaluation
112 | ```
113 |
114 | > ⚠️ Paths inside `meta_info.json` are written relative to the chosen `--save_dir`. Pass that same directory to the evaluators via `--image_base_dir` to avoid duplicate folder segments.
115 |
116 | **Parameters:**
117 | - `--output_image_dir`: Directory containing your model's edited output images
118 | - `--save_dir`: Root directory to save `meta_info.json` and input images
119 | - `--output_name_pattern`: Output image filename pattern (default `{index:05d}.jpg`), supports `{index}` placeholder
120 | - `--allow_missing`: Allow missing output images, still generate JSON (missing samples will have `output_path` set to `null`)
121 | - `--force_input_save`: Overwrite cached `input_img/*.jpg` files instead of reusing them (default: reuse existing files)
122 |
123 | ### `meta_info.json` Format
124 |
125 | PICABench expects per-scene metadata in `meta_info.json` plus accompanying images under a shared base directory. Each item should include:
126 |
127 | ```jsonc
128 | {
129 | "index": 1174,
130 | "input_path": "input_img/1174.jpg",
131 | "output_path": "output_img/1174.jpg",
132 | "edit_instruction": "Remove the tulip from the white vase and simultaneously eliminate every instance of it in the window reflection while keeping lighting and shading consistent.",
133 | "physics_category": "Optics",
134 | "physics_law": "Reflection",
135 | "edit_operation": "remove",
136 | "difficulty": "superficial",
137 | "annotated_qa_pairs": [
138 | {
139 | "question": "Is a tulip visible in the window reflection?",
140 | "answer": "No",
141 | "box": { "x": 101.25, "y": 476.90, "width": 169.44, "height": 202.96 }
142 | },
143 | {
144 | "question": "Does the interior of the white vase contain exactly zero tulips?",
145 | "answer": "Yes",
146 | "box": { "x": 327.96, "y": 485.99, "width": 209.80, "height": 206.21 }
147 | },
148 | {
149 | "question": "Is the vase's reflection aligned with the vase?",
150 | "answer": "Yes",
151 | "box": { "x": 117.24, "y": 496.29, "width": 363.74, "height": 183.41 }
152 | }
153 | ],
154 | "edit_area": [
155 | {
156 | "x": 117.24,
157 | "y": 496.29,
158 | "width": 363.74,
159 | "height": 183.41,
160 | "id": "BxnMC34B",
161 | "order": 1
162 | }
163 | ]
164 | }
165 | ```
166 |
167 |
168 | 📋 Field Descriptions (click to expand)
169 |
170 | - **`annotated_qa_pairs`**: List of QA dictionaries for physics verification. Each contains:
171 | - `question`: Yes/no question about physical correctness
172 | - `answer`: Ground truth ("Yes" or "No")
173 | - `box`: Region of interest `{x, y, width, height}` in 1024px canvas coordinates
174 |
175 | - **`edit_area`**: Bounding boxes of edited regions (used for visualization cropping). Set to `"unknown"` if unavailable.
176 |
177 | - **Visualization**: Scripts auto-generate cropped/annotated images in `visualization_annotated_qa_crop_box_and_resize/` under the base directory.
178 |
179 |
180 |
181 | ## Evaluation Pipelines
182 |
183 | ### 1. Qwen / vLLM (PICAEval)
184 |
185 | ```bash
186 | python PicaEval_qwen.py \
187 | --input_json_path /path/to/meta_info.json \
188 | --image_base_dir /path/to/images \
189 | --model_path pretrained/Qwen/Qwen2.5-VL-72B-Instruct \
190 | --tensor_parallel_size 4 \
191 | --dtype bfloat16 \
192 | --qa_field annotated_qa_pairs \
193 | --viz_mode crop_box_and_resize \
194 | --max_new_tokens 256 \
195 | --img_size 1024
196 | ```
197 |
198 | Outputs:
199 |
200 | - `_vllm_output_[_mode].json` – per-QA predictions with `model_answer`, `model_response`, `model_explanation`, `is_correct`, and optional `visualization_path`.
201 | - `_vllm_analysis_[_mode].json` – aggregated accuracy by physics category, law, and operation.
202 |
203 | ### 2. GPT-based Evaluation (PICAEval)
204 |
205 | ```bash
206 | export OPENAI_API_KEY="sk-..."
207 | python PicaEval_gpt.py \
208 | --input_json_path /path/to/meta_info.json \
209 | --image_base_dir /path/to/images \
210 | --qa_field annotated_qa_pairs \
211 | --viz_mode crop_box_and_resize \
212 | --gpt_model gpt-5 \
213 | --num_workers 50 \
214 | --max_attempts 5 \
215 | --api_base_url https://api.openai.com/v1
216 | ```
217 |
218 | **Key Parameters:**
219 | - `--num_workers`: Number of parallel worker threads (default: 50) for concurrent API requests
220 | - `--gpt_model`: OpenAI model name (e.g., `gpt-5`, `gpt-4o`, `gpt-4-turbo`)
221 | - `--api_base_url`: API endpoint URL (default: `https://api.openai.com/v1`)
222 | - `--max_attempts`: Retry attempts for failed API calls (default: 5)
223 |
224 | **Outputs:**
225 | - `_gpt_output_[_{mode}].json` – detailed results with per-question predictions
226 | - `_gpt_analysis_[_{mode}].json` – accuracy statistics by physics category, law, and operation
227 |
228 | **Notes:**
229 | - Uses multi-threaded execution with OpenAI SDK for efficient parallel evaluation
230 | - Reuses the same JSON schema for inputs/outputs as the Qwen pipeline, enabling direct comparison
231 | - Images are base64-encoded and sent as data URLs; be mindful of API quotas and rate limits
232 |
233 | ### 3. Non-edited Region Quality (PSNR)
234 |
235 | ```bash
236 | python PicaEval_consistency.py \
237 | --meta_info_path /path/to/meta_info.json \
238 | --base_dir /path/to/images \
239 | --size 512
240 | ```
241 |
242 | Produces `_psnr_output.json` and `_psnr_analysis.json`, containing masked PSNR on non-edited regions or whole-image PSNR when edit regions are unavailable.
243 |
244 | ## PICA-100K Training Data
245 |
246 | **Dataset**: [Andrew613/PICA-100K](https://huggingface.co/datasets/Andrew613/PICA-100K)
247 |
248 | 100K synthetic editing pairs derived from video frames, designed to improve physical realism in image editing models.
249 |
250 | ### Download
251 |
252 | ```bash
253 | huggingface-cli download Andrew613/PICA-100K \
254 | --repo-type dataset \
255 | --local-dir data/PICA-100K
256 | ```
257 |
258 | ## Leaderboard & Qualitative Explorer
259 |
260 | - Official leaderboard and gallery: [https://picabench.github.io](https://picabench.github.io)
261 | - Eight physics laws × three difficulty tiers provide direct qualitative comparisons.
262 | - PICAEval scores correlate strongly with human judgments (Elo study on the site).
263 |
264 | ## Leaderboard Submission
265 |
266 | To submit your model's results to the PICABench leaderboard:
267 |
268 | **Required Metrics:**
269 | - Accuracy (%) for each sub-category (Light Propagation, Reflection, Refraction, Light Source Effects, Deformation, Causality, Local State Transition, Global State Transition)
270 | - Overall Accuracy (%)
271 |
272 | **Submission:**
273 | Email your `*_analysis*.json` and `*_output*.json` files and model details to:
274 | - [puyuandong01061313@gmail.com](mailto:puyuandong01061313@gmail.com)
275 | ## Citation
276 |
277 | ```bibtex
278 | @article{pu2025picabench,
279 | title = {PICABench: How Far Are We From Physically Realistic Image Editing?},
280 | author = {Pu, Yuandong and Zhuo, Le and Han, Songhao and Xing, Jinbo and Zhu, Kaiwen and Cao, Shuo and Fu, Bin and Liu, Si and Li, Hongsheng and Qiao, Yu and Zhang, Wenlong and Chen, Xi and Liu, Yihao},
281 | journal = {arXiv preprint arXiv:2510.17681},
282 | year = {2025}
283 | }
284 | ```
285 |
286 | ## License
287 |
288 | This project is released under the Apache License 2.0.
289 |
--------------------------------------------------------------------------------
/PicaEval_consistency.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from typing import Dict, List, Any
5 |
6 | import numpy as np
7 | import torch
8 | from PIL import Image, ImageDraw
9 | from tqdm import tqdm
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser(description='Non-edited Region Quality Assessment - PSNR only')
14 | parser.add_argument(
15 | '--meta_info_path',
16 | required=True,
17 | help='Path to meta_info.json file',
18 | )
19 | parser.add_argument(
20 | '--base_dir',
21 | required=False,
22 | help='Base directory for image files, defaults to meta_info.json directory if not specified',
23 | )
24 | parser.add_argument(
25 | '--size',
26 | default=None,
27 | type=int,
28 | help='Resize image to this size, keep original size if not specified',
29 | )
30 |
31 | return parser.parse_args()
32 |
33 |
34 | def create_edit_mask(image_size, edit_areas):
35 | """Create edit region mask, returns non-edited region mask"""
36 | width, height = image_size
37 | # Create white image (all regions are non-edited)
38 | mask_img = Image.new('L', (width, height), 255)
39 | draw = ImageDraw.Draw(mask_img)
40 |
41 | # Mark edited regions as black (0)
42 | for area in edit_areas:
43 | x = area['x']
44 | y = area['y']
45 | w = area['width']
46 | h = area['height']
47 | # Draw rectangle, edited region in black
48 | draw.rectangle([x, y, x + w, y + h], fill=0)
49 |
50 | # Convert to tensor, non-edited region=1, edited region=0
51 | mask_array = np.array(mask_img) / 255.0 # Convert to [0,1]
52 | mask_tensor = torch.tensor(mask_array, dtype=torch.float32)
53 | return mask_tensor
54 |
55 |
56 | def load_image_with_mask(image_path: str, edit_areas: List[Dict], size=None):
57 | """Load image and create non-edited region mask"""
58 | try:
59 | image = Image.open(image_path).convert('RGB')
60 | original_size = image.size
61 |
62 | if size:
63 | image = image.resize((size, size), Image.BILINEAR)
64 | # Adjust edit_areas coordinates
65 | scale_x = size / original_size[0]
66 | scale_y = size / original_size[1]
67 | scaled_edit_areas = []
68 | for area in edit_areas:
69 | scaled_area = {
70 | 'x': area['x'] * scale_x,
71 | 'y': area['y'] * scale_y,
72 | 'width': area['width'] * scale_x,
73 | 'height': area['height'] * scale_y
74 | }
75 | scaled_edit_areas.append(scaled_area)
76 | edit_areas = scaled_edit_areas
77 | mask_size = (size, size)
78 | else:
79 | mask_size = original_size
80 |
81 | # Create mask
82 | mask = create_edit_mask(mask_size, edit_areas)
83 |
84 | # Convert image to tensor
85 | image_array = np.array(image) / 255.0 # Normalize to [0, 1]
86 | image_array = image_array.astype(np.float32)
87 | image_tensor = torch.tensor(image_array).permute(2, 0, 1).unsqueeze(0) # Shape: (1, 3, H, W)
88 |
89 | # Expand mask dimensions to match image (1, 1, H, W)
90 | mask = mask.unsqueeze(0).unsqueeze(0)
91 |
92 | return image_tensor, mask
93 | except Exception as e:
94 | print(f"Error loading image {image_path}: {e}")
95 | return None, None
96 |
97 |
98 | def compute_masked_psnr(output_img, input_img, mask):
99 | """Compute PSNR only on masked region (mask=1 pixels)"""
100 | mask_bool = mask > 0.5
101 | mask_3ch = mask_bool.expand_as(output_img)
102 |
103 | # Extract non-edited region pixels
104 | output_pixels = output_img[mask_3ch]
105 | input_pixels = input_img[mask_3ch]
106 |
107 | if output_pixels.numel() == 0:
108 | return None
109 |
110 | # Calculate MSE (only on non-edited region)
111 | mse = torch.mean((output_pixels - input_pixels) ** 2)
112 |
113 | if mse < 1e-10:
114 | return 100.0 # Near infinity, return a large value
115 |
116 | psnr = 20 * torch.log10(torch.tensor(1.0) / torch.sqrt(mse))
117 | return psnr.item()
118 |
119 |
120 | def load_image_simple(image_path: str, size=None):
121 | """Simple image loading without mask processing"""
122 | try:
123 | image = Image.open(image_path).convert('RGB')
124 | if size:
125 | image = image.resize((size, size), Image.BILINEAR)
126 | image = np.array(image) / 255.0 # Normalize to [0, 1]
127 | image = image.astype(np.float32)
128 | image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) # Shape: (1, 3, H, W)
129 | return image
130 | except Exception as e:
131 | print(f"Error loading image {image_path}: {e}")
132 | return None
133 |
134 |
135 | def compute_full_image_psnr(output_img, input_img):
136 | """Compute PSNR on full image"""
137 | mse = torch.mean((output_img - input_img) ** 2)
138 |
139 | if mse < 1e-10:
140 | return 100.0
141 |
142 | psnr = 20 * torch.log10(torch.tensor(1.0) / torch.sqrt(mse))
143 | return psnr.item()
144 |
145 |
146 | def evaluate_single_item(item: Dict[str, Any], base_dir: str, size=None) -> float:
147 | """Evaluate single meta_info item and return PSNR"""
148 | # Build full paths
149 | input_path = os.path.join(base_dir, item['input_path'])
150 | output_path = os.path.join(base_dir, item['output_path'])
151 |
152 | # Check if files exist
153 | if not os.path.exists(input_path):
154 | print(f"Warning: input image not found: {input_path}")
155 | return None
156 |
157 | if not os.path.exists(output_path):
158 | print(f"Warning: output image not found: {output_path}")
159 | return None
160 |
161 | # Get edit area information
162 | edit_areas = item.get('edit_area', [])
163 | use_full_image = False
164 |
165 | # Handle edit_area as string
166 | if isinstance(edit_areas, str):
167 | if edit_areas == "unknown":
168 | use_full_image = True
169 | else:
170 | print(f"Warning: unexpected edit_area format for item {item.get('index', 'unknown')}: {edit_areas}")
171 | return None
172 | elif not edit_areas or len(edit_areas) == 0:
173 | # No edit_area info, use full image evaluation
174 | use_full_image = True
175 |
176 | if use_full_image:
177 | # Full image evaluation
178 | input_image = load_image_simple(input_path, size)
179 | output_image = load_image_simple(output_path, size)
180 |
181 | if input_image is None or output_image is None:
182 | return None
183 |
184 | # Ensure images have same size
185 | if input_image.shape != output_image.shape:
186 | h, w = input_image.shape[2], input_image.shape[3]
187 | output_image = torch.nn.functional.interpolate(output_image, size=(h, w), mode='bilinear', align_corners=False)
188 |
189 | # Compute full image PSNR
190 | psnr = compute_full_image_psnr(output_image, input_image)
191 |
192 | else:
193 | # Non-edited region evaluation
194 | input_image, mask = load_image_with_mask(input_path, edit_areas, size)
195 | output_image, _ = load_image_with_mask(output_path, edit_areas, size)
196 |
197 | if input_image is None or output_image is None or mask is None:
198 | return None
199 |
200 | # Ensure images have same size
201 | if input_image.shape != output_image.shape:
202 | h, w = input_image.shape[2], input_image.shape[3]
203 | output_image = torch.nn.functional.interpolate(output_image, size=(h, w), mode='bilinear', align_corners=False)
204 |
205 | # Compute masked region PSNR
206 | psnr = compute_masked_psnr(output_image, input_image, mask)
207 |
208 | return psnr
209 |
210 |
211 | def process_meta_info(meta_info_path: str, base_dir: str, size=None):
212 | """Process entire meta_info.json file"""
213 |
214 | # Load meta_info.json
215 | with open(meta_info_path, 'r', encoding='utf-8') as f:
216 | meta_info = json.load(f)
217 |
218 | print(f"Loaded {len(meta_info)} items from {meta_info_path}")
219 |
220 | # Detect evaluation mode
221 | has_edit_area = False
222 |
223 | # Check first few items for edit_area status
224 | for item in meta_info[:10]:
225 | edit_areas = item.get('edit_area', [])
226 | if isinstance(edit_areas, list) and len(edit_areas) > 0:
227 | has_edit_area = True
228 | break
229 |
230 | if not has_edit_area:
231 | print("Warning: No valid edit_area found in dataset, using full image evaluation mode")
232 | evaluation_type = "full_image"
233 | else:
234 | evaluation_type = "non_edited_region"
235 |
236 | # Store results
237 | detailed_results = []
238 | valid_scores = []
239 |
240 | # Evaluate each item
241 | for idx, item in tqdm(enumerate(meta_info), desc="Evaluating items"):
242 | psnr = evaluate_single_item(item, base_dir, size)
243 | sample_id = item.get('index', idx)
244 |
245 | # Add to detailed results
246 | result_item = {
247 | 'id': sample_id,
248 | 'input_path': item['input_path'],
249 | 'output_path': item['output_path'],
250 | 'physics_category': item.get('physics_category', 'unknown'),
251 | 'physics_law': item.get('physics_law', 'unknown'),
252 | 'edit_operation': item.get('edit_operation', 'unknown'),
253 | 'difficulty': item.get('difficulty', 'unknown'),
254 | 'psnr': psnr
255 | }
256 | detailed_results.append(result_item)
257 |
258 | # Collect valid scores
259 | if psnr is not None:
260 | valid_scores.append(psnr)
261 |
262 | # Calculate overall statistics
263 | if valid_scores:
264 | overall_stats = {
265 | 'count': len(valid_scores),
266 | 'mean': float(np.mean(valid_scores)),
267 | 'std': float(np.std(valid_scores)),
268 | 'min': float(np.min(valid_scores)),
269 | 'max': float(np.max(valid_scores)),
270 | 'median': float(np.median(valid_scores))
271 | }
272 | else:
273 | overall_stats = {
274 | 'count': 0,
275 | 'mean': None,
276 | 'std': None,
277 | 'min': None,
278 | 'max': None,
279 | 'median': None
280 | }
281 |
282 | # Statistics by physics_category
283 | physics_category_stats = {}
284 | categories = set(item.get('physics_category', 'unknown') for item in meta_info)
285 |
286 | for category in categories:
287 | category_items = [item for item in detailed_results if item['physics_category'] == category]
288 | category_scores = [item['psnr'] for item in category_items if item['psnr'] is not None]
289 |
290 | if category_scores:
291 | physics_category_stats[category] = {
292 | 'count': len(category_scores),
293 | 'mean': float(np.mean(category_scores)),
294 | 'std': float(np.std(category_scores))
295 | }
296 | else:
297 | physics_category_stats[category] = {
298 | 'count': 0,
299 | 'mean': None,
300 | 'std': None
301 | }
302 |
303 | # Statistics by physics_law
304 | physics_law_stats = {}
305 | laws = set(item.get('physics_law', 'unknown') for item in meta_info)
306 |
307 | for law in laws:
308 | law_items = [item for item in detailed_results if item['physics_law'] == law]
309 | law_scores = [item['psnr'] for item in law_items if item['psnr'] is not None]
310 |
311 | if law_scores:
312 | physics_law_stats[law] = {
313 | 'count': len(law_scores),
314 | 'mean': float(np.mean(law_scores)),
315 | 'std': float(np.std(law_scores))
316 | }
317 | else:
318 | physics_law_stats[law] = {
319 | 'count': 0,
320 | 'mean': None,
321 | 'std': None
322 | }
323 |
324 | # Prepare analysis results
325 | final_analysis = {
326 | 'meta_info_path': meta_info_path,
327 | 'total_items': len(meta_info),
328 | 'evaluation_type': evaluation_type,
329 | 'overall_statistics': overall_stats,
330 | 'physics_category_statistics': physics_category_stats,
331 | 'physics_law_statistics': physics_law_stats
332 | }
333 |
334 | return final_analysis, detailed_results
335 |
336 |
337 | def main():
338 | args = parse_args()
339 |
340 | # Determine base directory
341 | if args.base_dir:
342 | base_dir = args.base_dir
343 | else:
344 | base_dir = os.path.dirname(args.meta_info_path)
345 |
346 | print(f"Using base directory: {base_dir}")
347 |
348 | # Process meta_info
349 | analysis_results, detailed_results = process_meta_info(
350 | args.meta_info_path,
351 | base_dir,
352 | args.size
353 | )
354 |
355 | # Generate output file names
356 | meta_info_dir = os.path.dirname(args.meta_info_path)
357 | meta_info_name = os.path.splitext(os.path.basename(args.meta_info_path))[0]
358 |
359 | analysis_output_path = os.path.join(meta_info_dir, f"{meta_info_name}_psnr_analysis.json")
360 | detailed_output_path = os.path.join(meta_info_dir, f"{meta_info_name}_psnr_output.json")
361 |
362 | # Save results
363 | with open(analysis_output_path, 'w', encoding='utf-8') as f:
364 | json.dump(analysis_results, f, indent=2, ensure_ascii=False)
365 |
366 | with open(detailed_output_path, 'w', encoding='utf-8') as f:
367 | json.dump(detailed_results, f, indent=2, ensure_ascii=False)
368 |
369 | print(f"\nResults saved:")
370 | print(f"Analysis: {analysis_output_path}")
371 | print(f"Detailed: {detailed_output_path}")
372 |
373 | # Print overall results
374 | eval_mode = analysis_results['evaluation_type']
375 | eval_label = "Full Image" if eval_mode == "full_image" else "Non-edited Region"
376 | print(f"\n=== Overall PSNR Results ({eval_label}) ===")
377 | stats = analysis_results['overall_statistics']
378 | if stats['mean'] is not None:
379 | print(f"PSNR: {stats['mean']:.4f} ± {stats['std']:.4f} (n={stats['count']})")
380 | print(f" Min: {stats['min']:.4f}, Max: {stats['max']:.4f}, Median: {stats['median']:.4f}")
381 | else:
382 | print(f"PSNR: No valid scores")
383 |
384 | # Print physics_category results
385 | print(f"\n=== PSNR Results by Physics Category ===")
386 | for category, category_stats in analysis_results['physics_category_statistics'].items():
387 | if category_stats['mean'] is not None:
388 | print(f"{category}: {category_stats['mean']:.4f} ± {category_stats['std']:.4f} (n={category_stats['count']})")
389 | else:
390 | print(f"{category}: No valid scores")
391 |
392 | # Print physics_law results
393 | print(f"\n=== PSNR Results by Physics Law ===")
394 | for law, law_stats in analysis_results['physics_law_statistics'].items():
395 | if law_stats['mean'] is not None:
396 | print(f"{law}: {law_stats['mean']:.4f} ± {law_stats['std']:.4f} (n={law_stats['count']})")
397 | else:
398 | print(f"{law}: No valid scores")
399 |
400 |
401 | if __name__ == "__main__":
402 | main()
403 |
--------------------------------------------------------------------------------
/PicaEval_gpt.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """GPT-based PhysEdit evaluation (v2).
4 | Matches the Qwen evaluator input/output contract while delegating inference to OpenAI GPT."""
5 |
6 | import os
7 | import sys
8 | import json
9 | import argparse
10 | import random
11 | import time
12 | import base64
13 | import io
14 | import re
15 | from dataclasses import dataclass
16 | from pathlib import Path
17 | from typing import Any, Dict, List, Optional, Tuple
18 | from concurrent.futures import ThreadPoolExecutor, as_completed
19 |
20 | from PIL import Image, ImageDraw
21 | from tqdm import tqdm
22 | from openai import OpenAI
23 |
24 | # ============================================================================
25 | # :: Shared Data Structures & Helpers (kept aligned with PicaEval_qwen)
26 | # ============================================================================
27 |
28 | IMAGE_PATH_KEYS: Tuple[str, ...] = ("output_path", "output_image_path", "output_img_path")
29 | JSON_PATTERN = re.compile(r"\{[^{}]*\"answer\"[^{}]*\"explanation\"[^{}]*\}", re.IGNORECASE | re.DOTALL)
30 |
31 |
32 | @dataclass
33 | class QATask:
34 | item_index: int
35 | qa_field: str
36 | qa_type: str
37 | qa_index: int
38 | image_path: str
39 | question: str
40 | answer: str
41 | source: str
42 |
43 |
44 | def resolve_image_path(item: Dict[str, Any], base_dir: str) -> Optional[str]:
45 | base_path = Path(base_dir).expanduser().resolve()
46 | base_name = base_path.name
47 | for key in IMAGE_PATH_KEYS:
48 | raw_path = item.get(key)
49 | if not raw_path:
50 | continue
51 |
52 | candidate_path = Path(raw_path)
53 | if candidate_path.is_absolute():
54 | return str(candidate_path)
55 |
56 | normalized_rel = Path(*candidate_path.parts)
57 | candidate = base_path / normalized_rel
58 | if candidate.exists():
59 | return str(candidate)
60 |
61 | rel_parts = normalized_rel.parts
62 | if rel_parts and rel_parts[0] == base_name:
63 | alt = base_path / Path(*rel_parts[1:])
64 | if alt.exists():
65 | return str(alt)
66 |
67 | return str(candidate)
68 | return None
69 |
70 |
71 | def iter_qa_entries(container: Any, default_type: str) -> List[Tuple[str, int, Dict[str, Any]]]:
72 | entries: List[Tuple[str, int, Dict[str, Any]]] = []
73 | if isinstance(container, dict):
74 | for qa_type, qa_list in container.items():
75 | for idx, qa in enumerate(qa_list):
76 | entries.append((qa_type, idx, qa))
77 | elif isinstance(container, list):
78 | for idx, qa in enumerate(container):
79 | entries.append((default_type, idx, qa))
80 | return entries
81 |
82 |
83 | def draw_box_on_image(image: Image.Image, box_info: Dict[str, Any], box_color: str = "red") -> Image.Image:
84 | img_copy = image.copy()
85 | draw = ImageDraw.Draw(img_copy)
86 | x = box_info.get("x", 0)
87 | y = box_info.get("y", 0)
88 | width = box_info.get("width", 0)
89 | height = box_info.get("height", 0)
90 | bbox = [x, y, x + width, y + height]
91 | draw.rectangle(bbox, outline=box_color, width=3)
92 | return img_copy
93 |
94 |
95 | def resize_image(image: Image.Image) -> Image.Image:
96 | width, height = image.size
97 | long_edge = max(width, height)
98 | if long_edge == 1024:
99 | return image
100 | scale = 1024.0 / long_edge
101 | new_width = int(width * scale)
102 | new_height = int(height * scale)
103 | return image.resize((new_width, new_height), Image.LANCZOS)
104 |
105 |
106 | def crop_image_with_box(image: Image.Image, box_info: Dict[str, Any], padding: int = 20) -> Image.Image:
107 | x = box_info.get("x", 0)
108 | y = box_info.get("y", 0)
109 | width = box_info.get("width", 0)
110 | height = box_info.get("height", 0)
111 | x1 = max(0, x - padding)
112 | y1 = max(0, y - padding)
113 | x2 = min(image.size[0], x + width + padding)
114 | y2 = min(image.size[1], y + height + padding)
115 | return image.crop((x1, y1, x2, y2))
116 |
117 |
118 | def generate_viz_filename(item_index: int, qa_type: str, question_index: int, viz_mode: str) -> str:
119 | qa_type_short = qa_type.replace(" ", "_").replace("QA", "qa")
120 | return f"{item_index:04d}_{qa_type_short}_{question_index:03d}_{viz_mode}.jpg"
121 |
122 |
123 | def create_visualization_and_question(
124 | output_img_path: str,
125 | qa_info: Dict[str, Any],
126 | item_index: int,
127 | qa_type: str,
128 | qa_index: int,
129 | viz_mode: str,
130 | viz_dir: str,
131 | args,
132 | ) -> Tuple[str, str]:
133 | question = qa_info.get("question", "")
134 | box_info = qa_info.get("box", {})
135 | viz_filename = generate_viz_filename(item_index, qa_type, qa_index, viz_mode)
136 | viz_path = os.path.join(viz_dir, viz_filename)
137 | viz_rel_path = os.path.join(f"visualization_annotated_qa_{viz_mode}", viz_filename)
138 | os.makedirs(viz_dir, exist_ok=True)
139 | try:
140 | image = Image.open(output_img_path).convert("RGB")
141 | image_resized = resize_image(image)
142 | if viz_mode == "draw_box":
143 | gpt_question = question
144 | viz_image = draw_box_on_image(image_resized, box_info, args.box_color)
145 | elif viz_mode == "crop_box":
146 | gpt_question = question
147 | viz_image = crop_image_with_box(image_resized, box_info, args.viz_padding)
148 | elif viz_mode == "crop_box_and_resize":
149 | gpt_question = question
150 | viz_image = crop_image_with_box(image_resized, box_info, args.viz_padding)
151 | viz_image = resize_image(viz_image)
152 | else:
153 | raise ValueError(f"Unknown viz_mode: {viz_mode}")
154 | viz_image.save(viz_path, quality=90)
155 | return gpt_question, viz_rel_path
156 | except Exception as exc:
157 | print(f"Error creating visualization for {viz_path}: {exc}")
158 | return question, ""
159 |
160 |
161 | def _normalize_answer(answer: str) -> str:
162 | answer_lower = answer.lower().strip().rstrip('.')
163 | if answer_lower in ["yes", "y", "true"]:
164 | return "Yes"
165 | if answer_lower in ["no", "n", "false"]:
166 | return "No"
167 | return answer.strip()
168 |
169 |
170 | def _parse_json_response(response: str) -> Optional[Tuple[str, str]]:
171 | try:
172 | response_clean = response.strip()
173 | json_candidates: List[str] = []
174 | json_start = response_clean.find('{')
175 | json_end = response_clean.rfind('}') + 1
176 | if json_start != -1 and json_end > json_start:
177 | json_candidates.append(response_clean[json_start:json_end])
178 | json_candidates.extend(JSON_PATTERN.findall(response_clean))
179 | for json_str in json_candidates:
180 | try:
181 | data = json.loads(json_str)
182 | answer = data.get("answer", "").strip()
183 | explanation = data.get("explanation", "").strip()
184 | if answer:
185 | return _normalize_answer(answer), explanation
186 | except json.JSONDecodeError:
187 | continue
188 | except Exception:
189 | pass
190 | return None
191 |
192 |
193 | def extract_yes_no_answer(response: str) -> str:
194 | response_lower = response.lower().strip()
195 | if response_lower.startswith("yes"):
196 | return "Yes"
197 | if response_lower.startswith("no"):
198 | return "No"
199 | if re.search(r"\byes\b", response_lower):
200 | return "Yes"
201 | if re.search(r"\bno\b", response_lower):
202 | return "No"
203 | return response[:10] if response else "Unknown"
204 |
205 |
206 | def parse_structured_response(response: str) -> Tuple[str, str]:
207 | parsed = _parse_json_response(response)
208 | if parsed:
209 | return parsed
210 | return extract_yes_no_answer(response), ""
211 |
212 |
213 | def encode_image(image: Image.Image) -> str:
214 | buffer = io.BytesIO()
215 | if image.mode != "RGB":
216 | image = image.convert("RGB")
217 | image.save(buffer, format="JPEG", quality=90)
218 | return base64.b64encode(buffer.getvalue()).decode("utf-8")
219 |
220 |
221 | def load_image_base64(image_path: str, cache: Dict[str, str]) -> str:
222 | if image_path in cache:
223 | return cache[image_path]
224 | try:
225 | image = Image.open(image_path).convert("RGB")
226 | image = resize_image(image)
227 | except Exception as exc:
228 | print(f"Error loading image {image_path}: {exc}")
229 | image = Image.new("RGB", (512, 512), "white")
230 | encoded = encode_image(image)
231 | cache[image_path] = encoded
232 | return encoded
233 |
234 |
235 | def create_structured_prompt(question: str) -> str:
236 | """Create a prompt with structured JSON output instructions"""
237 | json_instruction = (
238 | "\n\nPlease provide a structured answer in the following JSON format:\n"
239 | '{"answer": "Yes" or "No", "explanation": "Brief explanation of your reasoning"}\n\n'
240 | "Output ONLY valid JSON. No extra text."
241 | )
242 | return question + json_instruction
243 |
244 |
245 | def call_gpt_with_retries(client: OpenAI, prompt: str, image_data_url: str, args) -> str:
246 | """Call GPT API with retry logic using OpenAI SDK (responses.create endpoint)"""
247 | for attempt in range(args.max_attempts):
248 | try:
249 | # Build input payload with text and image
250 | input_payload = [{
251 | "role": "user",
252 | "content": [
253 | {"type": "input_text", "text": prompt},
254 | {"type": "input_image", "image_url": image_data_url},
255 | ],
256 | }]
257 |
258 | # Call OpenAI Responses API (for GPT-5/o1 models)
259 | response = client.responses.create(
260 | model=args.gpt_model,
261 | input=input_payload,
262 | )
263 |
264 | return response.output_text.strip()
265 |
266 | except Exception as exc:
267 | wait_time = (args.retry_backoff ** attempt) + random.uniform(0, 1)
268 | print(f"GPT call failed ({attempt + 1}/{args.max_attempts}): {exc}; retry in {wait_time:.1f}s")
269 | time.sleep(wait_time)
270 |
271 | return "Error: all attempts failed"
272 |
273 |
274 | def get_qa_entry(item: Dict[str, Any], qa_field: str, qa_type: str, qa_index: int) -> Dict[str, Any]:
275 | container = item.get(qa_field)
276 | if container is None and qa_field == "qa_pairs":
277 | container = item.get("annotated_qa_pairs")
278 | if isinstance(container, dict):
279 | return container[qa_type][qa_index]
280 | if isinstance(container, list):
281 | return container[qa_index]
282 | raise KeyError(f"QA entry not found for field={qa_field}, type={qa_type}, index={qa_index}")
283 |
284 |
285 | def extract_qa_tasks_standard(items: List[Dict[str, Any]], image_base_dir: str) -> List[QATask]:
286 | qa_tasks: List[QATask] = []
287 | for item_idx, item in enumerate(items):
288 | image_path = resolve_image_path(item, image_base_dir)
289 | if not image_path:
290 | continue
291 | qa_container = item.get("qa_pairs")
292 | if qa_container is None:
293 | qa_container = item.get("annotated_qa_pairs", {})
294 | for qa_type, qa_idx, qa in iter_qa_entries(qa_container, "qa"):
295 | question = qa.get("question")
296 | answer = qa.get("answer")
297 | if question and answer:
298 | qa_tasks.append(QATask(item_idx, "qa_pairs", qa_type, qa_idx, image_path, question, answer, "original"))
299 | return qa_tasks
300 |
301 |
302 | def extract_qa_tasks_annotated(items: List[Dict[str, Any]], image_base_dir: str, viz_mode: str, args) -> List[QATask]:
303 | qa_tasks: List[QATask] = []
304 | viz_dir = os.path.join(image_base_dir, f"visualization_annotated_qa_{viz_mode}")
305 | for item_idx, item in enumerate(items):
306 | image_path = resolve_image_path(item, image_base_dir)
307 | if not image_path:
308 | continue
309 | qa_container = item.get("annotated_qa_pairs", {})
310 | for qa_type, qa_idx, qa in iter_qa_entries(qa_container, "annotated_qa"):
311 | question = qa.get("question")
312 | answer = qa.get("answer")
313 | if not (question and answer):
314 | continue
315 | gpt_question, viz_rel_path = create_visualization_and_question(image_path, qa, item_idx, qa_type, qa_idx, viz_mode, viz_dir, args)
316 | viz_path = os.path.join(image_base_dir, viz_rel_path) if viz_rel_path else image_path
317 | source = "visualization" if viz_rel_path else "original"
318 | qa_tasks.append(QATask(item_idx, "annotated_qa_pairs", qa_type, qa_idx, viz_path, gpt_question, answer, source))
319 | return qa_tasks
320 |
321 |
322 | def process_single_task(task: QATask, items: List[Dict[str, Any]], base64_cache: Dict[str, str],
323 | image_base_dir: str, client: OpenAI, args) -> Dict[str, Any]:
324 | """Worker function to process a single QA task in thread pool"""
325 | try:
326 | # Load and encode image
327 | base64_str = load_image_base64(task.image_path, base64_cache)
328 | data_url = f"data:image/jpeg;base64,{base64_str}"
329 |
330 | # Create prompt and call GPT
331 | prompt = create_structured_prompt(task.question)
332 | model_response = call_gpt_with_retries(client, prompt, data_url, args)
333 |
334 | # Parse response
335 | model_answer, model_explanation = parse_structured_response(model_response)
336 |
337 | # Check correctness
338 | gt_clean = task.answer.lower().strip().rstrip('.')
339 | model_clean = model_answer.lower().strip().rstrip('.')
340 | is_correct = gt_clean == model_clean
341 |
342 | # Return result
343 | result = {
344 | "item_index": task.item_index,
345 | "qa_field": task.qa_field,
346 | "qa_type": task.qa_type,
347 | "qa_index": task.qa_index,
348 | "model_answer": model_answer,
349 | "model_response": model_response,
350 | "model_explanation": model_explanation,
351 | "is_correct": is_correct,
352 | }
353 |
354 | # Add visualization info if applicable
355 | if task.qa_field == "annotated_qa_pairs" and task.source == "visualization":
356 | viz_rel_path = os.path.relpath(task.image_path, image_base_dir)
357 | result["visualization_path"] = viz_rel_path
358 | result["viz_mode"] = args.viz_mode
359 |
360 | return result
361 |
362 | except Exception as e:
363 | return {
364 | "item_index": task.item_index,
365 | "qa_field": task.qa_field,
366 | "qa_type": task.qa_type,
367 | "qa_index": task.qa_index,
368 | "error": str(e),
369 | "model_answer": "Error",
370 | "is_correct": False,
371 | }
372 |
373 |
374 | def evaluate_with_gpt(items: List[Dict[str, Any]], image_base_dir: str, args) -> List[Dict[str, Any]]:
375 | """Main evaluation function using multi-threading"""
376 | # Limit number of items if specified
377 | if args.max_num is not None:
378 | items = items[:args.max_num]
379 |
380 | # Extract QA tasks based on field type
381 | if args.qa_field == "qa_pairs":
382 | qa_tasks = extract_qa_tasks_standard(items, image_base_dir)
383 | elif args.qa_field == "annotated_qa_pairs":
384 | qa_tasks = extract_qa_tasks_annotated(items, image_base_dir, args.viz_mode, args)
385 | else:
386 | raise ValueError(f"Unsupported qa_field: {args.qa_field}")
387 |
388 | if not qa_tasks:
389 | print("No QA tasks found!")
390 | return items
391 |
392 | print(f"Found {len(qa_tasks)} QA tasks")
393 |
394 | # Setup API key
395 | api_key = args.api_key or os.getenv("OPENAI_API_KEY")
396 | if not api_key:
397 | raise RuntimeError("OpenAI API key is required. Set --api_key or OPENAI_API_KEY.")
398 |
399 | # Initialize OpenAI client
400 | client = OpenAI(api_key=api_key, base_url=args.api_base_url)
401 |
402 | # Shared base64 cache (thread-safe for reading)
403 | base64_cache: Dict[str, str] = {}
404 |
405 | # Execute tasks in parallel using thread pool
406 | print(f"Processing with {args.num_workers} workers...")
407 | results = []
408 |
409 | with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
410 | # Submit all tasks
411 | future_to_task = {
412 | executor.submit(process_single_task, task, items, base64_cache, image_base_dir, client, args): task
413 | for task in qa_tasks
414 | }
415 |
416 | # Collect results with progress bar
417 | for future in tqdm(as_completed(future_to_task), total=len(qa_tasks), desc="Calling GPT model"):
418 | try:
419 | result = future.result()
420 | results.append(result)
421 | except Exception as e:
422 | task = future_to_task[future]
423 | print(f"\nError processing task: {e}")
424 | results.append({
425 | "item_index": task.item_index,
426 | "qa_field": task.qa_field,
427 | "qa_type": task.qa_type,
428 | "qa_index": task.qa_index,
429 | "error": str(e),
430 | "model_answer": "Error",
431 | "is_correct": False,
432 | })
433 |
434 | # Merge results back into items
435 | for result in results:
436 | try:
437 | qa_entry = get_qa_entry(items[result["item_index"]], result["qa_field"],
438 | result["qa_type"], result["qa_index"])
439 | qa_entry["model_answer"] = result.get("model_answer", "Error")
440 | qa_entry["model_response"] = result.get("model_response", "")
441 | qa_entry["model_explanation"] = result.get("model_explanation", "")
442 | qa_entry["is_correct"] = result.get("is_correct", False)
443 |
444 | if "visualization_path" in result:
445 | qa_entry["visualization_path"] = result["visualization_path"]
446 | if "viz_mode" in result:
447 | qa_entry["viz_mode"] = result["viz_mode"]
448 | if "error" in result:
449 | qa_entry["error"] = result["error"]
450 | except Exception as e:
451 | print(f"Error merging result: {e}")
452 |
453 | return items
454 |
455 |
456 | def calculate_accuracy_by_dimension(items: List[Dict[str, Any]]) -> Dict[str, Any]:
457 | total_questions = 0
458 | sample_acc_sum = 0.0
459 | sample_count = 0
460 | category_stats: Dict[str, Dict[str, float]] = {}
461 | law_stats: Dict[str, Dict[str, float]] = {}
462 | operation_stats: Dict[str, Dict[str, float]] = {}
463 |
464 | def collect_qas(item: Dict[str, Any]) -> List[Dict[str, Any]]:
465 | qa_sources: List[Dict[str, Any]] = []
466 | for field, default_type in (("qa_pairs", "qa"), ("annotated_qa_pairs", "annotated_qa")):
467 | container = item.get(field)
468 | if not container:
469 | continue
470 | for _, _, qa in iter_qa_entries(container, default_type):
471 | qa_sources.append(qa)
472 | return qa_sources
473 |
474 | def update_stat(stats: Dict[str, Dict[str, float]], key: str, sample_acc: float, qa_total: int) -> None:
475 | if key not in stats:
476 | stats[key] = {"sum_acc": 0.0, "sample_count": 0, "qa_total": 0}
477 | stats[key]["sum_acc"] += sample_acc
478 | stats[key]["sample_count"] += 1
479 | stats[key]["qa_total"] += qa_total
480 |
481 | for item in items:
482 | category = item.get("physics_category", "unknown")
483 | law = item.get("physics_law", "unknown")
484 | operation = item.get("edit_operation", "unknown")
485 | qa_sources = collect_qas(item)
486 | sample_total = 0
487 | sample_correct = 0
488 | for qa in qa_sources:
489 | if "is_correct" in qa:
490 | sample_total += 1
491 | if qa["is_correct"]:
492 | sample_correct += 1
493 | if sample_total == 0:
494 | continue
495 | sample_acc = sample_correct / sample_total
496 | sample_count += 1
497 | sample_acc_sum += sample_acc
498 | total_questions += sample_total
499 | update_stat(category_stats, category, sample_acc, sample_total)
500 | update_stat(law_stats, law, sample_acc, sample_total)
501 | update_stat(operation_stats, operation, sample_acc, sample_total)
502 |
503 | def calc_accuracy(stats: Dict[str, Dict[str, float]]) -> Dict[str, Any]:
504 | result: Dict[str, Any] = {}
505 | for key, value in stats.items():
506 | acc = 100.0 * value["sum_acc"] / value["sample_count"] if value["sample_count"] > 0 else 0.0
507 | result[key] = {
508 | "accuracy": acc,
509 | "sample_count": value["sample_count"],
510 | "qa_total": value["qa_total"],
511 | }
512 | return result
513 |
514 | overall_accuracy = (100.0 * sample_acc_sum / sample_count) if sample_count > 0 else 0.0
515 | return {
516 | "overall_accuracy": overall_accuracy,
517 | "sample_count": sample_count,
518 | "qa_total": total_questions,
519 | "by_category": calc_accuracy(category_stats),
520 | "by_law": calc_accuracy(law_stats),
521 | "by_operation": calc_accuracy(operation_stats),
522 | }
523 |
524 |
525 | def main() -> None:
526 | """Main entry point for GPT evaluation"""
527 | parser = argparse.ArgumentParser(description="GPT-based PhysEdit evaluation with multi-threading")
528 | parser.add_argument("--input_json_path", type=str, required=True, help="Path to meta_info.json")
529 | parser.add_argument("--image_base_dir", type=str, default=None, help="Image base directory; defaults to JSON directory")
530 | parser.add_argument("--qa_field", type=str, default="annotated_qa_pairs",
531 | choices=["qa_pairs", "annotated_qa_pairs"], help="Select qa field")
532 | parser.add_argument("--viz_mode", type=str, default="crop_box_and_resize",
533 | choices=["draw_box", "crop_box", "crop_box_and_resize"], help="Visualization mode")
534 | parser.add_argument("--max_num", type=int, default=None, help="Maximum samples to process")
535 | parser.add_argument("--viz_padding", type=int, default=20, help="Padding pixels for crop mode")
536 | parser.add_argument("--box_color", type=str, default="red", help="Bounding box color")
537 | parser.add_argument("--log_question_changes", action="store_true", help="Log question mutations")
538 | parser.add_argument("--img_size", type=int, default=1024, help="Used for output naming consistency")
539 | parser.add_argument("--api_key", type=str, default="", help="OpenAI API key; overrides OPENAI_API_KEY env")
540 | parser.add_argument("--api_base_url", type=str, default="https://api.openai.com/v1",
541 | help="OpenAI API base URL")
542 | parser.add_argument("--gpt_model", type=str, default="gpt-5", help="OpenAI multimodal model name")
543 | parser.add_argument("--max_attempts", type=int, default=5, help="Max call retries")
544 | parser.add_argument("--retry_backoff", type=float, default=2.0, help="Exponential backoff base")
545 | parser.add_argument("--num_workers", type=int, default=50, help="Number of parallel worker threads")
546 | args = parser.parse_args()
547 |
548 | # Set default image_base_dir
549 | if args.image_base_dir is None:
550 | args.image_base_dir = os.path.dirname(args.input_json_path)
551 |
552 | # Load data
553 | print(f"Loading data: {args.input_json_path}")
554 | with open(args.input_json_path, "r", encoding="utf-8") as f:
555 | data = json.load(f)
556 |
557 | print(f"QA field: {args.qa_field}")
558 | if args.qa_field == "annotated_qa_pairs":
559 | print(f"Visualization mode: {args.viz_mode}")
560 | print(f"Number of workers: {args.num_workers}")
561 |
562 | # Run evaluation
563 | print("Running GPT evaluation...")
564 | results = evaluate_with_gpt(data, args.image_base_dir, args)
565 |
566 | out_dir = os.path.dirname(args.input_json_path)
567 | base_name = os.path.splitext(os.path.basename(args.input_json_path))[0]
568 |
569 | suffix = f"_gpt_output_{args.img_size}"
570 | if args.qa_field == "annotated_qa_pairs":
571 | suffix += f"_{args.viz_mode}"
572 | out_path = os.path.join(out_dir, base_name + suffix + ".json")
573 | with open(out_path, "w", encoding="utf-8") as f:
574 | json.dump(results, f, ensure_ascii=False, indent=2)
575 |
576 | print("Computing statistics...")
577 | stats = calculate_accuracy_by_dimension(results)
578 | analysis_suffix = suffix.replace("_output_", "_analysis_")
579 | analysis_path = os.path.join(out_dir, base_name + analysis_suffix + ".json")
580 | with open(analysis_path, "w", encoding="utf-8") as f:
581 | json.dump(stats, f, ensure_ascii=False, indent=2)
582 |
583 | print("\n=== GPT evaluation finished ===")
584 | print(f"Overall accuracy: {stats['overall_accuracy']:.2f}% (samples: {stats['sample_count']}, qa_total: {stats['qa_total']})")
585 | print("\nBy category:")
586 | for category, stat in stats["by_category"].items():
587 | print(f" {category}: {stat['accuracy']:.2f}% (samples: {stat['sample_count']}, qa_total: {stat['qa_total']})")
588 | print("\nBy law:")
589 | for law, stat in stats["by_law"].items():
590 | print(f" {law}: {stat['accuracy']:.2f}% (samples: {stat['sample_count']}, qa_total: {stat['qa_total']})")
591 | print("\nBy operation:")
592 | for operation, stat in stats["by_operation"].items():
593 | print(f" {operation}: {stat['accuracy']:.2f}% (samples: {stat['sample_count']}, qa_total: {stat['qa_total']})")
594 |
595 | print("\nOutputs saved:")
596 | print(f" Detailed results: {out_path}")
597 | print(f" Analysis: {analysis_path}")
598 |
599 |
600 | if __name__ == "__main__":
601 | random.seed(42)
602 | main()
603 |
--------------------------------------------------------------------------------
/PicaEval_qwen.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # ============================================================================
3 | # :: PhysEdit Evaluation Script (vLLM backend)
4 | # :: Supports qa_pairs / annotated_qa_pairs flows
5 | # :: Includes visualization helpers (draw_box / crop_box modes)
6 | # ============================================================================
7 |
8 | import multiprocessing as mp
9 | if mp.get_start_method(allow_none=True) != "spawn":
10 | mp.set_start_method("spawn", force=True)
11 |
12 | import os
13 | import sys
14 | import json
15 | import argparse
16 | import random
17 | import re
18 | from dataclasses import dataclass
19 | from pathlib import Path
20 | from typing import List, Dict, Any, Tuple, Optional, Iterable
21 | from PIL import Image, ImageDraw
22 | import PIL
23 | from tqdm import tqdm
24 |
25 | from vllm import LLM, SamplingParams
26 | from transformers import AutoProcessor
27 |
28 | IMAGE_PATH_KEYS: Tuple[str, ...] = ("output_path", "output_image_path", "output_img_path")
29 | JSON_PATTERN = re.compile(r'\{[^{}]*"answer"[^{}]*"explanation"[^{}]*\}', re.IGNORECASE | re.DOTALL)
30 |
31 |
32 | @dataclass
33 | class QATask:
34 | item_index: int
35 | qa_field: str
36 | qa_type: str
37 | qa_index: int
38 | image_path: str
39 | question: str
40 | answer: str
41 | source: str
42 |
43 | def resolve_image_path(item: Dict[str, Any], base_dir: str) -> Optional[str]:
44 | """Resolve image path by inspecting common metadata keys."""
45 | base_path = Path(base_dir).expanduser().resolve()
46 | base_name = base_path.name
47 | for key in IMAGE_PATH_KEYS:
48 | raw_path = item.get(key)
49 | if not raw_path:
50 | continue
51 |
52 | candidate_path = Path(raw_path)
53 | if candidate_path.is_absolute():
54 | return str(candidate_path)
55 |
56 | normalized_rel = Path(*candidate_path.parts)
57 | candidate = base_path / normalized_rel
58 | if candidate.exists():
59 | return str(candidate)
60 |
61 | rel_parts = normalized_rel.parts
62 | if rel_parts and rel_parts[0] == base_name:
63 | alt = base_path / Path(*rel_parts[1:])
64 | if alt.exists():
65 | return str(alt)
66 | return str(alt)
67 |
68 | return str(candidate)
69 | return None
70 |
71 | def iter_qa_entries(container: Any, default_type: str) -> Iterable[Tuple[str, int, Dict[str, Any]]]:
72 | """Iterate QA entries uniformly for dict or list containers."""
73 | if isinstance(container, dict):
74 | for qa_type, qa_list in container.items():
75 | for idx, qa in enumerate(qa_list):
76 | yield qa_type, idx, qa
77 | elif isinstance(container, list):
78 | for idx, qa in enumerate(container):
79 | yield default_type, idx, qa
80 |
81 | def draw_box_on_image(image: Image.Image, box_info: Dict, box_color="red") -> Image.Image:
82 | """Draw a bounding box on the resized image."""
83 | img_copy = image.copy()
84 | draw = ImageDraw.Draw(img_copy)
85 |
86 | # :: Derive bounding box coordinates
87 | x = box_info.get("x", 0)
88 | y = box_info.get("y", 0)
89 | width = box_info.get("width", 0)
90 | height = box_info.get("height", 0)
91 |
92 | # :: Render rectangle outline
93 | bbox = [x, y, x + width, y + height]
94 | draw.rectangle(bbox, outline=box_color, width=3)
95 |
96 | return img_copy
97 |
98 | def resize_image(image: Image.Image) -> Image.Image:
99 | """Resize image proportionally so the long edge becomes 1024 pixels."""
100 | width, height = image.size
101 | long_edge = max(width, height)
102 |
103 | if long_edge == 1024:
104 | return image
105 |
106 | scale = 1024.0 / long_edge
107 | new_width = int(width * scale)
108 | new_height = int(height * scale)
109 |
110 | return image.resize((new_width, new_height), Image.LANCZOS)
111 |
112 | def crop_image_with_box(image: Image.Image, box_info: Dict, padding=20) -> Image.Image:
113 | """Crop image around the bounding box with optional padding."""
114 | x = box_info.get("x", 0)
115 | y = box_info.get("y", 0)
116 | width = box_info.get("width", 0)
117 | height = box_info.get("height", 0)
118 |
119 | # :: Apply padding while respecting image bounds
120 | x1 = max(0, x - padding)
121 | y1 = max(0, y - padding)
122 | x2 = min(image.size[0], x + width + padding)
123 | y2 = min(image.size[1], y + height + padding)
124 | cropped_image = image.crop((x1, y1, x2, y2))
125 | return cropped_image
126 |
127 | def generate_viz_filename(item_index: int, qa_type: str, question_index: int, viz_mode: str) -> str:
128 | """Create a deterministic visualization filename."""
129 | qa_type_short = qa_type.replace(" ", "_").replace("QA", "qa")
130 | return f"{item_index:04d}_{qa_type_short}_{question_index:03d}_{viz_mode}.jpg"
131 |
132 | def create_visualization_and_question(output_img_path: str, qa_info: Dict, item_index: int,
133 | qa_type: str, qa_index: int, viz_mode: str,
134 | viz_dir: str, args) -> Tuple[str, str]:
135 | """Create visualization asset and return the text prompt for the model."""
136 | question = qa_info.get("question", "")
137 | box_info = qa_info.get("box", {})
138 |
139 | # :: Construct file paths for visualization assets
140 | viz_filename = generate_viz_filename(item_index, qa_type, qa_index, viz_mode)
141 | viz_path = os.path.join(viz_dir, viz_filename)
142 | # :: Relative path must align with viz_dir for downstream loading
143 | viz_rel_path = os.path.join(f"visualization_annotated_qa_{viz_mode}", viz_filename)
144 |
145 | # :: Ensure target directory exists
146 | os.makedirs(viz_dir, exist_ok=True)
147 |
148 | try:
149 | # :: Load and resize image to match annotation scale (long edge = 1024)
150 | image = Image.open(output_img_path).convert("RGB")
151 | image_resized = resize_image(image)
152 |
153 | if viz_mode == "draw_box":
154 | # :: Draw-box mode keeps question intact and overlays box
155 | vllm_question = question
156 | viz_image = draw_box_on_image(image_resized, box_info, args.box_color)
157 | elif viz_mode == "crop_box":
158 | # :: Crop mode preserves question and crops the image
159 | vllm_question = question
160 | viz_image = crop_image_with_box(image_resized, box_info, args.viz_padding)
161 | elif viz_mode == "crop_box_and_resize":
162 | # :: Crop+resize mode crops the region then upscales
163 | vllm_question = question
164 | viz_image = crop_image_with_box(image_resized, box_info, args.viz_padding)
165 | viz_image = resize_image(viz_image)
166 | else:
167 | raise ValueError(f"Unknown viz_mode: {viz_mode}")
168 |
169 | # :: Persist visualization image
170 | viz_image.save(viz_path, quality=90)
171 |
172 | # :: Optionally log prompt mutations
173 | if args.log_question_changes and question != vllm_question:
174 | print(f"Question modified ({viz_mode}): {question} -> {vllm_question}")
175 |
176 | return vllm_question, viz_rel_path
177 |
178 | except Exception as e:
179 | print(f"Error creating visualization for {viz_path}: {e}")
180 | return question, ""
181 |
182 | def parse_structured_response(response: str) -> Tuple[str, str]:
183 | """Parse structured JSON answer from model output."""
184 | return _parse_json_response(response) or _fallback_parse_response(response)
185 |
186 | def _parse_json_response(response: str) -> Tuple[str, str] | None:
187 | """Attempt to decode JSON-formatted response."""
188 | try:
189 | response_clean = response.strip()
190 |
191 | # :: Collect possible JSON snippets
192 | json_candidates = []
193 |
194 | # :: Strategy 1: capture outermost braces
195 | json_start = response_clean.find('{')
196 | json_end = response_clean.rfind('}') + 1
197 | if json_start != -1 and json_end > json_start:
198 | json_candidates.append(response_clean[json_start:json_end])
199 |
200 | # :: Strategy 2: apply regex extraction
201 | json_candidates.extend(JSON_PATTERN.findall(response_clean))
202 |
203 | # :: Try parsing each candidate
204 | for json_str in json_candidates:
205 | try:
206 | data = json.loads(json_str)
207 | answer = data.get("answer", "").strip()
208 | explanation = data.get("explanation", "").strip()
209 |
210 | if answer:
211 | return _normalize_answer(answer), explanation
212 | except json.JSONDecodeError:
213 | continue
214 |
215 | except Exception:
216 | pass
217 |
218 | return None
219 |
220 | def _normalize_answer(answer: str) -> str:
221 | """Normalize canonical yes/no answer casing."""
222 | answer_lower = answer.lower().strip().rstrip('.')
223 | if answer_lower in ["yes", "y", "true"]:
224 | return "Yes"
225 | elif answer_lower in ["no", "n", "false"]:
226 | return "No"
227 | return answer.strip()
228 |
229 | def _fallback_parse_response(response: str) -> Tuple[str, str]:
230 | """Fallback parser when structured JSON is unavailable."""
231 | return extract_yes_no_answer(response), ""
232 |
233 | def extract_yes_no_answer(response: str) -> str:
234 | """Extract yes/no answer heuristically."""
235 | response_lower = response.lower().strip()
236 |
237 | # :: Prefer prefix match
238 | if response_lower.startswith("yes"):
239 | return "Yes"
240 | elif response_lower.startswith("no"):
241 | return "No"
242 |
243 | # :: Search for whole-word matches
244 | if re.search(r'\byes\b', response_lower):
245 | return "Yes"
246 | elif re.search(r'\bno\b', response_lower):
247 | return "No"
248 |
249 | # :: Leave snippet when answer is unclear
250 | return response[:10] if response else "Unknown"
251 |
252 | def init_llm(model_path: str, tp: int, dtype: str, gpu_util: float, max_len: int):
253 | """Initialize vLLM engine."""
254 | llm = LLM(
255 | model=model_path,
256 | tensor_parallel_size=tp,
257 | dtype=dtype,
258 | gpu_memory_utilization=gpu_util,
259 | trust_remote_code=True,
260 | max_model_len=max_len,
261 | )
262 | return llm
263 |
264 | def create_structured_message(msgs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
265 | """Append structured JSON instructions to the final user turn."""
266 | structured_msg = msgs.copy()
267 | if not structured_msg or structured_msg[-1]["role"] != "user":
268 | return structured_msg
269 |
270 | # :: Structured answer instruction
271 | json_instruction = "\n\nPlease provide a structured answer in the following JSON format:\n{\"answer\": \"Yes\" or \"No\", \"explanation\": \"Brief explanation of your reasoning\"}\n\nOutput ONLY valid JSON. No extra text."
272 |
273 | original_content = structured_msg[-1]["content"]
274 | if isinstance(original_content, list):
275 | # :: Multi-modal payload
276 | structured_content = original_content + [{"type": "text", "text": json_instruction}]
277 | else:
278 | # :: Text-only payload
279 | structured_content = original_content + json_instruction
280 |
281 | structured_msg[-1]["content"] = structured_content
282 | return structured_msg
283 |
284 | def prepare_vllm_batch(qa_tasks: List[QATask]) -> Tuple[List[List[Dict[str, Any]]], List[List[Image.Image]]]:
285 | """Prepare messages and images for vLLM batching."""
286 | msgs_batch: List[List[Dict[str, Any]]] = []
287 | imgs_batch: List[List[Image.Image]] = []
288 | pil_cache: Dict[str, Image.Image] = {}
289 |
290 | for task in tqdm(qa_tasks, desc="Preparing images"):
291 | if task.image_path not in pil_cache:
292 | try:
293 | img = Image.open(task.image_path).convert("RGB")
294 | img = resize_image(img)
295 | pil_cache[task.image_path] = img
296 | except Exception as e:
297 | print(f"Error loading image {task.image_path}: {e}")
298 | pil_cache[task.image_path] = Image.new("RGB", (512, 512), "white")
299 | img = pil_cache[task.image_path]
300 | msgs_batch.append([{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": task.question}]}])
301 | imgs_batch.append([img])
302 |
303 | return msgs_batch, imgs_batch
304 |
305 | def build_vllm_inputs(processor: AutoProcessor,
306 | batch_msgs: List[List[Dict[str, Any]]],
307 | batch_images: List[List[Image.Image]]) -> List[Dict[str, Any]]:
308 | """Build vLLM input payloads."""
309 | # :: Attach structured-answer instructions
310 | structured_msgs = [create_structured_message(msgs) for msgs in batch_msgs]
311 |
312 | # :: Render prompts via chat template
313 | prompts = [
314 | processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
315 | for msgs in structured_msgs
316 | ]
317 |
318 | # :: Assemble vLLM-ready dicts
319 | return [
320 | {"prompt": text, "multi_modal_data": {"image": imgs}}
321 | for text, imgs in zip(prompts, batch_images)
322 | ]
323 |
324 | def vllm_generate_batch(llm: LLM, processor: AutoProcessor,
325 | batch_msgs: List[List[Dict[str, Any]]],
326 | batch_images: List[List[Image.Image]],
327 | max_new_tokens: int) -> List[str]:
328 | """Run batched inference through vLLM."""
329 | # :: Prepare inputs
330 | inputs = build_vllm_inputs(processor, batch_msgs, batch_images)
331 |
332 | # :: Configure sampling strategy
333 | sp = SamplingParams(
334 | temperature=0.0,
335 | top_p=1.0,
336 | max_tokens=max_new_tokens,
337 | stop=["\uFFFD\uFFFD", "\n\uFFFD\uFFFD", "\U0001F4D0\n"]
338 | )
339 |
340 | # :: Execute generation
341 | outputs = llm.generate(inputs, sp)
342 | return [o.outputs[0].text if o.outputs else "" for o in outputs]
343 |
344 | def get_qa_entry(item: Dict[str, Any], qa_field: str, qa_type: str, qa_index: int) -> Dict[str, Any]:
345 | """Locate QA entry for annotation updates."""
346 | container = item.get(qa_field)
347 | if container is None and qa_field == "qa_pairs":
348 | container = item.get("annotated_qa_pairs")
349 | if isinstance(container, dict):
350 | return container[qa_type][qa_index]
351 | if isinstance(container, list):
352 | return container[qa_index]
353 | raise KeyError(f"QA entry not found for field={qa_field}, type={qa_type}, index={qa_index}")
354 |
355 |
356 | def extract_qa_tasks_standard(items: List[Dict[str, Any]], image_base_dir: str) -> List[QATask]:
357 | """Collect tasks from qa_pairs field."""
358 | qa_tasks: List[QATask] = []
359 | for item_idx, item in enumerate(items):
360 | image_path = resolve_image_path(item, image_base_dir)
361 | if not image_path:
362 | continue
363 | qa_container = item.get("qa_pairs")
364 | if qa_container is None:
365 | qa_container = item.get("annotated_qa_pairs", {})
366 | for qa_type, qa_idx, qa in iter_qa_entries(qa_container, "qa"):
367 | question = qa.get("question")
368 | answer = qa.get("answer")
369 | if question and answer:
370 | qa_tasks.append(QATask(
371 | item_index=item_idx,
372 | qa_field="qa_pairs",
373 | qa_type=qa_type,
374 | qa_index=qa_idx,
375 | image_path=image_path,
376 | question=question,
377 | answer=answer,
378 | source="original"
379 | ))
380 | return qa_tasks
381 |
382 | def extract_qa_tasks_annotated(items: List[Dict[str, Any]], image_base_dir: str, viz_mode: str, args) -> List[QATask]:
383 | """Collect annotated QA tasks and generate visualizations."""
384 | qa_tasks: List[QATask] = []
385 | viz_dir = os.path.join(image_base_dir, f"visualization_annotated_qa_{viz_mode}")
386 | for item_idx, item in enumerate(items):
387 | image_path = resolve_image_path(item, image_base_dir)
388 | if not image_path:
389 | continue
390 | qa_container = item.get("annotated_qa_pairs", {})
391 | for qa_type, qa_idx, qa in iter_qa_entries(qa_container, "annotated_qa"):
392 | question = qa.get("question")
393 | answer = qa.get("answer")
394 | if not (question and answer):
395 | continue
396 | vllm_question, viz_rel_path = create_visualization_and_question(
397 | image_path, qa, item_idx, qa_type, qa_idx, viz_mode, viz_dir, args
398 | )
399 | viz_path = os.path.join(image_base_dir, viz_rel_path) if viz_rel_path else image_path
400 | source = "visualization" if viz_rel_path else "original"
401 | qa_tasks.append(QATask(
402 | item_index=item_idx,
403 | qa_field="annotated_qa_pairs",
404 | qa_type=qa_type,
405 | qa_index=qa_idx,
406 | image_path=viz_path,
407 | question=vllm_question,
408 | answer=answer,
409 | source=source
410 | ))
411 | return qa_tasks
412 |
413 | def evaluate_physedit_with_vllm(items: List[Dict[str, Any]], image_base_dir: str,
414 | processor: AutoProcessor, llm: LLM, args) -> List[Dict[str, Any]]:
415 | """Main evaluation routine."""
416 | if args.max_num is not None:
417 | items = items[:args.max_num]
418 |
419 | if args.qa_field == "qa_pairs":
420 | qa_tasks = extract_qa_tasks_standard(items, image_base_dir)
421 | elif args.qa_field == "annotated_qa_pairs":
422 | qa_tasks = extract_qa_tasks_annotated(items, image_base_dir, args.viz_mode, args)
423 | else:
424 | raise ValueError(f"Unsupported qa_field: {args.qa_field}")
425 |
426 | if not qa_tasks:
427 | print("No QA tasks found!")
428 | return items
429 |
430 | print(f"Found {len(qa_tasks)} QA tasks")
431 | msgs_batch, imgs_batch = prepare_vllm_batch(qa_tasks)
432 | print("Running VLM inference...")
433 | answers = vllm_generate_batch(llm, processor, msgs_batch, imgs_batch, args.max_new_tokens)
434 |
435 | for task, model_response in zip(qa_tasks, answers):
436 | model_answer, model_explanation = parse_structured_response(model_response)
437 | gt_clean = task.answer.lower().strip().rstrip('.')
438 | model_clean = model_answer.lower().strip().rstrip('.')
439 | is_correct = gt_clean == model_clean
440 | qa_entry = get_qa_entry(items[task.item_index], task.qa_field, task.qa_type, task.qa_index)
441 | qa_entry["model_answer"] = model_answer
442 | qa_entry["model_response"] = model_response
443 | qa_entry["model_explanation"] = model_explanation
444 | qa_entry["is_correct"] = is_correct
445 | if task.qa_field == "annotated_qa_pairs" and task.source == "visualization":
446 | viz_rel_path = os.path.relpath(task.image_path, args.image_base_dir)
447 | qa_entry["visualization_path"] = viz_rel_path
448 | qa_entry["viz_mode"] = args.viz_mode
449 |
450 | return items
451 |
452 | def calculate_accuracy_by_dimension(items: List[Dict[str, Any]]) -> Dict[str, Any]:
453 | """Compute accuracy per dimension using sample-wise averaging."""
454 | total_questions = 0
455 | sample_acc_sum = 0.0
456 | sample_count = 0
457 | category_stats = {}
458 | law_stats = {}
459 | operation_stats = {}
460 |
461 | def collect_qas(item: Dict[str, Any]) -> List[Dict[str, Any]]:
462 | qa_sources: List[Dict[str, Any]] = []
463 | for field, default_type in (("qa_pairs", "qa"), ("annotated_qa_pairs", "annotated_qa")):
464 | container = item.get(field)
465 | if not container:
466 | continue
467 | for _, _, qa in iter_qa_entries(container, default_type):
468 | qa_sources.append(qa)
469 | return qa_sources
470 |
471 | def update_stat(stats: Dict, key: str, sample_acc: float, qa_total: int):
472 | if key not in stats:
473 | stats[key] = {"sum_acc": 0.0, "sample_count": 0, "qa_total": 0}
474 | stats[key]["sum_acc"] += sample_acc
475 | stats[key]["sample_count"] += 1
476 | stats[key]["qa_total"] += qa_total
477 |
478 | for item in items:
479 | category = item.get("physics_category", "unknown")
480 | law = item.get("physics_law", "unknown")
481 | operation = item.get("edit_operation", "unknown")
482 |
483 | qa_sources = collect_qas(item)
484 | sample_total = 0
485 | sample_correct = 0
486 | for qa in qa_sources:
487 | if "is_correct" in qa:
488 | sample_total += 1
489 | if qa["is_correct"]:
490 | sample_correct += 1
491 | if sample_total == 0:
492 | continue
493 |
494 | sample_acc = sample_correct / sample_total
495 | sample_count += 1
496 | sample_acc_sum += sample_acc
497 | total_questions += sample_total
498 |
499 | update_stat(category_stats, category, sample_acc, sample_total)
500 | update_stat(law_stats, law, sample_acc, sample_total)
501 | update_stat(operation_stats, operation, sample_acc, sample_total)
502 |
503 | def calc_accuracy(stats: Dict) -> Dict:
504 | result = {}
505 | for key, value in stats.items():
506 | if value["sample_count"] > 0:
507 | acc = 100.0 * value["sum_acc"] / value["sample_count"]
508 | else:
509 | acc = 0.0
510 | result[key] = {
511 | "accuracy": acc,
512 | "sample_count": value["sample_count"],
513 | "qa_total": value["qa_total"]
514 | }
515 | return result
516 |
517 | overall_accuracy = (100.0 * sample_acc_sum / sample_count) if sample_count > 0 else 0.0
518 | return {
519 | "overall_accuracy": overall_accuracy,
520 | "sample_count": sample_count,
521 | "qa_total": total_questions,
522 | "by_category": calc_accuracy(category_stats),
523 | "by_law": calc_accuracy(law_stats),
524 | "by_operation": calc_accuracy(operation_stats)
525 | }
526 |
527 | def main():
528 | parser = argparse.ArgumentParser()
529 | parser.add_argument("--input_json_path", type=str, required=True,
530 | help="Path to meta_info.json file")
531 | parser.add_argument("--image_base_dir", type=str, default=None,
532 | help="Image root directory; defaults to the JSON directory")
533 | parser.add_argument("--model_path", type=str, default="pretrained/Qwen/Qwen2.5-VL-72B-Instruct",
534 | help="Model checkpoint path or Hugging Face identifier")
535 | parser.add_argument("--qa_field", type=str, default="annotated_qa_pairs",
536 | choices=["qa_pairs", "annotated_qa_pairs"],
537 | help="Choose which QA field to evaluate")
538 | parser.add_argument("--viz_mode", type=str, default="crop_box_and_resize",
539 | choices=["draw_box", "crop_box", "crop_box_and_resize"],
540 | help="Visualization mode (used only for annotated QA)")
541 | parser.add_argument("--tensor_parallel_size", type=int, default=4,
542 | help="Tensor parallel shard count")
543 | parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16","float16"])
544 | parser.add_argument("--gpu_mem_util", type=float, default=0.9)
545 | parser.add_argument("--max_model_len", type=int, default=5120)
546 | parser.add_argument("--max_new_tokens", type=int, default=256)
547 | parser.add_argument("--img_size", type=int, default=1024, choices=[512, 1024])
548 | parser.add_argument("--max_num", type=int, default=None, help="Maximum number of samples to process")
549 | parser.add_argument("--viz_padding", type=int, default=20, help="Padding pixels for crop mode")
550 | parser.add_argument("--box_color", default="red", help="Bounding box color")
551 | parser.add_argument("--log_question_changes", action="store_true",
552 | help="Log question text mutations")
553 | args = parser.parse_args()
554 |
555 | # :: Derive default image_base_dir
556 | if args.image_base_dir is None:
557 | args.image_base_dir = os.path.dirname(args.input_json_path)
558 |
559 | # :: Load dataset
560 | print(f"Loading data: {args.input_json_path}")
561 | with open(args.input_json_path, "r", encoding="utf-8") as f:
562 | data = json.load(f)
563 |
564 | print(f"QA field: {args.qa_field}")
565 | if args.qa_field == "annotated_qa_pairs":
566 | print(f"Visualization mode: {args.viz_mode}")
567 |
568 | # :: Initialize model
569 | print("Initializing model...")
570 | processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
571 | llm = init_llm(args.model_path, args.tensor_parallel_size, args.dtype,
572 | args.gpu_mem_util, args.max_model_len)
573 |
574 | # :: Evaluate
575 | print("Starting evaluation...")
576 | results = evaluate_physedit_with_vllm(data, args.image_base_dir, processor, llm, args)
577 |
578 | # :: Persist outputs
579 | out_dir = os.path.dirname(args.input_json_path)
580 | base_name = os.path.splitext(os.path.basename(args.input_json_path))[0]
581 |
582 | # :: Build output filenames
583 | suffix = f"_vllm_output_{args.img_size}"
584 | if args.qa_field == "annotated_qa_pairs":
585 | suffix += f"_{args.viz_mode}"
586 |
587 | out_path = os.path.join(out_dir, base_name + suffix + ".json")
588 | with open(out_path, "w", encoding="utf-8") as f:
589 | json.dump(results, f, ensure_ascii=False, indent=2)
590 |
591 | # :: Compute and save statistics
592 | print("Computing statistics...")
593 | stats = calculate_accuracy_by_dimension(results)
594 |
595 | analysis_suffix = suffix.replace("_output_", "_analysis_")
596 | analysis_path = os.path.join(out_dir, base_name + analysis_suffix + ".json")
597 | with open(analysis_path, "w", encoding="utf-8") as f:
598 | json.dump(stats, f, ensure_ascii=False, indent=2)
599 |
600 | # :: Display summary
601 | print(f"\n=== Evaluation finished ===")
602 | print(f"Overall accuracy: {stats['overall_accuracy']:.2f}% "
603 | f"(samples: {stats['sample_count']}, qa_total: {stats['qa_total']})")
604 | print(f"\nBy category:")
605 | for category, stat in stats["by_category"].items():
606 | print(f" {category}: {stat['accuracy']:.2f}% "
607 | f"(samples: {stat['sample_count']}, qa_total: {stat['qa_total']})")
608 | print(f"\nBy physics law:")
609 | for law, stat in stats["by_law"].items():
610 | print(f" {law}: {stat['accuracy']:.2f}% "
611 | f"(samples: {stat['sample_count']}, qa_total: {stat['qa_total']})")
612 | print(f"\nBy operation type:")
613 | for operation, stat in stats["by_operation"].items():
614 | print(f" {operation}: {stat['accuracy']:.2f}% "
615 | f"(samples: {stat['sample_count']}, qa_total: {stat['qa_total']})")
616 |
617 | print(f"\nOutputs saved:")
618 | print(f" Detailed results: {out_path}")
619 | print(f" Statistics: {analysis_path}")
620 |
621 | if __name__ == "__main__":
622 | random.seed(42)
623 | main()
624 |
--------------------------------------------------------------------------------