├── prepare_meta_info.py ├── README.md ├── PicaEval_consistency.py ├── PicaEval_gpt.py └── PicaEval_qwen.py /prepare_meta_info.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Generate standard meta_info.json from HuggingFace PICABench dataset + model output images. 4 | 5 | Solves the problem: 6 | - Users get a Dataset object from load_dataset but don't know how to organize it into the JSON format required by evaluation scripts 7 | - Automatically saves input images to filesystem, maps output image paths, and generates complete metadata 8 | 9 | Usage: 10 | pip install datasets pillow tqdm 11 | 12 | # Basic usage: assuming output images are in outputs/ directory with filenames matching dataset index 13 | python prepare_meta_info.py --output_image_dir outputs --save_dir PICABench_data 14 | 15 | # Specify output image naming pattern (if not named by index) 16 | python prepare_meta_info.py --output_image_dir my_results --save_dir data \ 17 | --output_name_pattern "{index:05d}.png" 18 | 19 | Output: 20 | - save_dir/input_img/ # Input images (saved from HF dataset) 21 | - save_dir/meta_info.json # Standard format JSON, ready for evaluation 22 | """ 23 | 24 | import argparse 25 | import json 26 | import os 27 | from pathlib import Path, PurePosixPath 28 | from typing import Any, Dict, List, Optional 29 | 30 | from datasets import load_dataset 31 | from PIL import Image 32 | from tqdm import tqdm 33 | 34 | 35 | def find_output_image(output_dir: Path, index: int, pattern: str) -> Optional[str]: 36 | """ 37 | Find output image based on index and naming pattern. 38 | pattern supports {index} placeholder, e.g. "{index:05d}.jpg" 39 | """ 40 | # Try user-specified pattern 41 | filename = pattern.format(index=index) 42 | candidate = output_dir / filename 43 | if candidate.exists(): 44 | return filename 45 | 46 | # Try common extensions 47 | for ext in [".jpg", ".jpeg", ".png", ".webp"]: 48 | for fmt in [f"{index:05d}{ext}", f"{index:04d}{ext}", f"{index}{ext}"]: 49 | candidate = output_dir / fmt 50 | if candidate.exists(): 51 | return fmt 52 | 53 | return None 54 | 55 | 56 | def save_input_image(img: Image.Image, path: Path) -> None: 57 | """Save input image to specified path""" 58 | path.parent.mkdir(parents=True, exist_ok=True) 59 | img.convert("RGB").save(path, quality=95) 60 | 61 | 62 | def build_meta_item( 63 | idx: int, 64 | example: Dict[str, Any], 65 | input_filename: str, 66 | output_filename: Optional[str], 67 | output_dir_for_json: str, 68 | ) -> Dict[str, Any]: 69 | """Build a single meta_info record""" 70 | if output_filename: 71 | output_path = ( 72 | str(PurePosixPath(output_dir_for_json) / output_filename) 73 | if output_dir_for_json 74 | else output_filename 75 | ) 76 | else: 77 | output_path = None 78 | 79 | item = { 80 | "index": idx, 81 | "input_path": f"input_img/{input_filename}", 82 | "output_path": output_path, 83 | "edit_instruction": example.get("edit_instruction", ""), 84 | "physics_category": example.get("physics_category", "unknown"), 85 | "physics_law": example.get("physics_law", "unknown"), 86 | "edit_operation": example.get("edit_operation", "unknown"), 87 | "difficulty": example.get("difficulty", "unknown"), 88 | "annotated_qa_pairs": example.get("annotated_qa_pairs", []), 89 | "edit_area": example.get("edit_area", "unknown"), 90 | } 91 | return item 92 | 93 | 94 | def main() -> None: 95 | parser = argparse.ArgumentParser( 96 | description="Generate meta_info.json from HF PICABench dataset + output image directory" 97 | ) 98 | parser.add_argument( 99 | "--hf_repo", 100 | type=str, 101 | default="PICABench", 102 | help="HuggingFace dataset repository name", 103 | ) 104 | parser.add_argument( 105 | "--split", 106 | type=str, 107 | default="picabench", 108 | help="Dataset split name", 109 | ) 110 | parser.add_argument( 111 | "--output_image_dir", 112 | type=str, 113 | required=True, 114 | help="Directory containing model-generated output images (relative to save_dir or absolute path)", 115 | ) 116 | parser.add_argument( 117 | "--save_dir", 118 | type=str, 119 | required=True, 120 | help="Root directory to save meta_info.json and input_img/", 121 | ) 122 | parser.add_argument( 123 | "--output_name_pattern", 124 | type=str, 125 | default="{index:05d}.jpg", 126 | help="Output image filename pattern, supports {index} placeholder, default {index:05d}.jpg", 127 | ) 128 | parser.add_argument( 129 | "--allow_missing", 130 | action="store_true", 131 | help="Allow missing output images, still generate JSON (output_path will be null)", 132 | ) 133 | parser.add_argument( 134 | "--force_input_save", 135 | action="store_true", 136 | help="Overwrite input images even if they already exist under save_dir/input_img", 137 | ) 138 | args = parser.parse_args() 139 | 140 | save_dir = Path(args.save_dir).resolve() 141 | input_dir = save_dir / "input_img" 142 | input_dir.mkdir(parents=True, exist_ok=True) 143 | 144 | # Output image directory: supports relative path (relative to save_dir) or absolute path 145 | output_dir_arg = Path(args.output_image_dir) 146 | output_dir = output_dir_arg if output_dir_arg.is_absolute() else save_dir / output_dir_arg 147 | output_dir = output_dir.resolve() 148 | 149 | try: 150 | output_dir_for_json = output_dir.relative_to(save_dir).as_posix() 151 | except ValueError: 152 | output_dir_for_json = output_dir.as_posix() 153 | 154 | if output_dir_for_json in ("", "."): 155 | output_dir_for_json = "" 156 | 157 | print(f"Loading dataset: {args.hf_repo} (split={args.split})") 158 | dataset = load_dataset(args.hf_repo, split=args.split) 159 | print(f"Dataset size: {len(dataset)}") 160 | 161 | meta_info: List[Dict[str, Any]] = [] 162 | missing_outputs: List[int] = [] 163 | 164 | for idx, example in tqdm(enumerate(dataset), total=len(dataset), desc="Processing samples"): 165 | # 1. Save input image 166 | input_img = example.get("input_image") 167 | if input_img is None: 168 | # Fallback: try loading from image_path 169 | img_path = example.get("image_path") 170 | if img_path and os.path.exists(img_path): 171 | input_img = Image.open(img_path) 172 | else: 173 | print(f"Warning: sample {idx} has no input image, skipping") 174 | continue 175 | 176 | input_filename = f"{idx:05d}.jpg" 177 | input_path = input_dir / input_filename 178 | if input_path.exists(): 179 | if args.force_input_save: 180 | save_input_image(input_img, input_path) 181 | else: 182 | save_input_image(input_img, input_path) 183 | 184 | # 2. Find corresponding output image 185 | output_filename = find_output_image(output_dir, idx, args.output_name_pattern) 186 | if output_filename is None: 187 | missing_outputs.append(idx) 188 | if not args.allow_missing: 189 | continue # Skip samples without output images 190 | 191 | # 3. Build meta_info entry 192 | item = build_meta_item(idx, example, input_filename, output_filename, output_dir_for_json) 193 | meta_info.append(item) 194 | 195 | # Save JSON 196 | json_path = save_dir / "meta_info.json" 197 | with open(json_path, "w", encoding="utf-8") as f: 198 | json.dump(meta_info, f, ensure_ascii=False, indent=2) 199 | 200 | print(f"\n✓ Successfully generated meta_info.json: {json_path}") 201 | print(f" - Total samples: {len(meta_info)}") 202 | print(f" - Input images saved to: {input_dir}") 203 | 204 | if missing_outputs: 205 | print(f"\n⚠ Warning: {len(missing_outputs)} samples missing output images") 206 | if len(missing_outputs) <= 10: 207 | print(f" Missing indices: {missing_outputs}") 208 | else: 209 | print(f" Missing indices (first 10): {missing_outputs[:10]}") 210 | 211 | if not args.allow_missing: 212 | print(" These samples are excluded from JSON") 213 | else: 214 | print(" These samples have output_path set to null") 215 | 216 | 217 | if __name__ == "__main__": 218 | main() 219 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | PICABench: How Far Are We from Physically Realistic Image Editing? 4 |

5 | 6 |

Benchmark, evaluator, and data suite for physically realistic image editing.

7 | 8 | [![Huggingface Paper](https://img.shields.io/badge/Paper-2510.17681-ffcc00?style=for-the-badge&logo=huggingface&logoColor=black)](https://huggingface.co/papers/2510.17681) 9 | [![arXiv](https://img.shields.io/badge/arXiv-2510.17681-b31b1b?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2510.17681) 10 | [![Website](https://img.shields.io/badge/Project-Website-007ec6?style=for-the-badge)](https://picabench.github.io) 11 | 12 | PICABench Dataset 13 | PICA-100K Dataset 14 | 15 | 16 | 17 |
18 | PICABench teaser 19 |
20 | 21 | ## 📋 Table of Contents 22 | 23 | - [Overview](#overview) 24 | - [Quick Start](#quick-start) 25 | - [Installation](#installation) 26 | - [Data Preparation](#data-preparation) 27 | - [Evaluation Pipelines](#evaluation-pipelines) 28 | - [PICA-100K Training Data](#pica-100k-training-data) 29 | - [Leaderboard & Qualitative Explorer](#leaderboard--qualitative-explorer) 30 | - [Leaderboard Submission](#leaderboard-submission) 31 | - [Citation](#citation) 32 | 33 | ## Overview 34 | 35 | PICABench probes how far current editing models are from physically realistic image manipulation. It ties together: 36 | 37 | - **PICABench benchmark** – physics-aware editing cases spanning eight laws across *Optics*, *Mechanics*, and *State Transition*, each labeled with superficial/intermediate/explicit difficulty tiers. 38 | - **PICAEval metric** – region-grounded, QA-based verification with human-annotated regions of interest (ROIs) and spatially anchored yes/no questions. 39 | - **PICA-100K dataset** – synthetic, video-derived training data that boosts physics consistency when used for fine-tuning. 40 | 41 | The leaderboard shows that even top proprietary systems only reach ~60% accuracy, indicating a significant physics-awareness gap. 42 | 43 | ## ⚡ Quick Start 44 | 45 | Evaluate your model's physics-aware editing from a folder of output images in 3 steps: 46 | 47 | ```bash 48 | # 1) Install dependencies (choose GPT or Qwen) 49 | pip install openai Pillow tqdm datasets huggingface_hub # GPT 50 | # or 51 | pip install vllm transformers Pillow tqdm datasets huggingface_hub # Qwen/vLLM 52 | 53 | # 2) Build meta_info.json from HF dataset + your outputs 54 | # (Assume your edited images are under ./outputs as 00000.jpg, 00001.jpg, ...) 55 | python prepare_meta_info.py \ 56 | --hf_repo Andrew613/PICABench \ 57 | --output_image_dir outputs \ 58 | --save_dir PICABench_data 59 | 60 | # 3) Run evaluation (multi-threaded) 61 | export OPENAI_API_KEY="sk-..." 62 | python PicaEval_gpt.py \ 63 | --input_json_path PICABench_data/meta_info.json \ 64 | --gpt_model gpt-4o \ 65 | --num_workers 16 66 | ``` 67 | 68 | Notes: 69 | - When `meta_info.json` lives in `PICABench_data/`, you can omit `--image_base_dir` (defaults to the JSON directory). 70 | - If your output images are outside `PICABench_data/`, `prepare_meta_info.py` will write absolute paths and the evaluators will resolve them automatically. 71 | 72 | Results are saved as `PICABench_data/meta_info_gpt_output_1024_crop_box_and_resize.json` and the corresponding `_analysis_...json`. 73 | 74 | ## Installation 75 | 76 | We recommend using a Python 3.10+ virtual environment: 77 | 78 | ```bash 79 | conda create -n picabench python=3.10 80 | conda activate picabench 81 | ``` 82 | 83 | Install dependencies based on your evaluation needs: 84 | 85 | ```bash 86 | # For GPT evaluation (multi-threaded with OpenAI SDK) 87 | pip install openai Pillow tqdm huggingface_hub 88 | 89 | # For Qwen evaluation (with vLLM acceleration) 90 | pip install vllm transformers Pillow tqdm 91 | ``` 92 | 93 | ## Data Preparation 94 | 95 | ### Generating `meta_info.json` from HuggingFace Dataset 96 | 97 | If you've already generated edited images with your model but don't know how to organize them into the `meta_info.json` format required by evaluation scripts, use the provided conversion script: 98 | 99 | ```bash 100 | # 1. Install dependencies 101 | pip install datasets pillow tqdm 102 | 103 | # 2. Assuming your model outputs are in outputs/ directory with filenames 00000.jpg, 00001.jpg, ... 104 | python prepare_meta_info.py \ 105 | --hf_repo Andrew613/PICABench \ 106 | --output_image_dir outputs \ 107 | --save_dir PICABench_data 108 | 109 | # 3. Generated files: 110 | # PICABench_data/input_img/ - Input images (automatically saved from HF dataset) 111 | # PICABench_data/meta_info.json - Standard format JSON, ready for evaluation 112 | ``` 113 | 114 | > ⚠️ Paths inside `meta_info.json` are written relative to the chosen `--save_dir`. Pass that same directory to the evaluators via `--image_base_dir` to avoid duplicate folder segments. 115 | 116 | **Parameters:** 117 | - `--output_image_dir`: Directory containing your model's edited output images 118 | - `--save_dir`: Root directory to save `meta_info.json` and input images 119 | - `--output_name_pattern`: Output image filename pattern (default `{index:05d}.jpg`), supports `{index}` placeholder 120 | - `--allow_missing`: Allow missing output images, still generate JSON (missing samples will have `output_path` set to `null`) 121 | - `--force_input_save`: Overwrite cached `input_img/*.jpg` files instead of reusing them (default: reuse existing files) 122 | 123 | ### `meta_info.json` Format 124 | 125 | PICABench expects per-scene metadata in `meta_info.json` plus accompanying images under a shared base directory. Each item should include: 126 | 127 | ```jsonc 128 | { 129 | "index": 1174, 130 | "input_path": "input_img/1174.jpg", 131 | "output_path": "output_img/1174.jpg", 132 | "edit_instruction": "Remove the tulip from the white vase and simultaneously eliminate every instance of it in the window reflection while keeping lighting and shading consistent.", 133 | "physics_category": "Optics", 134 | "physics_law": "Reflection", 135 | "edit_operation": "remove", 136 | "difficulty": "superficial", 137 | "annotated_qa_pairs": [ 138 | { 139 | "question": "Is a tulip visible in the window reflection?", 140 | "answer": "No", 141 | "box": { "x": 101.25, "y": 476.90, "width": 169.44, "height": 202.96 } 142 | }, 143 | { 144 | "question": "Does the interior of the white vase contain exactly zero tulips?", 145 | "answer": "Yes", 146 | "box": { "x": 327.96, "y": 485.99, "width": 209.80, "height": 206.21 } 147 | }, 148 | { 149 | "question": "Is the vase's reflection aligned with the vase?", 150 | "answer": "Yes", 151 | "box": { "x": 117.24, "y": 496.29, "width": 363.74, "height": 183.41 } 152 | } 153 | ], 154 | "edit_area": [ 155 | { 156 | "x": 117.24, 157 | "y": 496.29, 158 | "width": 363.74, 159 | "height": 183.41, 160 | "id": "BxnMC34B", 161 | "order": 1 162 | } 163 | ] 164 | } 165 | ``` 166 | 167 |
168 | 📋 Field Descriptions (click to expand) 169 | 170 | - **`annotated_qa_pairs`**: List of QA dictionaries for physics verification. Each contains: 171 | - `question`: Yes/no question about physical correctness 172 | - `answer`: Ground truth ("Yes" or "No") 173 | - `box`: Region of interest `{x, y, width, height}` in 1024px canvas coordinates 174 | 175 | - **`edit_area`**: Bounding boxes of edited regions (used for visualization cropping). Set to `"unknown"` if unavailable. 176 | 177 | - **Visualization**: Scripts auto-generate cropped/annotated images in `visualization_annotated_qa_crop_box_and_resize/` under the base directory. 178 | 179 |
180 | 181 | ## Evaluation Pipelines 182 | 183 | ### 1. Qwen / vLLM (PICAEval) 184 | 185 | ```bash 186 | python PicaEval_qwen.py \ 187 | --input_json_path /path/to/meta_info.json \ 188 | --image_base_dir /path/to/images \ 189 | --model_path pretrained/Qwen/Qwen2.5-VL-72B-Instruct \ 190 | --tensor_parallel_size 4 \ 191 | --dtype bfloat16 \ 192 | --qa_field annotated_qa_pairs \ 193 | --viz_mode crop_box_and_resize \ 194 | --max_new_tokens 256 \ 195 | --img_size 1024 196 | ``` 197 | 198 | Outputs: 199 | 200 | - `_vllm_output_[_mode].json` – per-QA predictions with `model_answer`, `model_response`, `model_explanation`, `is_correct`, and optional `visualization_path`. 201 | - `_vllm_analysis_[_mode].json` – aggregated accuracy by physics category, law, and operation. 202 | 203 | ### 2. GPT-based Evaluation (PICAEval) 204 | 205 | ```bash 206 | export OPENAI_API_KEY="sk-..." 207 | python PicaEval_gpt.py \ 208 | --input_json_path /path/to/meta_info.json \ 209 | --image_base_dir /path/to/images \ 210 | --qa_field annotated_qa_pairs \ 211 | --viz_mode crop_box_and_resize \ 212 | --gpt_model gpt-5 \ 213 | --num_workers 50 \ 214 | --max_attempts 5 \ 215 | --api_base_url https://api.openai.com/v1 216 | ``` 217 | 218 | **Key Parameters:** 219 | - `--num_workers`: Number of parallel worker threads (default: 50) for concurrent API requests 220 | - `--gpt_model`: OpenAI model name (e.g., `gpt-5`, `gpt-4o`, `gpt-4-turbo`) 221 | - `--api_base_url`: API endpoint URL (default: `https://api.openai.com/v1`) 222 | - `--max_attempts`: Retry attempts for failed API calls (default: 5) 223 | 224 | **Outputs:** 225 | - `_gpt_output_[_{mode}].json` – detailed results with per-question predictions 226 | - `_gpt_analysis_[_{mode}].json` – accuracy statistics by physics category, law, and operation 227 | 228 | **Notes:** 229 | - Uses multi-threaded execution with OpenAI SDK for efficient parallel evaluation 230 | - Reuses the same JSON schema for inputs/outputs as the Qwen pipeline, enabling direct comparison 231 | - Images are base64-encoded and sent as data URLs; be mindful of API quotas and rate limits 232 | 233 | ### 3. Non-edited Region Quality (PSNR) 234 | 235 | ```bash 236 | python PicaEval_consistency.py \ 237 | --meta_info_path /path/to/meta_info.json \ 238 | --base_dir /path/to/images \ 239 | --size 512 240 | ``` 241 | 242 | Produces `_psnr_output.json` and `_psnr_analysis.json`, containing masked PSNR on non-edited regions or whole-image PSNR when edit regions are unavailable. 243 | 244 | ## PICA-100K Training Data 245 | 246 | **Dataset**: [Andrew613/PICA-100K](https://huggingface.co/datasets/Andrew613/PICA-100K) 247 | 248 | 100K synthetic editing pairs derived from video frames, designed to improve physical realism in image editing models. 249 | 250 | ### Download 251 | 252 | ```bash 253 | huggingface-cli download Andrew613/PICA-100K \ 254 | --repo-type dataset \ 255 | --local-dir data/PICA-100K 256 | ``` 257 | 258 | ## Leaderboard & Qualitative Explorer 259 | 260 | - Official leaderboard and gallery: [https://picabench.github.io](https://picabench.github.io) 261 | - Eight physics laws × three difficulty tiers provide direct qualitative comparisons. 262 | - PICAEval scores correlate strongly with human judgments (Elo study on the site). 263 | 264 | ## Leaderboard Submission 265 | 266 | To submit your model's results to the PICABench leaderboard: 267 | 268 | **Required Metrics:** 269 | - Accuracy (%) for each sub-category (Light Propagation, Reflection, Refraction, Light Source Effects, Deformation, Causality, Local State Transition, Global State Transition) 270 | - Overall Accuracy (%) 271 | 272 | **Submission:** 273 | Email your `*_analysis*.json` and `*_output*.json` files and model details to: 274 | - [puyuandong01061313@gmail.com](mailto:puyuandong01061313@gmail.com) 275 | ## Citation 276 | 277 | ```bibtex 278 | @article{pu2025picabench, 279 | title = {PICABench: How Far Are We From Physically Realistic Image Editing?}, 280 | author = {Pu, Yuandong and Zhuo, Le and Han, Songhao and Xing, Jinbo and Zhu, Kaiwen and Cao, Shuo and Fu, Bin and Liu, Si and Li, Hongsheng and Qiao, Yu and Zhang, Wenlong and Chen, Xi and Liu, Yihao}, 281 | journal = {arXiv preprint arXiv:2510.17681}, 282 | year = {2025} 283 | } 284 | ``` 285 | 286 | ## License 287 | 288 | This project is released under the Apache License 2.0. 289 | -------------------------------------------------------------------------------- /PicaEval_consistency.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from typing import Dict, List, Any 5 | 6 | import numpy as np 7 | import torch 8 | from PIL import Image, ImageDraw 9 | from tqdm import tqdm 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description='Non-edited Region Quality Assessment - PSNR only') 14 | parser.add_argument( 15 | '--meta_info_path', 16 | required=True, 17 | help='Path to meta_info.json file', 18 | ) 19 | parser.add_argument( 20 | '--base_dir', 21 | required=False, 22 | help='Base directory for image files, defaults to meta_info.json directory if not specified', 23 | ) 24 | parser.add_argument( 25 | '--size', 26 | default=None, 27 | type=int, 28 | help='Resize image to this size, keep original size if not specified', 29 | ) 30 | 31 | return parser.parse_args() 32 | 33 | 34 | def create_edit_mask(image_size, edit_areas): 35 | """Create edit region mask, returns non-edited region mask""" 36 | width, height = image_size 37 | # Create white image (all regions are non-edited) 38 | mask_img = Image.new('L', (width, height), 255) 39 | draw = ImageDraw.Draw(mask_img) 40 | 41 | # Mark edited regions as black (0) 42 | for area in edit_areas: 43 | x = area['x'] 44 | y = area['y'] 45 | w = area['width'] 46 | h = area['height'] 47 | # Draw rectangle, edited region in black 48 | draw.rectangle([x, y, x + w, y + h], fill=0) 49 | 50 | # Convert to tensor, non-edited region=1, edited region=0 51 | mask_array = np.array(mask_img) / 255.0 # Convert to [0,1] 52 | mask_tensor = torch.tensor(mask_array, dtype=torch.float32) 53 | return mask_tensor 54 | 55 | 56 | def load_image_with_mask(image_path: str, edit_areas: List[Dict], size=None): 57 | """Load image and create non-edited region mask""" 58 | try: 59 | image = Image.open(image_path).convert('RGB') 60 | original_size = image.size 61 | 62 | if size: 63 | image = image.resize((size, size), Image.BILINEAR) 64 | # Adjust edit_areas coordinates 65 | scale_x = size / original_size[0] 66 | scale_y = size / original_size[1] 67 | scaled_edit_areas = [] 68 | for area in edit_areas: 69 | scaled_area = { 70 | 'x': area['x'] * scale_x, 71 | 'y': area['y'] * scale_y, 72 | 'width': area['width'] * scale_x, 73 | 'height': area['height'] * scale_y 74 | } 75 | scaled_edit_areas.append(scaled_area) 76 | edit_areas = scaled_edit_areas 77 | mask_size = (size, size) 78 | else: 79 | mask_size = original_size 80 | 81 | # Create mask 82 | mask = create_edit_mask(mask_size, edit_areas) 83 | 84 | # Convert image to tensor 85 | image_array = np.array(image) / 255.0 # Normalize to [0, 1] 86 | image_array = image_array.astype(np.float32) 87 | image_tensor = torch.tensor(image_array).permute(2, 0, 1).unsqueeze(0) # Shape: (1, 3, H, W) 88 | 89 | # Expand mask dimensions to match image (1, 1, H, W) 90 | mask = mask.unsqueeze(0).unsqueeze(0) 91 | 92 | return image_tensor, mask 93 | except Exception as e: 94 | print(f"Error loading image {image_path}: {e}") 95 | return None, None 96 | 97 | 98 | def compute_masked_psnr(output_img, input_img, mask): 99 | """Compute PSNR only on masked region (mask=1 pixels)""" 100 | mask_bool = mask > 0.5 101 | mask_3ch = mask_bool.expand_as(output_img) 102 | 103 | # Extract non-edited region pixels 104 | output_pixels = output_img[mask_3ch] 105 | input_pixels = input_img[mask_3ch] 106 | 107 | if output_pixels.numel() == 0: 108 | return None 109 | 110 | # Calculate MSE (only on non-edited region) 111 | mse = torch.mean((output_pixels - input_pixels) ** 2) 112 | 113 | if mse < 1e-10: 114 | return 100.0 # Near infinity, return a large value 115 | 116 | psnr = 20 * torch.log10(torch.tensor(1.0) / torch.sqrt(mse)) 117 | return psnr.item() 118 | 119 | 120 | def load_image_simple(image_path: str, size=None): 121 | """Simple image loading without mask processing""" 122 | try: 123 | image = Image.open(image_path).convert('RGB') 124 | if size: 125 | image = image.resize((size, size), Image.BILINEAR) 126 | image = np.array(image) / 255.0 # Normalize to [0, 1] 127 | image = image.astype(np.float32) 128 | image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) # Shape: (1, 3, H, W) 129 | return image 130 | except Exception as e: 131 | print(f"Error loading image {image_path}: {e}") 132 | return None 133 | 134 | 135 | def compute_full_image_psnr(output_img, input_img): 136 | """Compute PSNR on full image""" 137 | mse = torch.mean((output_img - input_img) ** 2) 138 | 139 | if mse < 1e-10: 140 | return 100.0 141 | 142 | psnr = 20 * torch.log10(torch.tensor(1.0) / torch.sqrt(mse)) 143 | return psnr.item() 144 | 145 | 146 | def evaluate_single_item(item: Dict[str, Any], base_dir: str, size=None) -> float: 147 | """Evaluate single meta_info item and return PSNR""" 148 | # Build full paths 149 | input_path = os.path.join(base_dir, item['input_path']) 150 | output_path = os.path.join(base_dir, item['output_path']) 151 | 152 | # Check if files exist 153 | if not os.path.exists(input_path): 154 | print(f"Warning: input image not found: {input_path}") 155 | return None 156 | 157 | if not os.path.exists(output_path): 158 | print(f"Warning: output image not found: {output_path}") 159 | return None 160 | 161 | # Get edit area information 162 | edit_areas = item.get('edit_area', []) 163 | use_full_image = False 164 | 165 | # Handle edit_area as string 166 | if isinstance(edit_areas, str): 167 | if edit_areas == "unknown": 168 | use_full_image = True 169 | else: 170 | print(f"Warning: unexpected edit_area format for item {item.get('index', 'unknown')}: {edit_areas}") 171 | return None 172 | elif not edit_areas or len(edit_areas) == 0: 173 | # No edit_area info, use full image evaluation 174 | use_full_image = True 175 | 176 | if use_full_image: 177 | # Full image evaluation 178 | input_image = load_image_simple(input_path, size) 179 | output_image = load_image_simple(output_path, size) 180 | 181 | if input_image is None or output_image is None: 182 | return None 183 | 184 | # Ensure images have same size 185 | if input_image.shape != output_image.shape: 186 | h, w = input_image.shape[2], input_image.shape[3] 187 | output_image = torch.nn.functional.interpolate(output_image, size=(h, w), mode='bilinear', align_corners=False) 188 | 189 | # Compute full image PSNR 190 | psnr = compute_full_image_psnr(output_image, input_image) 191 | 192 | else: 193 | # Non-edited region evaluation 194 | input_image, mask = load_image_with_mask(input_path, edit_areas, size) 195 | output_image, _ = load_image_with_mask(output_path, edit_areas, size) 196 | 197 | if input_image is None or output_image is None or mask is None: 198 | return None 199 | 200 | # Ensure images have same size 201 | if input_image.shape != output_image.shape: 202 | h, w = input_image.shape[2], input_image.shape[3] 203 | output_image = torch.nn.functional.interpolate(output_image, size=(h, w), mode='bilinear', align_corners=False) 204 | 205 | # Compute masked region PSNR 206 | psnr = compute_masked_psnr(output_image, input_image, mask) 207 | 208 | return psnr 209 | 210 | 211 | def process_meta_info(meta_info_path: str, base_dir: str, size=None): 212 | """Process entire meta_info.json file""" 213 | 214 | # Load meta_info.json 215 | with open(meta_info_path, 'r', encoding='utf-8') as f: 216 | meta_info = json.load(f) 217 | 218 | print(f"Loaded {len(meta_info)} items from {meta_info_path}") 219 | 220 | # Detect evaluation mode 221 | has_edit_area = False 222 | 223 | # Check first few items for edit_area status 224 | for item in meta_info[:10]: 225 | edit_areas = item.get('edit_area', []) 226 | if isinstance(edit_areas, list) and len(edit_areas) > 0: 227 | has_edit_area = True 228 | break 229 | 230 | if not has_edit_area: 231 | print("Warning: No valid edit_area found in dataset, using full image evaluation mode") 232 | evaluation_type = "full_image" 233 | else: 234 | evaluation_type = "non_edited_region" 235 | 236 | # Store results 237 | detailed_results = [] 238 | valid_scores = [] 239 | 240 | # Evaluate each item 241 | for idx, item in tqdm(enumerate(meta_info), desc="Evaluating items"): 242 | psnr = evaluate_single_item(item, base_dir, size) 243 | sample_id = item.get('index', idx) 244 | 245 | # Add to detailed results 246 | result_item = { 247 | 'id': sample_id, 248 | 'input_path': item['input_path'], 249 | 'output_path': item['output_path'], 250 | 'physics_category': item.get('physics_category', 'unknown'), 251 | 'physics_law': item.get('physics_law', 'unknown'), 252 | 'edit_operation': item.get('edit_operation', 'unknown'), 253 | 'difficulty': item.get('difficulty', 'unknown'), 254 | 'psnr': psnr 255 | } 256 | detailed_results.append(result_item) 257 | 258 | # Collect valid scores 259 | if psnr is not None: 260 | valid_scores.append(psnr) 261 | 262 | # Calculate overall statistics 263 | if valid_scores: 264 | overall_stats = { 265 | 'count': len(valid_scores), 266 | 'mean': float(np.mean(valid_scores)), 267 | 'std': float(np.std(valid_scores)), 268 | 'min': float(np.min(valid_scores)), 269 | 'max': float(np.max(valid_scores)), 270 | 'median': float(np.median(valid_scores)) 271 | } 272 | else: 273 | overall_stats = { 274 | 'count': 0, 275 | 'mean': None, 276 | 'std': None, 277 | 'min': None, 278 | 'max': None, 279 | 'median': None 280 | } 281 | 282 | # Statistics by physics_category 283 | physics_category_stats = {} 284 | categories = set(item.get('physics_category', 'unknown') for item in meta_info) 285 | 286 | for category in categories: 287 | category_items = [item for item in detailed_results if item['physics_category'] == category] 288 | category_scores = [item['psnr'] for item in category_items if item['psnr'] is not None] 289 | 290 | if category_scores: 291 | physics_category_stats[category] = { 292 | 'count': len(category_scores), 293 | 'mean': float(np.mean(category_scores)), 294 | 'std': float(np.std(category_scores)) 295 | } 296 | else: 297 | physics_category_stats[category] = { 298 | 'count': 0, 299 | 'mean': None, 300 | 'std': None 301 | } 302 | 303 | # Statistics by physics_law 304 | physics_law_stats = {} 305 | laws = set(item.get('physics_law', 'unknown') for item in meta_info) 306 | 307 | for law in laws: 308 | law_items = [item for item in detailed_results if item['physics_law'] == law] 309 | law_scores = [item['psnr'] for item in law_items if item['psnr'] is not None] 310 | 311 | if law_scores: 312 | physics_law_stats[law] = { 313 | 'count': len(law_scores), 314 | 'mean': float(np.mean(law_scores)), 315 | 'std': float(np.std(law_scores)) 316 | } 317 | else: 318 | physics_law_stats[law] = { 319 | 'count': 0, 320 | 'mean': None, 321 | 'std': None 322 | } 323 | 324 | # Prepare analysis results 325 | final_analysis = { 326 | 'meta_info_path': meta_info_path, 327 | 'total_items': len(meta_info), 328 | 'evaluation_type': evaluation_type, 329 | 'overall_statistics': overall_stats, 330 | 'physics_category_statistics': physics_category_stats, 331 | 'physics_law_statistics': physics_law_stats 332 | } 333 | 334 | return final_analysis, detailed_results 335 | 336 | 337 | def main(): 338 | args = parse_args() 339 | 340 | # Determine base directory 341 | if args.base_dir: 342 | base_dir = args.base_dir 343 | else: 344 | base_dir = os.path.dirname(args.meta_info_path) 345 | 346 | print(f"Using base directory: {base_dir}") 347 | 348 | # Process meta_info 349 | analysis_results, detailed_results = process_meta_info( 350 | args.meta_info_path, 351 | base_dir, 352 | args.size 353 | ) 354 | 355 | # Generate output file names 356 | meta_info_dir = os.path.dirname(args.meta_info_path) 357 | meta_info_name = os.path.splitext(os.path.basename(args.meta_info_path))[0] 358 | 359 | analysis_output_path = os.path.join(meta_info_dir, f"{meta_info_name}_psnr_analysis.json") 360 | detailed_output_path = os.path.join(meta_info_dir, f"{meta_info_name}_psnr_output.json") 361 | 362 | # Save results 363 | with open(analysis_output_path, 'w', encoding='utf-8') as f: 364 | json.dump(analysis_results, f, indent=2, ensure_ascii=False) 365 | 366 | with open(detailed_output_path, 'w', encoding='utf-8') as f: 367 | json.dump(detailed_results, f, indent=2, ensure_ascii=False) 368 | 369 | print(f"\nResults saved:") 370 | print(f"Analysis: {analysis_output_path}") 371 | print(f"Detailed: {detailed_output_path}") 372 | 373 | # Print overall results 374 | eval_mode = analysis_results['evaluation_type'] 375 | eval_label = "Full Image" if eval_mode == "full_image" else "Non-edited Region" 376 | print(f"\n=== Overall PSNR Results ({eval_label}) ===") 377 | stats = analysis_results['overall_statistics'] 378 | if stats['mean'] is not None: 379 | print(f"PSNR: {stats['mean']:.4f} ± {stats['std']:.4f} (n={stats['count']})") 380 | print(f" Min: {stats['min']:.4f}, Max: {stats['max']:.4f}, Median: {stats['median']:.4f}") 381 | else: 382 | print(f"PSNR: No valid scores") 383 | 384 | # Print physics_category results 385 | print(f"\n=== PSNR Results by Physics Category ===") 386 | for category, category_stats in analysis_results['physics_category_statistics'].items(): 387 | if category_stats['mean'] is not None: 388 | print(f"{category}: {category_stats['mean']:.4f} ± {category_stats['std']:.4f} (n={category_stats['count']})") 389 | else: 390 | print(f"{category}: No valid scores") 391 | 392 | # Print physics_law results 393 | print(f"\n=== PSNR Results by Physics Law ===") 394 | for law, law_stats in analysis_results['physics_law_statistics'].items(): 395 | if law_stats['mean'] is not None: 396 | print(f"{law}: {law_stats['mean']:.4f} ± {law_stats['std']:.4f} (n={law_stats['count']})") 397 | else: 398 | print(f"{law}: No valid scores") 399 | 400 | 401 | if __name__ == "__main__": 402 | main() 403 | -------------------------------------------------------------------------------- /PicaEval_gpt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """GPT-based PhysEdit evaluation (v2). 4 | Matches the Qwen evaluator input/output contract while delegating inference to OpenAI GPT.""" 5 | 6 | import os 7 | import sys 8 | import json 9 | import argparse 10 | import random 11 | import time 12 | import base64 13 | import io 14 | import re 15 | from dataclasses import dataclass 16 | from pathlib import Path 17 | from typing import Any, Dict, List, Optional, Tuple 18 | from concurrent.futures import ThreadPoolExecutor, as_completed 19 | 20 | from PIL import Image, ImageDraw 21 | from tqdm import tqdm 22 | from openai import OpenAI 23 | 24 | # ============================================================================ 25 | # :: Shared Data Structures & Helpers (kept aligned with PicaEval_qwen) 26 | # ============================================================================ 27 | 28 | IMAGE_PATH_KEYS: Tuple[str, ...] = ("output_path", "output_image_path", "output_img_path") 29 | JSON_PATTERN = re.compile(r"\{[^{}]*\"answer\"[^{}]*\"explanation\"[^{}]*\}", re.IGNORECASE | re.DOTALL) 30 | 31 | 32 | @dataclass 33 | class QATask: 34 | item_index: int 35 | qa_field: str 36 | qa_type: str 37 | qa_index: int 38 | image_path: str 39 | question: str 40 | answer: str 41 | source: str 42 | 43 | 44 | def resolve_image_path(item: Dict[str, Any], base_dir: str) -> Optional[str]: 45 | base_path = Path(base_dir).expanduser().resolve() 46 | base_name = base_path.name 47 | for key in IMAGE_PATH_KEYS: 48 | raw_path = item.get(key) 49 | if not raw_path: 50 | continue 51 | 52 | candidate_path = Path(raw_path) 53 | if candidate_path.is_absolute(): 54 | return str(candidate_path) 55 | 56 | normalized_rel = Path(*candidate_path.parts) 57 | candidate = base_path / normalized_rel 58 | if candidate.exists(): 59 | return str(candidate) 60 | 61 | rel_parts = normalized_rel.parts 62 | if rel_parts and rel_parts[0] == base_name: 63 | alt = base_path / Path(*rel_parts[1:]) 64 | if alt.exists(): 65 | return str(alt) 66 | 67 | return str(candidate) 68 | return None 69 | 70 | 71 | def iter_qa_entries(container: Any, default_type: str) -> List[Tuple[str, int, Dict[str, Any]]]: 72 | entries: List[Tuple[str, int, Dict[str, Any]]] = [] 73 | if isinstance(container, dict): 74 | for qa_type, qa_list in container.items(): 75 | for idx, qa in enumerate(qa_list): 76 | entries.append((qa_type, idx, qa)) 77 | elif isinstance(container, list): 78 | for idx, qa in enumerate(container): 79 | entries.append((default_type, idx, qa)) 80 | return entries 81 | 82 | 83 | def draw_box_on_image(image: Image.Image, box_info: Dict[str, Any], box_color: str = "red") -> Image.Image: 84 | img_copy = image.copy() 85 | draw = ImageDraw.Draw(img_copy) 86 | x = box_info.get("x", 0) 87 | y = box_info.get("y", 0) 88 | width = box_info.get("width", 0) 89 | height = box_info.get("height", 0) 90 | bbox = [x, y, x + width, y + height] 91 | draw.rectangle(bbox, outline=box_color, width=3) 92 | return img_copy 93 | 94 | 95 | def resize_image(image: Image.Image) -> Image.Image: 96 | width, height = image.size 97 | long_edge = max(width, height) 98 | if long_edge == 1024: 99 | return image 100 | scale = 1024.0 / long_edge 101 | new_width = int(width * scale) 102 | new_height = int(height * scale) 103 | return image.resize((new_width, new_height), Image.LANCZOS) 104 | 105 | 106 | def crop_image_with_box(image: Image.Image, box_info: Dict[str, Any], padding: int = 20) -> Image.Image: 107 | x = box_info.get("x", 0) 108 | y = box_info.get("y", 0) 109 | width = box_info.get("width", 0) 110 | height = box_info.get("height", 0) 111 | x1 = max(0, x - padding) 112 | y1 = max(0, y - padding) 113 | x2 = min(image.size[0], x + width + padding) 114 | y2 = min(image.size[1], y + height + padding) 115 | return image.crop((x1, y1, x2, y2)) 116 | 117 | 118 | def generate_viz_filename(item_index: int, qa_type: str, question_index: int, viz_mode: str) -> str: 119 | qa_type_short = qa_type.replace(" ", "_").replace("QA", "qa") 120 | return f"{item_index:04d}_{qa_type_short}_{question_index:03d}_{viz_mode}.jpg" 121 | 122 | 123 | def create_visualization_and_question( 124 | output_img_path: str, 125 | qa_info: Dict[str, Any], 126 | item_index: int, 127 | qa_type: str, 128 | qa_index: int, 129 | viz_mode: str, 130 | viz_dir: str, 131 | args, 132 | ) -> Tuple[str, str]: 133 | question = qa_info.get("question", "") 134 | box_info = qa_info.get("box", {}) 135 | viz_filename = generate_viz_filename(item_index, qa_type, qa_index, viz_mode) 136 | viz_path = os.path.join(viz_dir, viz_filename) 137 | viz_rel_path = os.path.join(f"visualization_annotated_qa_{viz_mode}", viz_filename) 138 | os.makedirs(viz_dir, exist_ok=True) 139 | try: 140 | image = Image.open(output_img_path).convert("RGB") 141 | image_resized = resize_image(image) 142 | if viz_mode == "draw_box": 143 | gpt_question = question 144 | viz_image = draw_box_on_image(image_resized, box_info, args.box_color) 145 | elif viz_mode == "crop_box": 146 | gpt_question = question 147 | viz_image = crop_image_with_box(image_resized, box_info, args.viz_padding) 148 | elif viz_mode == "crop_box_and_resize": 149 | gpt_question = question 150 | viz_image = crop_image_with_box(image_resized, box_info, args.viz_padding) 151 | viz_image = resize_image(viz_image) 152 | else: 153 | raise ValueError(f"Unknown viz_mode: {viz_mode}") 154 | viz_image.save(viz_path, quality=90) 155 | return gpt_question, viz_rel_path 156 | except Exception as exc: 157 | print(f"Error creating visualization for {viz_path}: {exc}") 158 | return question, "" 159 | 160 | 161 | def _normalize_answer(answer: str) -> str: 162 | answer_lower = answer.lower().strip().rstrip('.') 163 | if answer_lower in ["yes", "y", "true"]: 164 | return "Yes" 165 | if answer_lower in ["no", "n", "false"]: 166 | return "No" 167 | return answer.strip() 168 | 169 | 170 | def _parse_json_response(response: str) -> Optional[Tuple[str, str]]: 171 | try: 172 | response_clean = response.strip() 173 | json_candidates: List[str] = [] 174 | json_start = response_clean.find('{') 175 | json_end = response_clean.rfind('}') + 1 176 | if json_start != -1 and json_end > json_start: 177 | json_candidates.append(response_clean[json_start:json_end]) 178 | json_candidates.extend(JSON_PATTERN.findall(response_clean)) 179 | for json_str in json_candidates: 180 | try: 181 | data = json.loads(json_str) 182 | answer = data.get("answer", "").strip() 183 | explanation = data.get("explanation", "").strip() 184 | if answer: 185 | return _normalize_answer(answer), explanation 186 | except json.JSONDecodeError: 187 | continue 188 | except Exception: 189 | pass 190 | return None 191 | 192 | 193 | def extract_yes_no_answer(response: str) -> str: 194 | response_lower = response.lower().strip() 195 | if response_lower.startswith("yes"): 196 | return "Yes" 197 | if response_lower.startswith("no"): 198 | return "No" 199 | if re.search(r"\byes\b", response_lower): 200 | return "Yes" 201 | if re.search(r"\bno\b", response_lower): 202 | return "No" 203 | return response[:10] if response else "Unknown" 204 | 205 | 206 | def parse_structured_response(response: str) -> Tuple[str, str]: 207 | parsed = _parse_json_response(response) 208 | if parsed: 209 | return parsed 210 | return extract_yes_no_answer(response), "" 211 | 212 | 213 | def encode_image(image: Image.Image) -> str: 214 | buffer = io.BytesIO() 215 | if image.mode != "RGB": 216 | image = image.convert("RGB") 217 | image.save(buffer, format="JPEG", quality=90) 218 | return base64.b64encode(buffer.getvalue()).decode("utf-8") 219 | 220 | 221 | def load_image_base64(image_path: str, cache: Dict[str, str]) -> str: 222 | if image_path in cache: 223 | return cache[image_path] 224 | try: 225 | image = Image.open(image_path).convert("RGB") 226 | image = resize_image(image) 227 | except Exception as exc: 228 | print(f"Error loading image {image_path}: {exc}") 229 | image = Image.new("RGB", (512, 512), "white") 230 | encoded = encode_image(image) 231 | cache[image_path] = encoded 232 | return encoded 233 | 234 | 235 | def create_structured_prompt(question: str) -> str: 236 | """Create a prompt with structured JSON output instructions""" 237 | json_instruction = ( 238 | "\n\nPlease provide a structured answer in the following JSON format:\n" 239 | '{"answer": "Yes" or "No", "explanation": "Brief explanation of your reasoning"}\n\n' 240 | "Output ONLY valid JSON. No extra text." 241 | ) 242 | return question + json_instruction 243 | 244 | 245 | def call_gpt_with_retries(client: OpenAI, prompt: str, image_data_url: str, args) -> str: 246 | """Call GPT API with retry logic using OpenAI SDK (responses.create endpoint)""" 247 | for attempt in range(args.max_attempts): 248 | try: 249 | # Build input payload with text and image 250 | input_payload = [{ 251 | "role": "user", 252 | "content": [ 253 | {"type": "input_text", "text": prompt}, 254 | {"type": "input_image", "image_url": image_data_url}, 255 | ], 256 | }] 257 | 258 | # Call OpenAI Responses API (for GPT-5/o1 models) 259 | response = client.responses.create( 260 | model=args.gpt_model, 261 | input=input_payload, 262 | ) 263 | 264 | return response.output_text.strip() 265 | 266 | except Exception as exc: 267 | wait_time = (args.retry_backoff ** attempt) + random.uniform(0, 1) 268 | print(f"GPT call failed ({attempt + 1}/{args.max_attempts}): {exc}; retry in {wait_time:.1f}s") 269 | time.sleep(wait_time) 270 | 271 | return "Error: all attempts failed" 272 | 273 | 274 | def get_qa_entry(item: Dict[str, Any], qa_field: str, qa_type: str, qa_index: int) -> Dict[str, Any]: 275 | container = item.get(qa_field) 276 | if container is None and qa_field == "qa_pairs": 277 | container = item.get("annotated_qa_pairs") 278 | if isinstance(container, dict): 279 | return container[qa_type][qa_index] 280 | if isinstance(container, list): 281 | return container[qa_index] 282 | raise KeyError(f"QA entry not found for field={qa_field}, type={qa_type}, index={qa_index}") 283 | 284 | 285 | def extract_qa_tasks_standard(items: List[Dict[str, Any]], image_base_dir: str) -> List[QATask]: 286 | qa_tasks: List[QATask] = [] 287 | for item_idx, item in enumerate(items): 288 | image_path = resolve_image_path(item, image_base_dir) 289 | if not image_path: 290 | continue 291 | qa_container = item.get("qa_pairs") 292 | if qa_container is None: 293 | qa_container = item.get("annotated_qa_pairs", {}) 294 | for qa_type, qa_idx, qa in iter_qa_entries(qa_container, "qa"): 295 | question = qa.get("question") 296 | answer = qa.get("answer") 297 | if question and answer: 298 | qa_tasks.append(QATask(item_idx, "qa_pairs", qa_type, qa_idx, image_path, question, answer, "original")) 299 | return qa_tasks 300 | 301 | 302 | def extract_qa_tasks_annotated(items: List[Dict[str, Any]], image_base_dir: str, viz_mode: str, args) -> List[QATask]: 303 | qa_tasks: List[QATask] = [] 304 | viz_dir = os.path.join(image_base_dir, f"visualization_annotated_qa_{viz_mode}") 305 | for item_idx, item in enumerate(items): 306 | image_path = resolve_image_path(item, image_base_dir) 307 | if not image_path: 308 | continue 309 | qa_container = item.get("annotated_qa_pairs", {}) 310 | for qa_type, qa_idx, qa in iter_qa_entries(qa_container, "annotated_qa"): 311 | question = qa.get("question") 312 | answer = qa.get("answer") 313 | if not (question and answer): 314 | continue 315 | gpt_question, viz_rel_path = create_visualization_and_question(image_path, qa, item_idx, qa_type, qa_idx, viz_mode, viz_dir, args) 316 | viz_path = os.path.join(image_base_dir, viz_rel_path) if viz_rel_path else image_path 317 | source = "visualization" if viz_rel_path else "original" 318 | qa_tasks.append(QATask(item_idx, "annotated_qa_pairs", qa_type, qa_idx, viz_path, gpt_question, answer, source)) 319 | return qa_tasks 320 | 321 | 322 | def process_single_task(task: QATask, items: List[Dict[str, Any]], base64_cache: Dict[str, str], 323 | image_base_dir: str, client: OpenAI, args) -> Dict[str, Any]: 324 | """Worker function to process a single QA task in thread pool""" 325 | try: 326 | # Load and encode image 327 | base64_str = load_image_base64(task.image_path, base64_cache) 328 | data_url = f"data:image/jpeg;base64,{base64_str}" 329 | 330 | # Create prompt and call GPT 331 | prompt = create_structured_prompt(task.question) 332 | model_response = call_gpt_with_retries(client, prompt, data_url, args) 333 | 334 | # Parse response 335 | model_answer, model_explanation = parse_structured_response(model_response) 336 | 337 | # Check correctness 338 | gt_clean = task.answer.lower().strip().rstrip('.') 339 | model_clean = model_answer.lower().strip().rstrip('.') 340 | is_correct = gt_clean == model_clean 341 | 342 | # Return result 343 | result = { 344 | "item_index": task.item_index, 345 | "qa_field": task.qa_field, 346 | "qa_type": task.qa_type, 347 | "qa_index": task.qa_index, 348 | "model_answer": model_answer, 349 | "model_response": model_response, 350 | "model_explanation": model_explanation, 351 | "is_correct": is_correct, 352 | } 353 | 354 | # Add visualization info if applicable 355 | if task.qa_field == "annotated_qa_pairs" and task.source == "visualization": 356 | viz_rel_path = os.path.relpath(task.image_path, image_base_dir) 357 | result["visualization_path"] = viz_rel_path 358 | result["viz_mode"] = args.viz_mode 359 | 360 | return result 361 | 362 | except Exception as e: 363 | return { 364 | "item_index": task.item_index, 365 | "qa_field": task.qa_field, 366 | "qa_type": task.qa_type, 367 | "qa_index": task.qa_index, 368 | "error": str(e), 369 | "model_answer": "Error", 370 | "is_correct": False, 371 | } 372 | 373 | 374 | def evaluate_with_gpt(items: List[Dict[str, Any]], image_base_dir: str, args) -> List[Dict[str, Any]]: 375 | """Main evaluation function using multi-threading""" 376 | # Limit number of items if specified 377 | if args.max_num is not None: 378 | items = items[:args.max_num] 379 | 380 | # Extract QA tasks based on field type 381 | if args.qa_field == "qa_pairs": 382 | qa_tasks = extract_qa_tasks_standard(items, image_base_dir) 383 | elif args.qa_field == "annotated_qa_pairs": 384 | qa_tasks = extract_qa_tasks_annotated(items, image_base_dir, args.viz_mode, args) 385 | else: 386 | raise ValueError(f"Unsupported qa_field: {args.qa_field}") 387 | 388 | if not qa_tasks: 389 | print("No QA tasks found!") 390 | return items 391 | 392 | print(f"Found {len(qa_tasks)} QA tasks") 393 | 394 | # Setup API key 395 | api_key = args.api_key or os.getenv("OPENAI_API_KEY") 396 | if not api_key: 397 | raise RuntimeError("OpenAI API key is required. Set --api_key or OPENAI_API_KEY.") 398 | 399 | # Initialize OpenAI client 400 | client = OpenAI(api_key=api_key, base_url=args.api_base_url) 401 | 402 | # Shared base64 cache (thread-safe for reading) 403 | base64_cache: Dict[str, str] = {} 404 | 405 | # Execute tasks in parallel using thread pool 406 | print(f"Processing with {args.num_workers} workers...") 407 | results = [] 408 | 409 | with ThreadPoolExecutor(max_workers=args.num_workers) as executor: 410 | # Submit all tasks 411 | future_to_task = { 412 | executor.submit(process_single_task, task, items, base64_cache, image_base_dir, client, args): task 413 | for task in qa_tasks 414 | } 415 | 416 | # Collect results with progress bar 417 | for future in tqdm(as_completed(future_to_task), total=len(qa_tasks), desc="Calling GPT model"): 418 | try: 419 | result = future.result() 420 | results.append(result) 421 | except Exception as e: 422 | task = future_to_task[future] 423 | print(f"\nError processing task: {e}") 424 | results.append({ 425 | "item_index": task.item_index, 426 | "qa_field": task.qa_field, 427 | "qa_type": task.qa_type, 428 | "qa_index": task.qa_index, 429 | "error": str(e), 430 | "model_answer": "Error", 431 | "is_correct": False, 432 | }) 433 | 434 | # Merge results back into items 435 | for result in results: 436 | try: 437 | qa_entry = get_qa_entry(items[result["item_index"]], result["qa_field"], 438 | result["qa_type"], result["qa_index"]) 439 | qa_entry["model_answer"] = result.get("model_answer", "Error") 440 | qa_entry["model_response"] = result.get("model_response", "") 441 | qa_entry["model_explanation"] = result.get("model_explanation", "") 442 | qa_entry["is_correct"] = result.get("is_correct", False) 443 | 444 | if "visualization_path" in result: 445 | qa_entry["visualization_path"] = result["visualization_path"] 446 | if "viz_mode" in result: 447 | qa_entry["viz_mode"] = result["viz_mode"] 448 | if "error" in result: 449 | qa_entry["error"] = result["error"] 450 | except Exception as e: 451 | print(f"Error merging result: {e}") 452 | 453 | return items 454 | 455 | 456 | def calculate_accuracy_by_dimension(items: List[Dict[str, Any]]) -> Dict[str, Any]: 457 | total_questions = 0 458 | sample_acc_sum = 0.0 459 | sample_count = 0 460 | category_stats: Dict[str, Dict[str, float]] = {} 461 | law_stats: Dict[str, Dict[str, float]] = {} 462 | operation_stats: Dict[str, Dict[str, float]] = {} 463 | 464 | def collect_qas(item: Dict[str, Any]) -> List[Dict[str, Any]]: 465 | qa_sources: List[Dict[str, Any]] = [] 466 | for field, default_type in (("qa_pairs", "qa"), ("annotated_qa_pairs", "annotated_qa")): 467 | container = item.get(field) 468 | if not container: 469 | continue 470 | for _, _, qa in iter_qa_entries(container, default_type): 471 | qa_sources.append(qa) 472 | return qa_sources 473 | 474 | def update_stat(stats: Dict[str, Dict[str, float]], key: str, sample_acc: float, qa_total: int) -> None: 475 | if key not in stats: 476 | stats[key] = {"sum_acc": 0.0, "sample_count": 0, "qa_total": 0} 477 | stats[key]["sum_acc"] += sample_acc 478 | stats[key]["sample_count"] += 1 479 | stats[key]["qa_total"] += qa_total 480 | 481 | for item in items: 482 | category = item.get("physics_category", "unknown") 483 | law = item.get("physics_law", "unknown") 484 | operation = item.get("edit_operation", "unknown") 485 | qa_sources = collect_qas(item) 486 | sample_total = 0 487 | sample_correct = 0 488 | for qa in qa_sources: 489 | if "is_correct" in qa: 490 | sample_total += 1 491 | if qa["is_correct"]: 492 | sample_correct += 1 493 | if sample_total == 0: 494 | continue 495 | sample_acc = sample_correct / sample_total 496 | sample_count += 1 497 | sample_acc_sum += sample_acc 498 | total_questions += sample_total 499 | update_stat(category_stats, category, sample_acc, sample_total) 500 | update_stat(law_stats, law, sample_acc, sample_total) 501 | update_stat(operation_stats, operation, sample_acc, sample_total) 502 | 503 | def calc_accuracy(stats: Dict[str, Dict[str, float]]) -> Dict[str, Any]: 504 | result: Dict[str, Any] = {} 505 | for key, value in stats.items(): 506 | acc = 100.0 * value["sum_acc"] / value["sample_count"] if value["sample_count"] > 0 else 0.0 507 | result[key] = { 508 | "accuracy": acc, 509 | "sample_count": value["sample_count"], 510 | "qa_total": value["qa_total"], 511 | } 512 | return result 513 | 514 | overall_accuracy = (100.0 * sample_acc_sum / sample_count) if sample_count > 0 else 0.0 515 | return { 516 | "overall_accuracy": overall_accuracy, 517 | "sample_count": sample_count, 518 | "qa_total": total_questions, 519 | "by_category": calc_accuracy(category_stats), 520 | "by_law": calc_accuracy(law_stats), 521 | "by_operation": calc_accuracy(operation_stats), 522 | } 523 | 524 | 525 | def main() -> None: 526 | """Main entry point for GPT evaluation""" 527 | parser = argparse.ArgumentParser(description="GPT-based PhysEdit evaluation with multi-threading") 528 | parser.add_argument("--input_json_path", type=str, required=True, help="Path to meta_info.json") 529 | parser.add_argument("--image_base_dir", type=str, default=None, help="Image base directory; defaults to JSON directory") 530 | parser.add_argument("--qa_field", type=str, default="annotated_qa_pairs", 531 | choices=["qa_pairs", "annotated_qa_pairs"], help="Select qa field") 532 | parser.add_argument("--viz_mode", type=str, default="crop_box_and_resize", 533 | choices=["draw_box", "crop_box", "crop_box_and_resize"], help="Visualization mode") 534 | parser.add_argument("--max_num", type=int, default=None, help="Maximum samples to process") 535 | parser.add_argument("--viz_padding", type=int, default=20, help="Padding pixels for crop mode") 536 | parser.add_argument("--box_color", type=str, default="red", help="Bounding box color") 537 | parser.add_argument("--log_question_changes", action="store_true", help="Log question mutations") 538 | parser.add_argument("--img_size", type=int, default=1024, help="Used for output naming consistency") 539 | parser.add_argument("--api_key", type=str, default="", help="OpenAI API key; overrides OPENAI_API_KEY env") 540 | parser.add_argument("--api_base_url", type=str, default="https://api.openai.com/v1", 541 | help="OpenAI API base URL") 542 | parser.add_argument("--gpt_model", type=str, default="gpt-5", help="OpenAI multimodal model name") 543 | parser.add_argument("--max_attempts", type=int, default=5, help="Max call retries") 544 | parser.add_argument("--retry_backoff", type=float, default=2.0, help="Exponential backoff base") 545 | parser.add_argument("--num_workers", type=int, default=50, help="Number of parallel worker threads") 546 | args = parser.parse_args() 547 | 548 | # Set default image_base_dir 549 | if args.image_base_dir is None: 550 | args.image_base_dir = os.path.dirname(args.input_json_path) 551 | 552 | # Load data 553 | print(f"Loading data: {args.input_json_path}") 554 | with open(args.input_json_path, "r", encoding="utf-8") as f: 555 | data = json.load(f) 556 | 557 | print(f"QA field: {args.qa_field}") 558 | if args.qa_field == "annotated_qa_pairs": 559 | print(f"Visualization mode: {args.viz_mode}") 560 | print(f"Number of workers: {args.num_workers}") 561 | 562 | # Run evaluation 563 | print("Running GPT evaluation...") 564 | results = evaluate_with_gpt(data, args.image_base_dir, args) 565 | 566 | out_dir = os.path.dirname(args.input_json_path) 567 | base_name = os.path.splitext(os.path.basename(args.input_json_path))[0] 568 | 569 | suffix = f"_gpt_output_{args.img_size}" 570 | if args.qa_field == "annotated_qa_pairs": 571 | suffix += f"_{args.viz_mode}" 572 | out_path = os.path.join(out_dir, base_name + suffix + ".json") 573 | with open(out_path, "w", encoding="utf-8") as f: 574 | json.dump(results, f, ensure_ascii=False, indent=2) 575 | 576 | print("Computing statistics...") 577 | stats = calculate_accuracy_by_dimension(results) 578 | analysis_suffix = suffix.replace("_output_", "_analysis_") 579 | analysis_path = os.path.join(out_dir, base_name + analysis_suffix + ".json") 580 | with open(analysis_path, "w", encoding="utf-8") as f: 581 | json.dump(stats, f, ensure_ascii=False, indent=2) 582 | 583 | print("\n=== GPT evaluation finished ===") 584 | print(f"Overall accuracy: {stats['overall_accuracy']:.2f}% (samples: {stats['sample_count']}, qa_total: {stats['qa_total']})") 585 | print("\nBy category:") 586 | for category, stat in stats["by_category"].items(): 587 | print(f" {category}: {stat['accuracy']:.2f}% (samples: {stat['sample_count']}, qa_total: {stat['qa_total']})") 588 | print("\nBy law:") 589 | for law, stat in stats["by_law"].items(): 590 | print(f" {law}: {stat['accuracy']:.2f}% (samples: {stat['sample_count']}, qa_total: {stat['qa_total']})") 591 | print("\nBy operation:") 592 | for operation, stat in stats["by_operation"].items(): 593 | print(f" {operation}: {stat['accuracy']:.2f}% (samples: {stat['sample_count']}, qa_total: {stat['qa_total']})") 594 | 595 | print("\nOutputs saved:") 596 | print(f" Detailed results: {out_path}") 597 | print(f" Analysis: {analysis_path}") 598 | 599 | 600 | if __name__ == "__main__": 601 | random.seed(42) 602 | main() 603 | -------------------------------------------------------------------------------- /PicaEval_qwen.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # ============================================================================ 3 | # :: PhysEdit Evaluation Script (vLLM backend) 4 | # :: Supports qa_pairs / annotated_qa_pairs flows 5 | # :: Includes visualization helpers (draw_box / crop_box modes) 6 | # ============================================================================ 7 | 8 | import multiprocessing as mp 9 | if mp.get_start_method(allow_none=True) != "spawn": 10 | mp.set_start_method("spawn", force=True) 11 | 12 | import os 13 | import sys 14 | import json 15 | import argparse 16 | import random 17 | import re 18 | from dataclasses import dataclass 19 | from pathlib import Path 20 | from typing import List, Dict, Any, Tuple, Optional, Iterable 21 | from PIL import Image, ImageDraw 22 | import PIL 23 | from tqdm import tqdm 24 | 25 | from vllm import LLM, SamplingParams 26 | from transformers import AutoProcessor 27 | 28 | IMAGE_PATH_KEYS: Tuple[str, ...] = ("output_path", "output_image_path", "output_img_path") 29 | JSON_PATTERN = re.compile(r'\{[^{}]*"answer"[^{}]*"explanation"[^{}]*\}', re.IGNORECASE | re.DOTALL) 30 | 31 | 32 | @dataclass 33 | class QATask: 34 | item_index: int 35 | qa_field: str 36 | qa_type: str 37 | qa_index: int 38 | image_path: str 39 | question: str 40 | answer: str 41 | source: str 42 | 43 | def resolve_image_path(item: Dict[str, Any], base_dir: str) -> Optional[str]: 44 | """Resolve image path by inspecting common metadata keys.""" 45 | base_path = Path(base_dir).expanduser().resolve() 46 | base_name = base_path.name 47 | for key in IMAGE_PATH_KEYS: 48 | raw_path = item.get(key) 49 | if not raw_path: 50 | continue 51 | 52 | candidate_path = Path(raw_path) 53 | if candidate_path.is_absolute(): 54 | return str(candidate_path) 55 | 56 | normalized_rel = Path(*candidate_path.parts) 57 | candidate = base_path / normalized_rel 58 | if candidate.exists(): 59 | return str(candidate) 60 | 61 | rel_parts = normalized_rel.parts 62 | if rel_parts and rel_parts[0] == base_name: 63 | alt = base_path / Path(*rel_parts[1:]) 64 | if alt.exists(): 65 | return str(alt) 66 | return str(alt) 67 | 68 | return str(candidate) 69 | return None 70 | 71 | def iter_qa_entries(container: Any, default_type: str) -> Iterable[Tuple[str, int, Dict[str, Any]]]: 72 | """Iterate QA entries uniformly for dict or list containers.""" 73 | if isinstance(container, dict): 74 | for qa_type, qa_list in container.items(): 75 | for idx, qa in enumerate(qa_list): 76 | yield qa_type, idx, qa 77 | elif isinstance(container, list): 78 | for idx, qa in enumerate(container): 79 | yield default_type, idx, qa 80 | 81 | def draw_box_on_image(image: Image.Image, box_info: Dict, box_color="red") -> Image.Image: 82 | """Draw a bounding box on the resized image.""" 83 | img_copy = image.copy() 84 | draw = ImageDraw.Draw(img_copy) 85 | 86 | # :: Derive bounding box coordinates 87 | x = box_info.get("x", 0) 88 | y = box_info.get("y", 0) 89 | width = box_info.get("width", 0) 90 | height = box_info.get("height", 0) 91 | 92 | # :: Render rectangle outline 93 | bbox = [x, y, x + width, y + height] 94 | draw.rectangle(bbox, outline=box_color, width=3) 95 | 96 | return img_copy 97 | 98 | def resize_image(image: Image.Image) -> Image.Image: 99 | """Resize image proportionally so the long edge becomes 1024 pixels.""" 100 | width, height = image.size 101 | long_edge = max(width, height) 102 | 103 | if long_edge == 1024: 104 | return image 105 | 106 | scale = 1024.0 / long_edge 107 | new_width = int(width * scale) 108 | new_height = int(height * scale) 109 | 110 | return image.resize((new_width, new_height), Image.LANCZOS) 111 | 112 | def crop_image_with_box(image: Image.Image, box_info: Dict, padding=20) -> Image.Image: 113 | """Crop image around the bounding box with optional padding.""" 114 | x = box_info.get("x", 0) 115 | y = box_info.get("y", 0) 116 | width = box_info.get("width", 0) 117 | height = box_info.get("height", 0) 118 | 119 | # :: Apply padding while respecting image bounds 120 | x1 = max(0, x - padding) 121 | y1 = max(0, y - padding) 122 | x2 = min(image.size[0], x + width + padding) 123 | y2 = min(image.size[1], y + height + padding) 124 | cropped_image = image.crop((x1, y1, x2, y2)) 125 | return cropped_image 126 | 127 | def generate_viz_filename(item_index: int, qa_type: str, question_index: int, viz_mode: str) -> str: 128 | """Create a deterministic visualization filename.""" 129 | qa_type_short = qa_type.replace(" ", "_").replace("QA", "qa") 130 | return f"{item_index:04d}_{qa_type_short}_{question_index:03d}_{viz_mode}.jpg" 131 | 132 | def create_visualization_and_question(output_img_path: str, qa_info: Dict, item_index: int, 133 | qa_type: str, qa_index: int, viz_mode: str, 134 | viz_dir: str, args) -> Tuple[str, str]: 135 | """Create visualization asset and return the text prompt for the model.""" 136 | question = qa_info.get("question", "") 137 | box_info = qa_info.get("box", {}) 138 | 139 | # :: Construct file paths for visualization assets 140 | viz_filename = generate_viz_filename(item_index, qa_type, qa_index, viz_mode) 141 | viz_path = os.path.join(viz_dir, viz_filename) 142 | # :: Relative path must align with viz_dir for downstream loading 143 | viz_rel_path = os.path.join(f"visualization_annotated_qa_{viz_mode}", viz_filename) 144 | 145 | # :: Ensure target directory exists 146 | os.makedirs(viz_dir, exist_ok=True) 147 | 148 | try: 149 | # :: Load and resize image to match annotation scale (long edge = 1024) 150 | image = Image.open(output_img_path).convert("RGB") 151 | image_resized = resize_image(image) 152 | 153 | if viz_mode == "draw_box": 154 | # :: Draw-box mode keeps question intact and overlays box 155 | vllm_question = question 156 | viz_image = draw_box_on_image(image_resized, box_info, args.box_color) 157 | elif viz_mode == "crop_box": 158 | # :: Crop mode preserves question and crops the image 159 | vllm_question = question 160 | viz_image = crop_image_with_box(image_resized, box_info, args.viz_padding) 161 | elif viz_mode == "crop_box_and_resize": 162 | # :: Crop+resize mode crops the region then upscales 163 | vllm_question = question 164 | viz_image = crop_image_with_box(image_resized, box_info, args.viz_padding) 165 | viz_image = resize_image(viz_image) 166 | else: 167 | raise ValueError(f"Unknown viz_mode: {viz_mode}") 168 | 169 | # :: Persist visualization image 170 | viz_image.save(viz_path, quality=90) 171 | 172 | # :: Optionally log prompt mutations 173 | if args.log_question_changes and question != vllm_question: 174 | print(f"Question modified ({viz_mode}): {question} -> {vllm_question}") 175 | 176 | return vllm_question, viz_rel_path 177 | 178 | except Exception as e: 179 | print(f"Error creating visualization for {viz_path}: {e}") 180 | return question, "" 181 | 182 | def parse_structured_response(response: str) -> Tuple[str, str]: 183 | """Parse structured JSON answer from model output.""" 184 | return _parse_json_response(response) or _fallback_parse_response(response) 185 | 186 | def _parse_json_response(response: str) -> Tuple[str, str] | None: 187 | """Attempt to decode JSON-formatted response.""" 188 | try: 189 | response_clean = response.strip() 190 | 191 | # :: Collect possible JSON snippets 192 | json_candidates = [] 193 | 194 | # :: Strategy 1: capture outermost braces 195 | json_start = response_clean.find('{') 196 | json_end = response_clean.rfind('}') + 1 197 | if json_start != -1 and json_end > json_start: 198 | json_candidates.append(response_clean[json_start:json_end]) 199 | 200 | # :: Strategy 2: apply regex extraction 201 | json_candidates.extend(JSON_PATTERN.findall(response_clean)) 202 | 203 | # :: Try parsing each candidate 204 | for json_str in json_candidates: 205 | try: 206 | data = json.loads(json_str) 207 | answer = data.get("answer", "").strip() 208 | explanation = data.get("explanation", "").strip() 209 | 210 | if answer: 211 | return _normalize_answer(answer), explanation 212 | except json.JSONDecodeError: 213 | continue 214 | 215 | except Exception: 216 | pass 217 | 218 | return None 219 | 220 | def _normalize_answer(answer: str) -> str: 221 | """Normalize canonical yes/no answer casing.""" 222 | answer_lower = answer.lower().strip().rstrip('.') 223 | if answer_lower in ["yes", "y", "true"]: 224 | return "Yes" 225 | elif answer_lower in ["no", "n", "false"]: 226 | return "No" 227 | return answer.strip() 228 | 229 | def _fallback_parse_response(response: str) -> Tuple[str, str]: 230 | """Fallback parser when structured JSON is unavailable.""" 231 | return extract_yes_no_answer(response), "" 232 | 233 | def extract_yes_no_answer(response: str) -> str: 234 | """Extract yes/no answer heuristically.""" 235 | response_lower = response.lower().strip() 236 | 237 | # :: Prefer prefix match 238 | if response_lower.startswith("yes"): 239 | return "Yes" 240 | elif response_lower.startswith("no"): 241 | return "No" 242 | 243 | # :: Search for whole-word matches 244 | if re.search(r'\byes\b', response_lower): 245 | return "Yes" 246 | elif re.search(r'\bno\b', response_lower): 247 | return "No" 248 | 249 | # :: Leave snippet when answer is unclear 250 | return response[:10] if response else "Unknown" 251 | 252 | def init_llm(model_path: str, tp: int, dtype: str, gpu_util: float, max_len: int): 253 | """Initialize vLLM engine.""" 254 | llm = LLM( 255 | model=model_path, 256 | tensor_parallel_size=tp, 257 | dtype=dtype, 258 | gpu_memory_utilization=gpu_util, 259 | trust_remote_code=True, 260 | max_model_len=max_len, 261 | ) 262 | return llm 263 | 264 | def create_structured_message(msgs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 265 | """Append structured JSON instructions to the final user turn.""" 266 | structured_msg = msgs.copy() 267 | if not structured_msg or structured_msg[-1]["role"] != "user": 268 | return structured_msg 269 | 270 | # :: Structured answer instruction 271 | json_instruction = "\n\nPlease provide a structured answer in the following JSON format:\n{\"answer\": \"Yes\" or \"No\", \"explanation\": \"Brief explanation of your reasoning\"}\n\nOutput ONLY valid JSON. No extra text." 272 | 273 | original_content = structured_msg[-1]["content"] 274 | if isinstance(original_content, list): 275 | # :: Multi-modal payload 276 | structured_content = original_content + [{"type": "text", "text": json_instruction}] 277 | else: 278 | # :: Text-only payload 279 | structured_content = original_content + json_instruction 280 | 281 | structured_msg[-1]["content"] = structured_content 282 | return structured_msg 283 | 284 | def prepare_vllm_batch(qa_tasks: List[QATask]) -> Tuple[List[List[Dict[str, Any]]], List[List[Image.Image]]]: 285 | """Prepare messages and images for vLLM batching.""" 286 | msgs_batch: List[List[Dict[str, Any]]] = [] 287 | imgs_batch: List[List[Image.Image]] = [] 288 | pil_cache: Dict[str, Image.Image] = {} 289 | 290 | for task in tqdm(qa_tasks, desc="Preparing images"): 291 | if task.image_path not in pil_cache: 292 | try: 293 | img = Image.open(task.image_path).convert("RGB") 294 | img = resize_image(img) 295 | pil_cache[task.image_path] = img 296 | except Exception as e: 297 | print(f"Error loading image {task.image_path}: {e}") 298 | pil_cache[task.image_path] = Image.new("RGB", (512, 512), "white") 299 | img = pil_cache[task.image_path] 300 | msgs_batch.append([{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": task.question}]}]) 301 | imgs_batch.append([img]) 302 | 303 | return msgs_batch, imgs_batch 304 | 305 | def build_vllm_inputs(processor: AutoProcessor, 306 | batch_msgs: List[List[Dict[str, Any]]], 307 | batch_images: List[List[Image.Image]]) -> List[Dict[str, Any]]: 308 | """Build vLLM input payloads.""" 309 | # :: Attach structured-answer instructions 310 | structured_msgs = [create_structured_message(msgs) for msgs in batch_msgs] 311 | 312 | # :: Render prompts via chat template 313 | prompts = [ 314 | processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) 315 | for msgs in structured_msgs 316 | ] 317 | 318 | # :: Assemble vLLM-ready dicts 319 | return [ 320 | {"prompt": text, "multi_modal_data": {"image": imgs}} 321 | for text, imgs in zip(prompts, batch_images) 322 | ] 323 | 324 | def vllm_generate_batch(llm: LLM, processor: AutoProcessor, 325 | batch_msgs: List[List[Dict[str, Any]]], 326 | batch_images: List[List[Image.Image]], 327 | max_new_tokens: int) -> List[str]: 328 | """Run batched inference through vLLM.""" 329 | # :: Prepare inputs 330 | inputs = build_vllm_inputs(processor, batch_msgs, batch_images) 331 | 332 | # :: Configure sampling strategy 333 | sp = SamplingParams( 334 | temperature=0.0, 335 | top_p=1.0, 336 | max_tokens=max_new_tokens, 337 | stop=["\uFFFD\uFFFD", "\n\uFFFD\uFFFD", "\U0001F4D0\n"] 338 | ) 339 | 340 | # :: Execute generation 341 | outputs = llm.generate(inputs, sp) 342 | return [o.outputs[0].text if o.outputs else "" for o in outputs] 343 | 344 | def get_qa_entry(item: Dict[str, Any], qa_field: str, qa_type: str, qa_index: int) -> Dict[str, Any]: 345 | """Locate QA entry for annotation updates.""" 346 | container = item.get(qa_field) 347 | if container is None and qa_field == "qa_pairs": 348 | container = item.get("annotated_qa_pairs") 349 | if isinstance(container, dict): 350 | return container[qa_type][qa_index] 351 | if isinstance(container, list): 352 | return container[qa_index] 353 | raise KeyError(f"QA entry not found for field={qa_field}, type={qa_type}, index={qa_index}") 354 | 355 | 356 | def extract_qa_tasks_standard(items: List[Dict[str, Any]], image_base_dir: str) -> List[QATask]: 357 | """Collect tasks from qa_pairs field.""" 358 | qa_tasks: List[QATask] = [] 359 | for item_idx, item in enumerate(items): 360 | image_path = resolve_image_path(item, image_base_dir) 361 | if not image_path: 362 | continue 363 | qa_container = item.get("qa_pairs") 364 | if qa_container is None: 365 | qa_container = item.get("annotated_qa_pairs", {}) 366 | for qa_type, qa_idx, qa in iter_qa_entries(qa_container, "qa"): 367 | question = qa.get("question") 368 | answer = qa.get("answer") 369 | if question and answer: 370 | qa_tasks.append(QATask( 371 | item_index=item_idx, 372 | qa_field="qa_pairs", 373 | qa_type=qa_type, 374 | qa_index=qa_idx, 375 | image_path=image_path, 376 | question=question, 377 | answer=answer, 378 | source="original" 379 | )) 380 | return qa_tasks 381 | 382 | def extract_qa_tasks_annotated(items: List[Dict[str, Any]], image_base_dir: str, viz_mode: str, args) -> List[QATask]: 383 | """Collect annotated QA tasks and generate visualizations.""" 384 | qa_tasks: List[QATask] = [] 385 | viz_dir = os.path.join(image_base_dir, f"visualization_annotated_qa_{viz_mode}") 386 | for item_idx, item in enumerate(items): 387 | image_path = resolve_image_path(item, image_base_dir) 388 | if not image_path: 389 | continue 390 | qa_container = item.get("annotated_qa_pairs", {}) 391 | for qa_type, qa_idx, qa in iter_qa_entries(qa_container, "annotated_qa"): 392 | question = qa.get("question") 393 | answer = qa.get("answer") 394 | if not (question and answer): 395 | continue 396 | vllm_question, viz_rel_path = create_visualization_and_question( 397 | image_path, qa, item_idx, qa_type, qa_idx, viz_mode, viz_dir, args 398 | ) 399 | viz_path = os.path.join(image_base_dir, viz_rel_path) if viz_rel_path else image_path 400 | source = "visualization" if viz_rel_path else "original" 401 | qa_tasks.append(QATask( 402 | item_index=item_idx, 403 | qa_field="annotated_qa_pairs", 404 | qa_type=qa_type, 405 | qa_index=qa_idx, 406 | image_path=viz_path, 407 | question=vllm_question, 408 | answer=answer, 409 | source=source 410 | )) 411 | return qa_tasks 412 | 413 | def evaluate_physedit_with_vllm(items: List[Dict[str, Any]], image_base_dir: str, 414 | processor: AutoProcessor, llm: LLM, args) -> List[Dict[str, Any]]: 415 | """Main evaluation routine.""" 416 | if args.max_num is not None: 417 | items = items[:args.max_num] 418 | 419 | if args.qa_field == "qa_pairs": 420 | qa_tasks = extract_qa_tasks_standard(items, image_base_dir) 421 | elif args.qa_field == "annotated_qa_pairs": 422 | qa_tasks = extract_qa_tasks_annotated(items, image_base_dir, args.viz_mode, args) 423 | else: 424 | raise ValueError(f"Unsupported qa_field: {args.qa_field}") 425 | 426 | if not qa_tasks: 427 | print("No QA tasks found!") 428 | return items 429 | 430 | print(f"Found {len(qa_tasks)} QA tasks") 431 | msgs_batch, imgs_batch = prepare_vllm_batch(qa_tasks) 432 | print("Running VLM inference...") 433 | answers = vllm_generate_batch(llm, processor, msgs_batch, imgs_batch, args.max_new_tokens) 434 | 435 | for task, model_response in zip(qa_tasks, answers): 436 | model_answer, model_explanation = parse_structured_response(model_response) 437 | gt_clean = task.answer.lower().strip().rstrip('.') 438 | model_clean = model_answer.lower().strip().rstrip('.') 439 | is_correct = gt_clean == model_clean 440 | qa_entry = get_qa_entry(items[task.item_index], task.qa_field, task.qa_type, task.qa_index) 441 | qa_entry["model_answer"] = model_answer 442 | qa_entry["model_response"] = model_response 443 | qa_entry["model_explanation"] = model_explanation 444 | qa_entry["is_correct"] = is_correct 445 | if task.qa_field == "annotated_qa_pairs" and task.source == "visualization": 446 | viz_rel_path = os.path.relpath(task.image_path, args.image_base_dir) 447 | qa_entry["visualization_path"] = viz_rel_path 448 | qa_entry["viz_mode"] = args.viz_mode 449 | 450 | return items 451 | 452 | def calculate_accuracy_by_dimension(items: List[Dict[str, Any]]) -> Dict[str, Any]: 453 | """Compute accuracy per dimension using sample-wise averaging.""" 454 | total_questions = 0 455 | sample_acc_sum = 0.0 456 | sample_count = 0 457 | category_stats = {} 458 | law_stats = {} 459 | operation_stats = {} 460 | 461 | def collect_qas(item: Dict[str, Any]) -> List[Dict[str, Any]]: 462 | qa_sources: List[Dict[str, Any]] = [] 463 | for field, default_type in (("qa_pairs", "qa"), ("annotated_qa_pairs", "annotated_qa")): 464 | container = item.get(field) 465 | if not container: 466 | continue 467 | for _, _, qa in iter_qa_entries(container, default_type): 468 | qa_sources.append(qa) 469 | return qa_sources 470 | 471 | def update_stat(stats: Dict, key: str, sample_acc: float, qa_total: int): 472 | if key not in stats: 473 | stats[key] = {"sum_acc": 0.0, "sample_count": 0, "qa_total": 0} 474 | stats[key]["sum_acc"] += sample_acc 475 | stats[key]["sample_count"] += 1 476 | stats[key]["qa_total"] += qa_total 477 | 478 | for item in items: 479 | category = item.get("physics_category", "unknown") 480 | law = item.get("physics_law", "unknown") 481 | operation = item.get("edit_operation", "unknown") 482 | 483 | qa_sources = collect_qas(item) 484 | sample_total = 0 485 | sample_correct = 0 486 | for qa in qa_sources: 487 | if "is_correct" in qa: 488 | sample_total += 1 489 | if qa["is_correct"]: 490 | sample_correct += 1 491 | if sample_total == 0: 492 | continue 493 | 494 | sample_acc = sample_correct / sample_total 495 | sample_count += 1 496 | sample_acc_sum += sample_acc 497 | total_questions += sample_total 498 | 499 | update_stat(category_stats, category, sample_acc, sample_total) 500 | update_stat(law_stats, law, sample_acc, sample_total) 501 | update_stat(operation_stats, operation, sample_acc, sample_total) 502 | 503 | def calc_accuracy(stats: Dict) -> Dict: 504 | result = {} 505 | for key, value in stats.items(): 506 | if value["sample_count"] > 0: 507 | acc = 100.0 * value["sum_acc"] / value["sample_count"] 508 | else: 509 | acc = 0.0 510 | result[key] = { 511 | "accuracy": acc, 512 | "sample_count": value["sample_count"], 513 | "qa_total": value["qa_total"] 514 | } 515 | return result 516 | 517 | overall_accuracy = (100.0 * sample_acc_sum / sample_count) if sample_count > 0 else 0.0 518 | return { 519 | "overall_accuracy": overall_accuracy, 520 | "sample_count": sample_count, 521 | "qa_total": total_questions, 522 | "by_category": calc_accuracy(category_stats), 523 | "by_law": calc_accuracy(law_stats), 524 | "by_operation": calc_accuracy(operation_stats) 525 | } 526 | 527 | def main(): 528 | parser = argparse.ArgumentParser() 529 | parser.add_argument("--input_json_path", type=str, required=True, 530 | help="Path to meta_info.json file") 531 | parser.add_argument("--image_base_dir", type=str, default=None, 532 | help="Image root directory; defaults to the JSON directory") 533 | parser.add_argument("--model_path", type=str, default="pretrained/Qwen/Qwen2.5-VL-72B-Instruct", 534 | help="Model checkpoint path or Hugging Face identifier") 535 | parser.add_argument("--qa_field", type=str, default="annotated_qa_pairs", 536 | choices=["qa_pairs", "annotated_qa_pairs"], 537 | help="Choose which QA field to evaluate") 538 | parser.add_argument("--viz_mode", type=str, default="crop_box_and_resize", 539 | choices=["draw_box", "crop_box", "crop_box_and_resize"], 540 | help="Visualization mode (used only for annotated QA)") 541 | parser.add_argument("--tensor_parallel_size", type=int, default=4, 542 | help="Tensor parallel shard count") 543 | parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16","float16"]) 544 | parser.add_argument("--gpu_mem_util", type=float, default=0.9) 545 | parser.add_argument("--max_model_len", type=int, default=5120) 546 | parser.add_argument("--max_new_tokens", type=int, default=256) 547 | parser.add_argument("--img_size", type=int, default=1024, choices=[512, 1024]) 548 | parser.add_argument("--max_num", type=int, default=None, help="Maximum number of samples to process") 549 | parser.add_argument("--viz_padding", type=int, default=20, help="Padding pixels for crop mode") 550 | parser.add_argument("--box_color", default="red", help="Bounding box color") 551 | parser.add_argument("--log_question_changes", action="store_true", 552 | help="Log question text mutations") 553 | args = parser.parse_args() 554 | 555 | # :: Derive default image_base_dir 556 | if args.image_base_dir is None: 557 | args.image_base_dir = os.path.dirname(args.input_json_path) 558 | 559 | # :: Load dataset 560 | print(f"Loading data: {args.input_json_path}") 561 | with open(args.input_json_path, "r", encoding="utf-8") as f: 562 | data = json.load(f) 563 | 564 | print(f"QA field: {args.qa_field}") 565 | if args.qa_field == "annotated_qa_pairs": 566 | print(f"Visualization mode: {args.viz_mode}") 567 | 568 | # :: Initialize model 569 | print("Initializing model...") 570 | processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True) 571 | llm = init_llm(args.model_path, args.tensor_parallel_size, args.dtype, 572 | args.gpu_mem_util, args.max_model_len) 573 | 574 | # :: Evaluate 575 | print("Starting evaluation...") 576 | results = evaluate_physedit_with_vllm(data, args.image_base_dir, processor, llm, args) 577 | 578 | # :: Persist outputs 579 | out_dir = os.path.dirname(args.input_json_path) 580 | base_name = os.path.splitext(os.path.basename(args.input_json_path))[0] 581 | 582 | # :: Build output filenames 583 | suffix = f"_vllm_output_{args.img_size}" 584 | if args.qa_field == "annotated_qa_pairs": 585 | suffix += f"_{args.viz_mode}" 586 | 587 | out_path = os.path.join(out_dir, base_name + suffix + ".json") 588 | with open(out_path, "w", encoding="utf-8") as f: 589 | json.dump(results, f, ensure_ascii=False, indent=2) 590 | 591 | # :: Compute and save statistics 592 | print("Computing statistics...") 593 | stats = calculate_accuracy_by_dimension(results) 594 | 595 | analysis_suffix = suffix.replace("_output_", "_analysis_") 596 | analysis_path = os.path.join(out_dir, base_name + analysis_suffix + ".json") 597 | with open(analysis_path, "w", encoding="utf-8") as f: 598 | json.dump(stats, f, ensure_ascii=False, indent=2) 599 | 600 | # :: Display summary 601 | print(f"\n=== Evaluation finished ===") 602 | print(f"Overall accuracy: {stats['overall_accuracy']:.2f}% " 603 | f"(samples: {stats['sample_count']}, qa_total: {stats['qa_total']})") 604 | print(f"\nBy category:") 605 | for category, stat in stats["by_category"].items(): 606 | print(f" {category}: {stat['accuracy']:.2f}% " 607 | f"(samples: {stat['sample_count']}, qa_total: {stat['qa_total']})") 608 | print(f"\nBy physics law:") 609 | for law, stat in stats["by_law"].items(): 610 | print(f" {law}: {stat['accuracy']:.2f}% " 611 | f"(samples: {stat['sample_count']}, qa_total: {stat['qa_total']})") 612 | print(f"\nBy operation type:") 613 | for operation, stat in stats["by_operation"].items(): 614 | print(f" {operation}: {stat['accuracy']:.2f}% " 615 | f"(samples: {stat['sample_count']}, qa_total: {stat['qa_total']})") 616 | 617 | print(f"\nOutputs saved:") 618 | print(f" Detailed results: {out_path}") 619 | print(f" Statistics: {analysis_path}") 620 | 621 | if __name__ == "__main__": 622 | random.seed(42) 623 | main() 624 | --------------------------------------------------------------------------------