├── asserts
├── fig1_teaser.png
└── fig2_overview.png
├── generate_response
├── run_llava_qa.sh
├── run_qwen_code.sh
├── run_llava_code.sh
├── run_internvl3_qa.sh
├── run_qwen_qa.sh
├── run_internvl2_5_qa.sh
├── run_internvl3_code.sh
├── run_internvl2_5_code.sh
├── internvl2_5_qa.py
├── internvl3_qa.py
├── openai_qa.py
├── llava_qa.py
├── o1_qa.py
├── qwen_qa.py
├── gemini_qa.py
├── o4-mini_qa.py
├── geminipro_qa.py
├── claude_qa.py
├── internvl2_5_code.py
├── o1_code.py
└── geminipro_code.py
├── requirements.txt
├── README.md
└── calculate_similarity
├── render_img.py
├── clip_score.py
└── gemini_evaluate.py
/asserts/fig1_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mikivishy/FullFront/HEAD/asserts/fig1_teaser.png
--------------------------------------------------------------------------------
/asserts/fig2_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mikivishy/FullFront/HEAD/asserts/fig2_overview.png
--------------------------------------------------------------------------------
/generate_response/run_llava_qa.sh:
--------------------------------------------------------------------------------
1 | export NCCL_SOCKET_IFNAME=bond0
2 | export NCCL_IB_HCA=mlx5_0
3 |
4 | srun \
5 | --partition=MoE \
6 | --mpi=pmi2 \
7 | --job-name=llava_qa \
8 | -c 128 \
9 | --gres=gpu:8 \
10 | --nodes=1 \
11 | --ntasks-per-node=1 \
12 | --kill-on-bad-exit=1 \
13 | --quotatype=spot \
14 | python llava_qa.py
--------------------------------------------------------------------------------
/generate_response/run_qwen_code.sh:
--------------------------------------------------------------------------------
1 | export NCCL_SOCKET_IFNAME=bond0
2 | export NCCL_IB_HCA=mlx5_0
3 |
4 | srun \
5 | --partition=MoE \
6 | --mpi=pmi2 \
7 | --job-name=qwen_code \
8 | -c 42 \
9 | --gres=gpu:4 \
10 | --nodes=1 \
11 | --ntasks-per-node=1 \
12 | --kill-on-bad-exit=1 \
13 | --quotatype=spot \
14 | python qwen_code.py
--------------------------------------------------------------------------------
/generate_response/run_llava_code.sh:
--------------------------------------------------------------------------------
1 | export NCCL_SOCKET_IFNAME=bond0
2 | export NCCL_IB_HCA=mlx5_0
3 |
4 | srun \
5 | --partition=MoE \
6 | --mpi=pmi2 \
7 | --job-name=llava_code \
8 | -c 42 \
9 | --gres=gpu:4 \
10 | --nodes=1 \
11 | --ntasks-per-node=1 \
12 | --kill-on-bad-exit=1 \
13 | --quotatype=auto \
14 | python llava_code.py
--------------------------------------------------------------------------------
/generate_response/run_internvl3_qa.sh:
--------------------------------------------------------------------------------
1 | export NCCL_SOCKET_IFNAME=bond0
2 | export NCCL_IB_HCA=mlx5_0
3 |
4 | srun \
5 | --partition=MoE \
6 | --mpi=pmi2 \
7 | --job-name=internvl_3 \
8 | -c 128 \
9 | --gres=gpu:8 \
10 | --nodes=1 \
11 | --ntasks-per-node=1 \
12 | --kill-on-bad-exit=1 \
13 | --quotatype=spot \
14 | python internvl3_qa.py
--------------------------------------------------------------------------------
/generate_response/run_qwen_qa.sh:
--------------------------------------------------------------------------------
1 | export NCCL_SOCKET_IFNAME=bond0
2 | export NCCL_IB_HCA=mlx5_0
3 |
4 | srun \
5 | --partition=MoE \
6 | --mpi=pmi2 \
7 | --job-name=qwen_qa \
8 | -c 128 \
9 | --gres=gpu:8 \
10 | --nodes=1 \
11 | --ntasks-per-node=1 \
12 | --kill-on-bad-exit=1 \
13 | --quotatype=spot \
14 | python qwen_qa.py \
15 | --
--------------------------------------------------------------------------------
/generate_response/run_internvl2_5_qa.sh:
--------------------------------------------------------------------------------
1 | export NCCL_SOCKET_IFNAME=bond0
2 | export NCCL_IB_HCA=mlx5_0
3 |
4 | srun \
5 | --partition=MoE \
6 | --mpi=pmi2 \
7 | --job-name=intervl2_5_qa \
8 | -c 128 \
9 | --gres=gpu:8 \
10 | --nodes=1 \
11 | --ntasks-per-node=1 \
12 | --kill-on-bad-exit=1 \
13 | --quotatype=spot \
14 | python internvl2_5_qa.py
--------------------------------------------------------------------------------
/generate_response/run_internvl3_code.sh:
--------------------------------------------------------------------------------
1 | export NCCL_SOCKET_IFNAME=bond0
2 | export NCCL_IB_HCA=mlx5_0
3 |
4 | srun \
5 | --partition=MoE \
6 | --mpi=pmi2 \
7 | --job-name=internvl3_code \
8 | -c 42 \
9 | --gres=gpu:4 \
10 | --nodes=1 \
11 | --ntasks-per-node=1 \
12 | --kill-on-bad-exit=1 \
13 | --quotatype=spot \
14 | python internvl3_code.py
--------------------------------------------------------------------------------
/generate_response/run_internvl2_5_code.sh:
--------------------------------------------------------------------------------
1 | export NCCL_SOCKET_IFNAME=bond0
2 | export NCCL_IB_HCA=mlx5_0
3 |
4 | srun \
5 | --partition=MoE \
6 | --mpi=pmi2 \
7 | --job-name=internvl2_5_code \
8 | -c 42 \
9 | --gres=gpu:4 \
10 | --nodes=1 \
11 | --ntasks-per-node=1 \
12 | --kill-on-bad-exit=1 \
13 | --quotatype=spot \
14 | python internvl2_5_code.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # FullFront Project Dependencies
2 | # Model Response Generation
3 | anthropic>=0.8.0
4 | openai>=1.0.0
5 | google-generativeai>=0.3.0
6 | transformers>=4.35.0
7 | pillow>=9.0.0
8 | pandas>=2.0.0
9 | numpy>=1.24.0
10 | tqdm>=4.65.0
11 | requests>=2.30.0
12 |
13 | # HTML Rendering
14 | playwright>=1.40.0
15 |
16 | # Similarity Calculation
17 | torch>=2.0.0
18 | torchvision>=0.15.0
19 | clip-openai>=1.0
20 | opencv-python>=4.8.0
21 | sentence-transformers>=2.2.2
22 | bs4>=0.0.1
23 | beautifulsoup4>=4.12.0
24 | html-similarity>=0.3.3
25 | lxml>=4.9.0
26 | difflib>=0.1.0
27 |
28 | # Data Processing
29 | pyarrow>=14.0.0
30 | fastparquet>=2023.10.0
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FullFront
2 |
3 |
4 |

5 |
6 |
7 | FullFront is a comprehensive benchmark for evaluating Multimodal Large Language Models (MLLMs) across the entire front-end engineering workflow. This project provides code generation, page understanding, and evaluation tools to measure MLLMs' performance at various stages of front-end development.
8 |
9 |
10 |

11 |
12 |
13 | ## Project Overview
14 |
15 | The FullFront benchmark covers three core tasks in front-end engineering:
16 | 1. **Webpage Design** - Assessing the model's ability to organize and structure visual elements
17 | 2. **Webpage Perception QA** - Evaluating the model's understanding of visual organization, element characteristics, and spatial relationships
18 | 3. **Webpage Code Generation** - Focusing on the ability to accurately translate visual designs into functional code
19 |
20 | ## Key Features
21 |
22 | - Supports evaluation of multiple mainstream multimodal models (Claude, OpenAI, Gemini, etc.)
23 | - Provides a complete code generation and evaluation workflow
24 | - Includes image similarity and code quality assessment metrics
25 | - Automatically renders HTML into images for evaluation
26 |
27 | ## Installation
28 |
29 | 1. Clone this repository:
30 | ```bash
31 | git clone https://github.com/your-username/FullFront.git
32 | cd FullFront
33 | ```
34 |
35 | 2. Install dependencies:
36 | ```bash
37 | pip install -r requirements.txt
38 | ```
39 |
40 | ## Usage Guide
41 |
42 | ### Generating Model Responses
43 |
44 | The `generate_response` directory contains scripts for generating responses from different models:
45 |
46 | 1. **Set API Keys**: Based on the model you're using, set the API key in the corresponding script.
47 |
48 | 2. **Run Generation Scripts**:
49 | ```bash
50 | cd generate_response
51 | python claude_code.py # Generate code using Claude model
52 | python openai_code.py # Generate code using OpenAI model
53 | python gemini_code.py # Generate code using Gemini model
54 | ```
55 |
56 | 3. **Use Shell Scripts for Batch Processing**:
57 | ```bash
58 | bash run_llava_code.sh # Run LLaVA model code generation tasks
59 | bash run_qwen_qa.sh # Run Qwen model QA tasks
60 | ```
61 |
62 | The generated results will be saved in the `generate_response/results/{model_name}` directory.
63 |
64 | ### Rendering HTML to Images
65 |
66 | Use `calculate_similarity/render_img.py` to render generated HTML into images:
67 |
68 | ```bash
69 | python calculate_similarity/render_img.py
70 | ```
71 |
72 | You can modify the input and output directories in this script:
73 | ```python
74 | html_folder = "./path/to/your/html/files"
75 | screenshot_folder = "./path/to/save/screenshots"
76 | ```
77 |
78 | ### Calculating Similarity Scores
79 |
80 | 1. **CLIP Similarity**: Evaluate semantic similarity between generated images and target images
81 | ```bash
82 | python calculate_similarity/clip_score.py
83 | ```
84 |
85 | 2. **Code Similarity**: Evaluate structure and content similarity between generated code and standard code
86 | ```bash
87 | python calculate_similarity/code_score.py
88 | ```
89 |
90 | 3. **Gemini Evaluation**: Use Gemini model to evaluate generated content
91 | ```bash
92 | python calculate_similarity/gemini_evaluate.py
93 | ```
94 |
95 | ### Result Analysis
96 |
97 | Evaluation results will be saved in the `calculate_similarity/results/` directory, containing the following metrics:
98 | - CLIP similarity score
99 | - Code structure similarity
100 | - Code content similarity
--------------------------------------------------------------------------------
/calculate_similarity/render_img.py:
--------------------------------------------------------------------------------
1 | """
2 | HTML to Image Renderer
3 |
4 | Usage:
5 | 1. Run the script directly: python render_img.py
6 | 2. Or import the capture_screenshot function:
7 | from render_img import capture_screenshot
8 | capture_screenshot(html_file_path, screenshot_folder)
9 |
10 | This script renders HTML files to PNG images using Playwright.
11 | It waits for all images to load, handles lazy-loaded content, and
12 | captures full-page screenshots.
13 | """
14 |
15 | from playwright.sync_api import sync_playwright
16 | import os
17 | import time
18 |
19 | def capture_screenshot(html_file_path, screenshot_folder):
20 | """
21 | Capture a screenshot of an HTML file after all images are loaded.
22 |
23 | Args:
24 | html_file_path: Path to the HTML file
25 | screenshot_folder: Path to save screenshots
26 |
27 | Returns:
28 | bool: True if successful, False otherwise
29 | """
30 | try:
31 | # Ensure HTML file exists
32 | if not os.path.exists(html_file_path):
33 | raise FileNotFoundError(f"HTML file not found: {html_file_path}")
34 |
35 | # Get absolute path
36 | absolute_path = os.path.abspath(html_file_path)
37 | file_url = f"file://{absolute_path}"
38 |
39 | # Generate screenshot filename
40 | html_file_name = os.path.splitext(os.path.basename(html_file_path))[0]
41 | screenshot_path = os.path.join(screenshot_folder, f"{html_file_name}.png")
42 |
43 | with sync_playwright() as p:
44 | # Launch browser
45 | browser = p.chromium.launch(headless=True)
46 | page = browser.new_page()
47 |
48 | # Navigate to HTML file
49 | page.goto(file_url)
50 |
51 | # Wait for page to load
52 | page.wait_for_load_state("load")
53 |
54 | # Wait for all images to fully load
55 | page.evaluate("""
56 | () => {
57 | return new Promise((resolve) => {
58 | const images = document.querySelectorAll('img');
59 | let loadedImages = 0;
60 |
61 | // If no images, resolve immediately
62 | if (images.length === 0) {
63 | return resolve();
64 | }
65 |
66 | // Add load event handlers to each image
67 | images.forEach(img => {
68 | // Already loaded images
69 | if (img.complete) {
70 | loadedImages++;
71 | if (loadedImages === images.length) {
72 | resolve();
73 | }
74 | } else {
75 | // Listen for load events
76 | img.addEventListener('load', () => {
77 | loadedImages++;
78 | if (loadedImages === images.length) {
79 | resolve();
80 | }
81 | });
82 |
83 | // Listen for error events
84 | img.addEventListener('error', () => {
85 | loadedImages++;
86 | if (loadedImages === images.length) {
87 | resolve();
88 | }
89 | });
90 | }
91 | });
92 | });
93 | }
94 | """)
95 |
96 | # Simulate scrolling to load lazy-loaded elements
97 | page.evaluate("""
98 | () => {
99 | return new Promise((resolve) => {
100 | let totalHeight = 0;
101 | const distance = 100;
102 | const timer = setInterval(() => {
103 | const scrollHeight = document.body.scrollHeight;
104 | window.scrollBy(0, distance);
105 | totalHeight += distance;
106 |
107 | // Stop if reached bottom or scrolled enough
108 | if(totalHeight >= scrollHeight || totalHeight > 10000){
109 | clearInterval(timer);
110 | // Scroll back to top
111 | window.scrollTo(0, 0);
112 | setTimeout(resolve, 100); // Give time for page to stabilize
113 | }
114 | }, 100);
115 | });
116 | }
117 | """)
118 |
119 | # Capture screenshot
120 | page.screenshot(path=screenshot_path, full_page=True)
121 |
122 | # Close browser
123 | browser.close()
124 |
125 | print(f"Screenshot saved to: {screenshot_path}")
126 | return True
127 | except Exception as e:
128 | print(f"Error processing file {html_file_path}: {str(e)}")
129 | return False
130 |
131 | if __name__ == "__main__":
132 | start_time = time.time()
133 | # HTML files directory
134 | html_folder = "./input/html" # Replace with your HTML files directory
135 | # Screenshots output directory
136 | screenshot_folder = "./output/images" # Replace with your screenshots directory
137 |
138 | # Ensure screenshot folder exists
139 | if not os.path.exists(screenshot_folder):
140 | os.makedirs(screenshot_folder)
141 |
142 | # Get all HTML files
143 | html_files = [f for f in os.listdir(html_folder) if f.endswith(".html")]
144 |
145 | # Sort alphabetically
146 | html_files.sort()
147 |
148 | # Get existing screenshot names (without extension)
149 | existing_screenshots = set()
150 | if os.path.exists(screenshot_folder):
151 | for file in os.listdir(screenshot_folder):
152 | if file.endswith(".png"):
153 | existing_screenshots.add(os.path.splitext(file)[0])
154 |
155 | # Track failed files
156 | failed_indices = []
157 | # Count skipped files
158 | skipped_count = 0
159 |
160 | # Process each HTML file
161 | for index, html_file in enumerate(html_files):
162 | # Get HTML filename without extension
163 | html_file_name = os.path.splitext(html_file)[0]
164 |
165 | # Skip if screenshot already exists
166 | if html_file_name in existing_screenshots:
167 | print(f"Screenshot exists, skipping [{index+1}/{len(html_files)}]: {html_file}")
168 | skipped_count += 1
169 | continue
170 |
171 | print(f"Processing [{index+1}/{len(html_files)}]: {html_file}")
172 | html_file_path = os.path.join(html_folder, html_file)
173 | success = capture_screenshot(html_file_path, screenshot_folder)
174 | if not success:
175 | failed_indices.append(index)
176 |
177 | # Print processing statistics
178 | print("\nProcessing Statistics:")
179 | print(f"Total files: {len(html_files)}")
180 | print(f"Skipped: {skipped_count}")
181 | print(f"Newly processed: {len(html_files) - skipped_count - len(failed_indices)}")
182 |
183 | # Print failed files
184 | if failed_indices:
185 | print("\nFailed files indices:")
186 | for idx in failed_indices:
187 | print(f"Index {idx}: {html_files[idx]}")
188 | print(f"Failed count: {len(failed_indices)}")
189 | else:
190 | print("\nAll files processed successfully!")
191 | end_time = time.time()
192 | print(f"Total processing time: {end_time - start_time} seconds")
--------------------------------------------------------------------------------
/generate_response/internvl2_5_qa.py:
--------------------------------------------------------------------------------
1 | """
2 | InternVL2_5 QA Generation Tool
3 |
4 | Usage:
5 | 1. Place your data files (.parquet) in a 'data' folder in the same directory
6 | 2. Run this script: python internvl2_5_qa.py
7 | 3. Results will be saved in the 'results' folder
8 |
9 | You can configure model settings, categories, and other parameters at the bottom of this script.
10 | """
11 |
12 | import os
13 | import json
14 | import base64
15 | import glob
16 | import pandas as pd
17 | from PIL import Image
18 | from io import BytesIO
19 | from vllm import LLM, SamplingParams
20 |
21 | def load_parquet_data(file_path):
22 | """Load data from a single Parquet file."""
23 | try:
24 | df = pd.read_parquet(file_path)
25 | return df.to_dict('records')
26 | except Exception as e:
27 | print(f"Failed to load file {file_path}: {e}")
28 | return []
29 |
30 | def load_existing_results(output_path):
31 | """Load existing results file to skip already processed data."""
32 | if os.path.exists(output_path):
33 | try:
34 | with open(output_path, 'r', encoding='utf-8') as f:
35 | return json.load(f)
36 | except Exception as e:
37 | print(f"Failed to load existing results from {output_path}: {e}")
38 | return []
39 |
40 | def save_results_to_json(results, output_path="output.json"):
41 | """Save results to a JSON file."""
42 | try:
43 | with open(output_path, 'w', encoding='utf-8') as f:
44 | json.dump(results, f, ensure_ascii=False, indent=4)
45 | print(f"Results saved to: {output_path}")
46 | except Exception as e:
47 | print(f"Failed to save results to {output_path}: {e}")
48 |
49 | def create_prompt(prompt, question, choices):
50 | """Create a prompt for InternVL2_5"""
51 | return f"USER: \n{prompt}\n'Question:'{question}\n'Choices:'{choices}\nASSISTANT:"
52 |
53 | def process_data(data_list, llm, temperature, max_tokens, existing_results=None, output_path=None):
54 | """Process data list, generate model responses, and save results in real-time."""
55 | results = []
56 | sampling_params = SamplingParams(
57 | temperature=temperature,
58 | max_tokens=max_tokens,
59 | )
60 |
61 | # Create a set of processed question IDs to skip
62 | processed_ids = set()
63 | if existing_results:
64 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')}
65 | results = existing_results.copy()
66 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip these.")
67 |
68 | total_items = len(data_list)
69 | newly_processed_count = 0
70 |
71 | for i, item in enumerate(data_list):
72 | question_id = item.get('Question_id')
73 |
74 | # Skip already processed questions
75 | if question_id in processed_ids:
76 | continue
77 |
78 | newly_processed_count += 1
79 |
80 | try:
81 | # Decode base64 image
82 | base64_image = item['Image']
83 | image_bytes = base64.b64decode(base64_image)
84 | image_io = BytesIO(image_bytes)
85 | # Convert to PIL image object
86 | image = Image.open(image_io)
87 |
88 | prompt = item['Prompt']
89 | question = item['Question']
90 | choices = item['Choices']
91 |
92 | # Create prompt
93 | internvl_prompt = create_prompt(prompt, question, choices)
94 |
95 | # Single image inference
96 | inputs = {
97 | "prompt": internvl_prompt,
98 | "multi_modal_data": {
99 | "image": image
100 | },
101 | }
102 |
103 | # Generate response
104 | outputs = llm.generate([inputs], sampling_params=sampling_params)
105 | response = outputs[0].outputs[0].text.strip()
106 |
107 | result = {
108 | "Question_id": question_id,
109 | "Response": response,
110 | "Answer": item.get('Answer'),
111 | "Category": item.get('Category'),
112 | "Png_id": item.get('Png_id')
113 | }
114 |
115 | results.append(result)
116 |
117 | # Save results after every 10 new questions or at the last new question
118 | if newly_processed_count > 0 and (newly_processed_count % 10 == 0 or i == total_items - 1):
119 | if output_path:
120 | save_results_to_json(results, output_path)
121 | print(f"Processed {i + 1}/{total_items} questions (new: {newly_processed_count}), saving results to {output_path}.")
122 |
123 | except Exception as e:
124 | print(f"Error processing Question_id {question_id}: {e}")
125 | result = {
126 | "Question_id": question_id,
127 | "Response": f"Error: {e}",
128 | "Answer": item.get('Answer'),
129 | "Category": item.get('Category'),
130 | "Png_id": item.get('Png_id')
131 | }
132 | results.append(result)
133 |
134 | # Save results when an error occurs
135 | if output_path:
136 | save_results_to_json(results, output_path)
137 | print(f"Error processing Question_id {question_id}, current results saved to {output_path}.")
138 |
139 | # Ensure final results are saved if any new items were processed
140 | if newly_processed_count > 0 and output_path:
141 | save_results_to_json(results, output_path)
142 | print(f"File processing complete, final results saved to: {output_path}")
143 |
144 | return results
145 |
146 | def main(data_folder, model_path, output_base_dir, categories=None, max_tokens=256, temperature=0, tensor_parallel_size=8):
147 | """Main function to process Parquet datasets and generate responses, creating a separate output file for each dataset."""
148 | # Ensure output directory exists
149 | os.makedirs(output_base_dir, exist_ok=True)
150 |
151 | # Find all Parquet files
152 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet"))
153 | files_to_process = []
154 |
155 | # Filter files by categories if specified
156 | if categories:
157 | # Ensure category names don't include .parquet extension for matching
158 | category_basenames = {cat.replace(".parquet", "") for cat in categories}
159 | for file_path in all_parquet_files:
160 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0]
161 | if basename_no_ext in category_basenames:
162 | files_to_process.append(file_path)
163 | print(f"Will process specified category files: {files_to_process}")
164 | # Check for any categories not found
165 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process}
166 | missing_categories = category_basenames - found_basenames
167 | if missing_categories:
168 | print(f"Warning: The following specified category files were not found: {missing_categories}")
169 |
170 | else:
171 | files_to_process = all_parquet_files
172 | print(f"Will process all .parquet files in {data_folder} folder.")
173 |
174 | if not files_to_process:
175 | print("No Parquet files found to process.")
176 | return
177 |
178 | # --- Model Loading ---
179 | print("Loading model...")
180 | try:
181 | llm = LLM(
182 | model=model_path,
183 | max_model_len=32768,
184 | tensor_parallel_size=tensor_parallel_size,
185 | enforce_eager=True
186 | )
187 | print("Model loaded successfully.")
188 | except Exception as e:
189 | print(f"Error loading model: {e}")
190 | return
191 |
192 | # --- Process each file ---
193 | for file_path in files_to_process:
194 | print("-" * 50)
195 | print(f"Starting to process file: {file_path}")
196 |
197 | # 1. Dynamically generate output filename from input filename
198 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet"
199 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA"
200 | # Combine "internvl2_5_" prefix with filename to create JSON filename
201 | output_filename = f"internvl2_5_{dataset_name}.json"
202 | current_output_path = os.path.join(output_base_dir, output_filename)
203 | print(f"Results will be saved to: {current_output_path}")
204 |
205 | # 2. Load existing results for current file
206 | existing_results_for_current_file = load_existing_results(current_output_path)
207 |
208 | # 3. Load Parquet data
209 | data_list = load_parquet_data(file_path)
210 |
211 | # 4. Process data
212 | if data_list:
213 | process_data(
214 | data_list,
215 | llm,
216 | temperature,
217 | max_tokens,
218 | existing_results=existing_results_for_current_file,
219 | output_path=current_output_path
220 | )
221 | print(f"File {file_path} processing completed.")
222 | else:
223 | print(f"No data in file {file_path} or loading failed.")
224 |
225 | print("=" * 50)
226 | print("Processing of all specified files completed.")
227 |
228 | if __name__ == "__main__":
229 | # Relative paths
230 | data_folder = "data" # Data folder path
231 | model_path = "models/InternVL2_5-78B" # Model weights path
232 | output_dir = "results" # Output directory for results
233 |
234 | categories = ["Real-world_QA", "Synthetic_QA", "Multi-window_QA"] # Specify dataset categories to process
235 | # categories = None # Set to None to process all parquet files
236 |
237 | max_tokens = 1024 # Maximum tokens to generate
238 | temperature = 0 # Temperature parameter
239 | tensor_parallel_size = 8 # Use 8 GPUs in parallel
240 |
241 | # Call main function with output directory
242 | main(data_folder, model_path, output_dir, categories, max_tokens, temperature, tensor_parallel_size)
--------------------------------------------------------------------------------
/generate_response/internvl3_qa.py:
--------------------------------------------------------------------------------
1 | """
2 | InternVL3 QA Inference Script
3 |
4 | Usage:
5 | python internvl3_qa.py [--data_folder DATA_PATH] [--model_path MODEL_PATH] [--output_dir OUTPUT_DIR]
6 | [--categories CATEGORY1 CATEGORY2 ...] [--max_tokens MAX_TOKENS]
7 | [--temperature TEMP] [--tensor_parallel_size TP_SIZE]
8 |
9 | Description:
10 | This script processes parquet files containing image-text QA data and generates responses using InternVL3.
11 | It supports processing specific categories of data and continues from previous runs.
12 | """
13 |
14 | import os
15 | import json
16 | import base64
17 | import glob
18 | import pandas as pd
19 | from PIL import Image
20 | from io import BytesIO
21 | from vllm import LLM, SamplingParams
22 |
23 | def load_parquet_data(file_path):
24 | """Load data from a single Parquet file."""
25 | try:
26 | df = pd.read_parquet(file_path)
27 | return df.to_dict('records')
28 | except Exception as e:
29 | print(f"Failed to load file {file_path}: {e}")
30 | return []
31 |
32 | def load_existing_results(output_path):
33 | """Load existing results file to skip already processed data."""
34 | if os.path.exists(output_path):
35 | try:
36 | with open(output_path, 'r', encoding='utf-8') as f:
37 | return json.load(f)
38 | except Exception as e:
39 | print(f"Failed to load existing results file {output_path}: {e}")
40 | return []
41 |
42 | def save_results_to_json(results, output_path="output.json"):
43 | """Save results to a JSON file."""
44 | try:
45 | with open(output_path, 'w', encoding='utf-8') as f:
46 | json.dump(results, f, ensure_ascii=False, indent=4)
47 | print(f"Results saved to: {output_path}")
48 | except Exception as e:
49 | print(f"Failed to save results to {output_path}: {e}")
50 |
51 | def create_prompt(prompt, question, choices):
52 | """Create prompt for InternVL3"""
53 | return f"USER: \n{prompt}\n'Question:'{question}\n'Choices:'{choices}\nASSISTANT:"
54 |
55 | def process_data(data_list, llm, temperature, max_tokens, existing_results=None, output_path=None):
56 | """Process data list, generate model responses, and save results in real-time."""
57 | results = []
58 | sampling_params = SamplingParams(
59 | temperature=temperature,
60 | max_tokens=max_tokens,
61 | )
62 |
63 | # Create a set of processed question IDs to skip
64 | processed_ids = set()
65 | if existing_results:
66 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')}
67 | results = existing_results.copy()
68 | print(f"Loaded {len(processed_ids)} previously processed questions (from {output_path}), will skip these.")
69 |
70 | total_items = len(data_list)
71 | newly_processed_count = 0
72 |
73 | for i, item in enumerate(data_list):
74 | question_id = item.get('Question_id')
75 |
76 | # Skip already processed questions
77 | if question_id in processed_ids:
78 | continue
79 |
80 | newly_processed_count += 1
81 |
82 | try:
83 | # Decode base64 image
84 | base64_image = item['Image']
85 | image_bytes = base64.b64decode(base64_image)
86 | image_io = BytesIO(image_bytes)
87 | # Convert to PIL image object
88 | image = Image.open(image_io)
89 |
90 | prompt = item['Prompt']
91 | question = item['Question']
92 | choices = item['Choices']
93 |
94 | # Create prompt
95 | internvl_prompt = create_prompt(prompt, question, choices)
96 |
97 | # Single image inference
98 | inputs = {
99 | "prompt": internvl_prompt,
100 | "multi_modal_data": {
101 | "image": image
102 | },
103 | }
104 |
105 | # Generate response
106 | outputs = llm.generate([inputs], sampling_params=sampling_params)
107 | response = outputs[0].outputs[0].text.strip()
108 |
109 | result = {
110 | "Question_id": question_id,
111 | "Response": response,
112 | "Answer": item.get('Answer'),
113 | "Category": item.get('Category'),
114 | "Png_id": item.get('Png_id')
115 | }
116 |
117 | results.append(result)
118 |
119 | # Save results every 10 new questions or at the last question
120 | if newly_processed_count > 0 and (newly_processed_count % 10 == 0 or i == total_items - 1):
121 | if output_path:
122 | save_results_to_json(results, output_path)
123 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), saved to {output_path}.")
124 |
125 | except Exception as e:
126 | print(f"Error processing Question_id {question_id}: {e}")
127 | result = {
128 | "Question_id": question_id,
129 | "Response": f"Processing error: {e}",
130 | "Answer": item.get('Answer'),
131 | "Category": item.get('Category'),
132 | "Png_id": item.get('Png_id')
133 | }
134 | results.append(result)
135 |
136 | # Save results when an error occurs
137 | if output_path:
138 | save_results_to_json(results, output_path)
139 | print(f"Error processing Question_id {question_id}, saved current results to {output_path}.")
140 |
141 | # Ensure final results are saved if any new items were processed
142 | if newly_processed_count > 0 and output_path:
143 | save_results_to_json(results, output_path)
144 | print(f"File processing complete, final results saved to: {output_path}")
145 |
146 | return results
147 |
148 | def main(data_folder="./data", model_path="./model", output_base_dir="./results",
149 | categories=None, max_tokens=256, temperature=0, tensor_parallel_size=8):
150 | """Main function to process Parquet datasets and generate responses."""
151 | # Ensure output directory exists
152 | os.makedirs(output_base_dir, exist_ok=True)
153 |
154 | # Find all Parquet files
155 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet"))
156 | files_to_process = []
157 |
158 | # Filter files based on categories
159 | if categories:
160 | category_basenames = {cat.replace(".parquet", "") for cat in categories}
161 | for file_path in all_parquet_files:
162 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0]
163 | if basename_no_ext in category_basenames:
164 | files_to_process.append(file_path)
165 | print(f"Will process specific category files: {files_to_process}")
166 | # Check for missing categories
167 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process}
168 | missing_categories = category_basenames - found_basenames
169 | if missing_categories:
170 | print(f"Warning: The following specified categories were not found: {missing_categories}")
171 | else:
172 | files_to_process = all_parquet_files
173 | print(f"Will process all .parquet files in {data_folder}.")
174 |
175 | if not files_to_process:
176 | print("No Parquet files found to process.")
177 | return
178 |
179 | # Load model
180 | print("Loading model...")
181 | try:
182 | llm = LLM(
183 | model=model_path,
184 | tensor_parallel_size=tensor_parallel_size,
185 | max_model_len=32768,
186 | enforce_eager=True
187 | )
188 | print("Model loaded successfully.")
189 | except Exception as e:
190 | print(f"Error loading model: {e}")
191 | return
192 |
193 | # Process each file
194 | for file_path in files_to_process:
195 | print("-" * 50)
196 | print(f"Processing file: {file_path}")
197 |
198 | # Generate output filename from input filename
199 | base_name = os.path.basename(file_path)
200 | dataset_name = os.path.splitext(base_name)[0]
201 | output_filename = f"internvl3_{dataset_name}.json"
202 | current_output_path = os.path.join(output_base_dir, output_filename)
203 | print(f"Results will be saved to: {current_output_path}")
204 |
205 | # Load existing results for current file
206 | existing_results_for_current_file = load_existing_results(current_output_path)
207 |
208 | # Load Parquet data
209 | data_list = load_parquet_data(file_path)
210 |
211 | # Process data
212 | if data_list:
213 | process_data(
214 | data_list,
215 | llm,
216 | temperature,
217 | max_tokens,
218 | existing_results=existing_results_for_current_file,
219 | output_path=current_output_path
220 | )
221 | print(f"File {file_path} processing complete.")
222 | else:
223 | print(f"No data in file {file_path} or loading failed.")
224 |
225 | print("=" * 50)
226 | print("All specified files have been processed.")
227 |
228 | if __name__ == "__main__":
229 | import argparse
230 |
231 | parser = argparse.ArgumentParser(description="InternVL3 QA Inference")
232 | parser.add_argument("--data_folder", default="./data", help="Path to data folder containing parquet files")
233 | parser.add_argument("--model_path", default="./model", help="Path to model weights")
234 | parser.add_argument("--output_dir", default="./results", help="Directory to save results")
235 | parser.add_argument("--categories", nargs="+", default=["Real-world_QA", "Synthetic_QA", "Multi-window_QA"],
236 | help="Categories to process (parquet file names without extension)")
237 | parser.add_argument("--max_tokens", type=int, default=1024, help="Maximum tokens to generate")
238 | parser.add_argument("--temperature", type=float, default=0, help="Temperature parameter")
239 | parser.add_argument("--tensor_parallel_size", type=int, default=8, help="Number of GPUs for tensor parallelism")
240 |
241 | args = parser.parse_args()
242 |
243 | main(
244 | args.data_folder,
245 | args.model_path,
246 | args.output_dir,
247 | args.categories,
248 | args.max_tokens,
249 | args.temperature,
250 | args.tensor_parallel_size
251 | )
--------------------------------------------------------------------------------
/generate_response/openai_qa.py:
--------------------------------------------------------------------------------
1 | """
2 | OpenAI Vision API Question Answering Script
3 |
4 | Usage:
5 | 1. Place your dataset parquet files in the './datasets' folder
6 | 2. Configure your OpenAI API key in the main function
7 | 3. Set the desired model name (default: gpt-4o-2024-11-20)
8 | 4. Run the script: python openai_qa.py
9 | 5. Results will be saved to './results/openai' directory
10 |
11 | The script processes parquet files containing images (base64 encoded) and questions,
12 | sends them to OpenAI vision API, and saves the responses as JSON files.
13 | """
14 |
15 | import os
16 | import json
17 | import glob
18 | import pandas as pd
19 | from openai import OpenAI
20 | import time
21 |
22 | def load_parquet_data(file_path):
23 | try:
24 | df = pd.read_parquet(file_path)
25 | return df.to_dict('records')
26 | except Exception as e:
27 | print(f"Failed to load file {file_path}: {e}")
28 | return []
29 |
30 | def load_existing_results(output_path):
31 | if os.path.exists(output_path):
32 | try:
33 | with open(output_path, 'r', encoding='utf-8') as f:
34 | return json.load(f)
35 | except Exception as e:
36 | print(f"Failed to load existing results from {output_path}: {e}")
37 | return []
38 |
39 | def save_results_to_json(results, output_path="output.json"):
40 | try:
41 | with open(output_path, 'w', encoding='utf-8') as f:
42 | json.dump(results, f, ensure_ascii=False, indent=4)
43 | print(f"Results saved to: {output_path}")
44 | except Exception as e:
45 | print(f"Failed to save results to {output_path}: {e}")
46 |
47 |
48 | def process_data(data_list, client, model_name, temperature, max_tokens, existing_results=None, output_path=None):
49 | results = []
50 |
51 | processed_ids = set()
52 | if existing_results:
53 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')}
54 | results = existing_results.copy()
55 | print(f"Loaded {len(processed_ids)} previously processed questions (from {output_path}), will skip them.")
56 |
57 | total_items = len(data_list)
58 | newly_processed_count = 0
59 |
60 | for i, item in enumerate(data_list):
61 | question_id = item.get('Question_id')
62 |
63 | if question_id in processed_ids:
64 | print(f"Skipping already processed question ID: {question_id}")
65 | continue
66 |
67 | newly_processed_count += 1
68 |
69 | try:
70 | base64_image = item['Image']
71 |
72 | prompt = item['Prompt']
73 | question = item['Question']
74 | choices = item['Choices']
75 |
76 | prompt_text = f"{prompt}\nQuestion: {question}\nChoices: {choices}"
77 |
78 | try:
79 | response = client.chat.completions.create(
80 | model=model_name,
81 | messages=[
82 | {
83 | "role": "user",
84 | "content": [
85 | {
86 | "type": "text",
87 | "text": prompt_text
88 | },
89 | {
90 | "type": "image_url",
91 | "image_url": {
92 | "url": f"data:image/jpeg;base64,{base64_image}"
93 | }
94 | }
95 | ]
96 | }
97 | ],
98 | temperature=temperature,
99 | max_tokens=max_tokens
100 | )
101 |
102 | model_response = response.choices[0].message.content
103 |
104 | except Exception as api_error:
105 | print(f"API call error, attempting retry: {api_error}")
106 | time.sleep(2)
107 | try:
108 | response = client.chat.completions.create(
109 | model=model_name,
110 | messages=[
111 | {
112 | "role": "user",
113 | "content": [
114 | {
115 | "type": "text",
116 | "text": prompt_text
117 | },
118 | {
119 | "type": "image_url",
120 | "image_url": {
121 | "url": f"data:image/jpeg;base64,{base64_image}"
122 | }
123 | }
124 | ]
125 | }
126 | ],
127 | temperature=temperature,
128 | max_tokens=max_tokens
129 | )
130 | model_response = response.choices[0].message.content
131 | except Exception as retry_error:
132 | raise Exception(f"Retry failed: {retry_error}")
133 |
134 | result = {
135 | "Question_id": question_id,
136 | "Response": model_response,
137 | "Answer": item.get('Answer'),
138 | "Category": item.get('Category'),
139 | "Png_id": item.get('Png_id')
140 | }
141 |
142 | results.append(result)
143 |
144 | # Save results every 5 new questions or at the last item
145 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1):
146 | if output_path:
147 | save_results_to_json(results, output_path)
148 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), results saved to {output_path}.")
149 |
150 | except Exception as e:
151 | print(f"Error processing Question_id {question_id}: {e}")
152 | result = {
153 | "Question_id": question_id,
154 | "Response": f"Processing error: {e}",
155 | "Answer": item.get('Answer'),
156 | "Category": item.get('Category'),
157 | "Png_id": item.get('Png_id')
158 | }
159 | results.append(result)
160 |
161 | # Save results when error occurs
162 | if output_path:
163 | save_results_to_json(results, output_path)
164 | print(f"Error processing Question_id {question_id}, current results saved to {output_path}.")
165 |
166 | # Ensure final results are saved at the end if any new items were processed
167 | if newly_processed_count > 0 and output_path:
168 | save_results_to_json(results, output_path)
169 | print(f"File processing complete, final results saved to: {output_path}")
170 |
171 | return results
172 |
173 | def main(data_folder, api_key, model_name, output_base_dir, categories=None, max_tokens=1024, temperature=0):
174 | """Main function to process parquet datasets and generate output files for each dataset."""
175 | # Ensure output directory exists
176 | os.makedirs(output_base_dir, exist_ok=True)
177 |
178 | # Find all parquet files
179 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet"))
180 | files_to_process = []
181 |
182 | # Filter files by categories if specified
183 | if categories:
184 | category_basenames = {cat.replace(".parquet", "") for cat in categories}
185 | for file_path in all_parquet_files:
186 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0]
187 | if basename_no_ext in category_basenames:
188 | files_to_process.append(file_path)
189 | print(f"Will process specified category files: {files_to_process}")
190 |
191 | # Check for missing categories
192 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process}
193 | missing_categories = category_basenames - found_basenames
194 | if missing_categories:
195 | print(f"Warning: The following specified categories were not found: {missing_categories}")
196 | else:
197 | files_to_process = all_parquet_files
198 | print(f"Will process all .parquet files in the {data_folder} folder.")
199 |
200 | if not files_to_process:
201 | print("No parquet files found to process.")
202 | return
203 |
204 | # Initialize OpenAI API client
205 | print("Initializing OpenAI API client...")
206 | try:
207 | client = OpenAI(api_key=api_key)
208 | print("OpenAI API client initialized successfully.")
209 | except Exception as e:
210 | print(f"Error initializing OpenAI API client: {e}")
211 | return
212 |
213 | # Process each file
214 | for file_path in files_to_process:
215 | print("-" * 50)
216 | print(f"Starting to process file: {file_path}")
217 |
218 | # Generate output filename from input filename
219 | base_name = os.path.basename(file_path)
220 | dataset_name = os.path.splitext(base_name)[0]
221 | output_filename = f"openai_{dataset_name}.json"
222 | current_output_path = os.path.join(output_base_dir, output_filename)
223 | print(f"Results will be saved to: {current_output_path}")
224 |
225 | # Load existing results for the current file
226 | existing_results_for_current_file = load_existing_results(current_output_path)
227 |
228 | # Load parquet data
229 | data_list = load_parquet_data(file_path)
230 |
231 | # Process data
232 | if data_list:
233 | process_data(
234 | data_list,
235 | client,
236 | model_name,
237 | temperature,
238 | max_tokens,
239 | existing_results=existing_results_for_current_file,
240 | output_path=current_output_path
241 | )
242 | print(f"File {file_path} processing completed.")
243 | else:
244 | print(f"No data found or failed to load file {file_path}.")
245 |
246 | print("=" * 50)
247 | print("All specified files have been processed.")
248 |
249 | if __name__ == "__main__":
250 | data_folder = "./datasets"
251 | api_key = "your_api_key"
252 | model_name = "gpt-4o-2024-11-20"
253 |
254 | output_dir = "./results/openai"
255 |
256 | categories = ["Real-world_QA", "Synthetic_QA", "Multi-window_QA"]
257 |
258 | max_tokens = 300
259 | temperature = 0
260 |
261 | main(data_folder, api_key, model_name, output_dir, categories, max_tokens, temperature)
--------------------------------------------------------------------------------
/generate_response/llava_qa.py:
--------------------------------------------------------------------------------
1 | """
2 | LLaVA QA Processor
3 |
4 | This script processes vision-language question answering tasks using the LLaVA model.
5 | It handles multiple parquet datasets, generating responses for image-based questions.
6 |
7 | Usage:
8 | python llava_qa.py [--data_folder DATA_PATH] [--model_path MODEL_PATH]
9 | [--output_dir OUTPUT_PATH] [--categories CATEGORY1 CATEGORY2 ...]
10 | [--max_tokens MAX_TOKENS] [--temperature TEMP] [--tensor_parallel_size GPUS]
11 |
12 | Example:
13 | python llava_qa.py --data_folder "./datasets" --model_path "./models/llava-model"
14 | --output_dir "./results" --categories "Real-world_QA" "Synthetic_QA"
15 | """
16 |
17 | import os
18 | import json
19 | import base64
20 | import glob
21 | import pandas as pd
22 | from PIL import Image
23 | from io import BytesIO
24 | from vllm import LLM, SamplingParams
25 | import torch
26 |
27 | def load_parquet_data(file_path):
28 | """Load data from a single Parquet file."""
29 | try:
30 | df = pd.read_parquet(file_path)
31 | return df.to_dict('records')
32 | except Exception as e:
33 | print(f"Failed to load file {file_path}: {e}")
34 | return []
35 |
36 | def load_existing_results(output_path):
37 | """Load existing results file to skip already processed data."""
38 | if os.path.exists(output_path):
39 | try:
40 | with open(output_path, 'r', encoding='utf-8') as f:
41 | return json.load(f)
42 | except Exception as e:
43 | print(f"Failed to load existing results file {output_path}: {e}")
44 | return []
45 |
46 | def save_results_to_json(results, output_path="output.json"):
47 | """Save results to a JSON file."""
48 | try:
49 | with open(output_path, 'w', encoding='utf-8') as f:
50 | json.dump(results, f, ensure_ascii=False, indent=4)
51 | print(f"Results saved to: {output_path}")
52 | except Exception as e:
53 | print(f"Failed to save results to {output_path}: {e}")
54 |
55 | def create_prompt(prompt, question, choices):
56 | """Create prompt for the LLaVA model"""
57 | return f"<|im_start|>user \n{prompt}\n'Question:'{question}\n'Choices:'{choices}<|im_end|> <|im_start|>assistant\n"
58 | # return f"<|im_start|>user \n请你描述这张图片<|im_end|> <|im_start|>assistant\n"
59 |
60 | def process_data(data_list, llm, temperature, max_tokens, existing_results=None, output_path=None):
61 | """Process data list, generate model responses, and save results in real-time."""
62 | results = []
63 | sampling_params = SamplingParams(
64 | temperature=temperature,
65 | max_tokens=max_tokens,
66 | )
67 |
68 | # Create a set of processed question IDs to skip
69 | processed_ids = set()
70 | if existing_results:
71 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')}
72 | results = existing_results.copy()
73 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip these questions.")
74 |
75 | total_items = len(data_list)
76 | newly_processed_count = 0
77 |
78 | for i, item in enumerate(data_list):
79 | question_id = item.get('Question_id')
80 |
81 | # Skip already processed questions
82 | if question_id in processed_ids:
83 | continue
84 |
85 | newly_processed_count += 1
86 |
87 | try:
88 | # Decode base64 image
89 | base64_image = item['Image']
90 | image_bytes = base64.b64decode(base64_image)
91 | image_io = BytesIO(image_bytes)
92 | # Convert to PIL image object
93 | image = Image.open(image_io)
94 |
95 | prompt = item['Prompt']
96 | question = item['Question']
97 | choices = item['Choices']
98 |
99 | # Create prompt
100 | llava_prompt = create_prompt(prompt, question, choices)
101 |
102 | # print("**************************************llava_prompt: ", llava_prompt)
103 |
104 | # Single image inference
105 | inputs = {
106 | "prompt": llava_prompt,
107 | "multi_modal_data": {
108 | "image": image
109 | },
110 | }
111 |
112 | # Generate response
113 | outputs = llm.generate([inputs], sampling_params=sampling_params)
114 | response = outputs[0].outputs[0].text.strip()
115 |
116 | result = {
117 | "Question_id": question_id,
118 | "Response": response,
119 | "Answer": item.get('Answer'),
120 | "Category": item.get('Category'),
121 | "Png_id": item.get('Png_id')
122 | }
123 |
124 | results.append(result)
125 |
126 | # Save results every 10 *new* questions or at the last *new* question
127 | if newly_processed_count > 0 and (newly_processed_count % 10 == 0 or i == total_items - 1):
128 | if output_path:
129 | save_results_to_json(results, output_path)
130 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), saved results to {output_path}.")
131 |
132 | except Exception as e:
133 | print(f"Error processing Question_id {question_id}: {e}")
134 | result = {
135 | "Question_id": question_id,
136 | "Response": f"Processing error: {e}",
137 | "Answer": item.get('Answer'),
138 | "Category": item.get('Category'),
139 | "Png_id": item.get('Png_id')
140 | }
141 | results.append(result)
142 |
143 | # Save results when an error occurs
144 | if output_path:
145 | save_results_to_json(results, output_path)
146 | print(f"Error processing Question_id {question_id}, saved current results to {output_path}.")
147 |
148 | # Ensure final results are saved at the end of function if any new items were processed
149 | if newly_processed_count > 0 and output_path:
150 | save_results_to_json(results, output_path)
151 | print(f"Completed processing file, final results saved to: {output_path}")
152 |
153 | return results
154 |
155 | def main(data_folder, model_path, output_base_dir, categories=None, max_tokens=256, temperature=0, tensor_parallel_size=8):
156 | """Main function to process specified Parquet datasets and generate inferences."""
157 | # Ensure output directory exists
158 | os.makedirs(output_base_dir, exist_ok=True)
159 |
160 | # Find all Parquet files
161 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet"))
162 | files_to_process = []
163 |
164 | # Filter files by categories
165 | if categories:
166 | # Ensure category names don't include .parquet suffix for matching
167 | category_basenames = {cat.replace(".parquet", "") for cat in categories}
168 | for file_path in all_parquet_files:
169 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0]
170 | if basename_no_ext in category_basenames:
171 | files_to_process.append(file_path)
172 | print(f"Will process specified category files: {files_to_process}")
173 | # Check for missing categories
174 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process}
175 | missing_categories = category_basenames - found_basenames
176 | if missing_categories:
177 | print(f"Warning: The following specified category files were not found: {missing_categories}")
178 | else:
179 | files_to_process = all_parquet_files
180 | print(f"Will process all .parquet files in folder {data_folder}.")
181 |
182 | if not files_to_process:
183 | print("No Parquet files found to process.")
184 | return
185 |
186 | # --- Model Loading ---
187 | print("Loading model...")
188 | try:
189 | llm = LLM(
190 | model=model_path,
191 | tensor_parallel_size=tensor_parallel_size,
192 | max_model_len=32768,
193 | limit_mm_per_prompt={"image": 1, "video": 0},
194 | enforce_eager=True
195 | )
196 | print("Model loaded successfully.")
197 | except Exception as e:
198 | print(f"Error loading model: {e}")
199 | return
200 |
201 | # --- Process each file ---
202 | for file_path in files_to_process:
203 | print("-" * 50)
204 | print(f"Starting to process file: {file_path}")
205 |
206 | # 1. Dynamically generate output filename from input filename
207 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet"
208 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA"
209 | # Combine "llava_" prefix with filename to create JSON filename
210 | output_filename = f"llava_{dataset_name}.json"
211 | current_output_path = os.path.join(output_base_dir, output_filename)
212 | print(f"Results will be saved to: {current_output_path}")
213 |
214 | # 2. Load existing results for current file
215 | existing_results_for_current_file = load_existing_results(current_output_path)
216 |
217 | # 3. Load Parquet data
218 | data_list = load_parquet_data(file_path)
219 |
220 | # 4. Process data
221 | if data_list:
222 | process_data(
223 | data_list,
224 | llm,
225 | temperature,
226 | max_tokens,
227 | existing_results=existing_results_for_current_file,
228 | output_path=current_output_path
229 | )
230 | print(f"File {file_path} processing completed.")
231 | else:
232 | print(f"No data in file {file_path} or loading failed.")
233 |
234 | print("=" * 50)
235 | print("Processing of all specified files completed.")
236 |
237 | if __name__ == "__main__":
238 | import argparse
239 |
240 | parser = argparse.ArgumentParser(description='Process vision-language QA tasks with LLaVA')
241 | parser.add_argument('--data_folder', default='./datasets', help='Path to data folder')
242 | parser.add_argument('--model_path', default='./models/llava-model', help='Path to model weights')
243 | parser.add_argument('--output_dir', default='./results', help='Directory to save results')
244 | parser.add_argument('--categories', nargs='+', default=["Real-world_QA", "Synthetic_QA", "Multi-window_QA"],
245 | help='Categories to process (default: all)')
246 | parser.add_argument('--max_tokens', type=int, default=1024, help='Maximum tokens to generate')
247 | parser.add_argument('--temperature', type=float, default=0, help='Temperature parameter')
248 | parser.add_argument('--tensor_parallel_size', type=int, default=8, help='Number of GPUs for parallel inference')
249 |
250 | args = parser.parse_args()
251 |
252 | main(args.data_folder, args.model_path, args.output_dir,
253 | args.categories, args.max_tokens, args.temperature, args.tensor_parallel_size)
--------------------------------------------------------------------------------
/generate_response/o1_qa.py:
--------------------------------------------------------------------------------
1 | """
2 | O1_QA.py - OpenAI Vision Model Evaluation Tool
3 |
4 | Usage:
5 | 1. Place your Parquet datasets in the './mini_datasets' folder
6 | 2. Set your OpenAI API key in the main function
7 | 3. Configure model name, output directory and categories as needed
8 | 4. Run the script: python o1_qa.py
9 |
10 | This script processes image question-answering datasets in Parquet format,
11 | queries the OpenAI vision model, and saves results to JSON files.
12 | """
13 |
14 | import os
15 | import json
16 | import base64
17 | import glob
18 | import pandas as pd
19 | from PIL import Image
20 | from io import BytesIO
21 | from openai import OpenAI
22 | import time
23 |
24 | def load_parquet_data(file_path):
25 | """Load data from a single Parquet file."""
26 | try:
27 | df = pd.read_parquet(file_path)
28 | return df.to_dict('records')
29 | except Exception as e:
30 | print(f"Failed to load file {file_path}: {e}")
31 | return []
32 |
33 | def load_existing_results(output_path):
34 | """Load existing results file to skip already processed data."""
35 | if os.path.exists(output_path):
36 | try:
37 | with open(output_path, 'r', encoding='utf-8') as f:
38 | return json.load(f)
39 | except Exception as e:
40 | print(f"Failed to load existing results file {output_path}: {e}")
41 | return []
42 |
43 | def save_results_to_json(results, output_path="output.json"):
44 | """Save results to a JSON file."""
45 | try:
46 | with open(output_path, 'w', encoding='utf-8') as f:
47 | json.dump(results, f, ensure_ascii=False, indent=4)
48 | print(f"Results saved to: {output_path}")
49 | except Exception as e:
50 | print(f"Failed to save results to {output_path}: {e}")
51 |
52 |
53 | def process_data(data_list, client, model_name, existing_results=None, output_path=None):
54 | """Process data list, generate model responses, and save results in real-time."""
55 | results = []
56 |
57 | # Create a set of processed question IDs to skip
58 | processed_ids = set()
59 | if existing_results:
60 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')}
61 | results = existing_results.copy()
62 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip them.")
63 |
64 | total_items = len(data_list)
65 | newly_processed_count = 0
66 |
67 | for i, item in enumerate(data_list):
68 | question_id = item.get('Question_id')
69 |
70 | # Skip already processed questions
71 | if question_id in processed_ids:
72 | print(f"Skipping already processed question ID: {question_id}")
73 | continue
74 |
75 | newly_processed_count += 1
76 |
77 | try:
78 | # Get base64 image
79 | base64_image = item['Image']
80 |
81 | prompt = item['Prompt']
82 | question = item['Question']
83 | choices = item['Choices']
84 |
85 | # Build prompt text
86 | prompt_text = f"{prompt}\nQuestion: {question}\nChoices: {choices}"
87 |
88 | # Build API request
89 | try:
90 | response = client.chat.completions.create(
91 | model=model_name,
92 | messages=[
93 | {
94 | "role": "user",
95 | "content": [
96 | {
97 | "type": "text",
98 | "text": prompt_text
99 | },
100 | {
101 | "type": "image_url",
102 | "image_url": {
103 | "url": f"data:image/jpeg;base64,{base64_image}"
104 | }
105 | }
106 | ]
107 | }
108 | ]
109 | )
110 |
111 | # Get response text
112 | model_response = response.choices[0].message.content
113 |
114 | except Exception as api_error:
115 | print(f"API call error, trying to retry: {api_error}")
116 | # Simple retry mechanism
117 | time.sleep(2)
118 | try:
119 | response = client.chat.completions.create(
120 | model=model_name,
121 | messages=[
122 | {
123 | "role": "user",
124 | "content": [
125 | {
126 | "type": "text",
127 | "text": prompt_text
128 | },
129 | {
130 | "type": "image_url",
131 | "image_url": {
132 | "url": f"data:image/jpeg;base64,{base64_image}"
133 | }
134 | }
135 | ]
136 | }
137 | ]
138 | )
139 | model_response = response.choices[0].message.content
140 | except Exception as retry_error:
141 | raise Exception(f"Retry failed: {retry_error}")
142 |
143 | result = {
144 | "Question_id": question_id,
145 | "Response": model_response,
146 | "Answer": item.get('Answer'),
147 | "Category": item.get('Category'),
148 | "Png_id": item.get('Png_id')
149 | }
150 |
151 | results.append(result)
152 |
153 | # Save results every 5 new questions or at the last new question
154 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1):
155 | if output_path:
156 | save_results_to_json(results, output_path)
157 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), saved results to {output_path}.")
158 |
159 | except Exception as e:
160 | print(f"Error processing Question_id {question_id}: {e}")
161 | result = {
162 | "Question_id": question_id,
163 | "Response": f"Processing error: {e}",
164 | "Answer": item.get('Answer'),
165 | "Category": item.get('Category'),
166 | "Png_id": item.get('Png_id')
167 | }
168 | results.append(result)
169 |
170 | # Save results when an error occurs
171 | if output_path:
172 | save_results_to_json(results, output_path)
173 | print(f"Error processing Question_id {question_id}, saved current results to {output_path}.")
174 |
175 | # Ensure final results are saved at the end of function if any new items were processed
176 | if newly_processed_count > 0 and output_path:
177 | save_results_to_json(results, output_path)
178 | print(f"Completed processing file, final results saved to: {output_path}")
179 |
180 | return results
181 |
182 | def main(data_folder, api_key, model_name, output_base_dir, categories=None):
183 | """Main function to process Parquet datasets in the specified folder and generate separate output files for each dataset."""
184 | # Ensure output directory exists
185 | os.makedirs(output_base_dir, exist_ok=True)
186 |
187 | # Find all Parquet files
188 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet"))
189 | files_to_process = []
190 |
191 | # Filter files by categories if specified
192 | if categories:
193 | # Ensure category names don't include .parquet extension
194 | category_basenames = {cat.replace(".parquet", "") for cat in categories}
195 | for file_path in all_parquet_files:
196 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0]
197 | if basename_no_ext in category_basenames:
198 | files_to_process.append(file_path)
199 | print(f"Will process specified category files: {files_to_process}")
200 | # Check for missing categories
201 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process}
202 | missing_categories = category_basenames - found_basenames
203 | if missing_categories:
204 | print(f"Warning: The following specified category files were not found: {missing_categories}")
205 |
206 | else:
207 | files_to_process = all_parquet_files
208 | print(f"Will process all .parquet files in the {data_folder} folder.")
209 |
210 | if not files_to_process:
211 | print("No Parquet files found to process.")
212 | return
213 |
214 | # Initialize OpenAI API client
215 | print("Initializing OpenAI API client...")
216 | try:
217 | client = OpenAI(api_key=api_key)
218 | print("OpenAI API client initialized successfully.")
219 | except Exception as e:
220 | print(f"Error initializing OpenAI API client: {e}")
221 | return
222 |
223 | # Process each file
224 | for file_path in files_to_process:
225 | print("-" * 50)
226 | print(f"Starting to process file: {file_path}")
227 |
228 | # 1. Generate output filename from input filename
229 | base_name = os.path.basename(file_path)
230 | dataset_name = os.path.splitext(base_name)[0]
231 | output_filename = f"o1_{dataset_name}.json"
232 | current_output_path = os.path.join(output_base_dir, output_filename)
233 | print(f"Results will be saved to: {current_output_path}")
234 |
235 | # 2. Load existing results for current file
236 | existing_results_for_current_file = load_existing_results(current_output_path)
237 |
238 | # 3. Load Parquet data
239 | data_list = load_parquet_data(file_path)
240 |
241 | # 4. Process data
242 | if data_list:
243 | process_data(
244 | data_list,
245 | client,
246 | model_name,
247 | existing_results=existing_results_for_current_file,
248 | output_path=current_output_path
249 | )
250 | print(f"File {file_path} processing completed.")
251 | else:
252 | print(f"No data in file {file_path} or loading failed.")
253 |
254 | print("=" * 50)
255 | print("Processing of all specified files completed.")
256 |
257 | if __name__ == "__main__":
258 | data_folder = "./mini_datasets" # Data folder path
259 | api_key = "your_api_key" # OpenAI API key
260 | model_name = "o1" # OpenAI model name
261 |
262 | output_dir = "./results/o1" # Output directory for results
263 |
264 | categories = ["Real-world_QA_mini", "Synthetic_QA_mini", "Multi-window_QA_mini"]
265 |
266 | # Call the main function
267 | main(data_folder, api_key, model_name, output_dir, categories)
--------------------------------------------------------------------------------
/generate_response/qwen_qa.py:
--------------------------------------------------------------------------------
1 | """
2 | Qwen Vision-Language Model Question Answering Script
3 |
4 | Usage:
5 | python qwen_qa.py [--data_folder DATA_FOLDER] [--model_path MODEL_PATH]
6 | [--output_dir OUTPUT_DIR] [--categories CATEGORIES]
7 | [--max_tokens MAX_TOKENS] [--temperature TEMPERATURE]
8 | [--tensor_parallel_size TENSOR_PARALLEL_SIZE]
9 |
10 | This script processes image-based questions from parquet files, generates responses
11 | using Qwen Vision-Language model, and saves results to JSON files.
12 | """
13 |
14 | import os
15 | import json
16 | import base64
17 | import glob
18 | import pandas as pd
19 | from PIL import Image
20 | from io import BytesIO
21 | from vllm import LLM, SamplingParams
22 | from transformers import AutoProcessor
23 | from qwen_vl_utils import process_vision_info
24 |
25 | # --- load_parquet_data, load_existing_results, save_results_to_json, process_data 函数保持不变 ---
26 |
27 | def load_parquet_data(file_path):
28 | """Load data from a single Parquet file."""
29 | try:
30 | df = pd.read_parquet(file_path)
31 | return df.to_dict('records')
32 | except Exception as e:
33 | print(f"Failed to load file {file_path}: {e}")
34 | return []
35 |
36 | def load_existing_results(output_path):
37 | """Load existing results file to skip already processed data."""
38 | if os.path.exists(output_path):
39 | try:
40 | with open(output_path, 'r', encoding='utf-8') as f:
41 | return json.load(f)
42 | except Exception as e:
43 | print(f"Failed to load existing results file {output_path}: {e}")
44 | return []
45 |
46 | def save_results_to_json(results, output_path="output.json"):
47 | """Save results to a JSON file."""
48 | try:
49 | with open(output_path, 'w', encoding='utf-8') as f:
50 | json.dump(results, f, ensure_ascii=False, indent=4)
51 | print(f"Results saved to: {output_path}")
52 | except Exception as e:
53 | print(f"Failed to save results to {output_path}: {e}")
54 |
55 |
56 | def process_data(data_list, llm, processor, temperature, max_tokens, existing_results=None, output_path=None):
57 | """Process data list, generate model responses, and save results in real-time."""
58 | results = []
59 | sampling_params = SamplingParams(
60 | temperature=temperature,
61 | max_tokens=max_tokens,
62 | )
63 |
64 | # Create a set of processed question IDs to skip
65 | processed_ids = set()
66 | if existing_results:
67 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')}
68 | results = existing_results.copy() # Use existing results for this file
69 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip them.")
70 |
71 | total_items = len(data_list)
72 | newly_processed_count = 0
73 |
74 | for i, item in enumerate(data_list):
75 | question_id = item.get('Question_id')
76 |
77 | # Skip already processed questions
78 | if question_id in processed_ids:
79 | continue
80 |
81 | newly_processed_count += 1 # Count newly processed questions
82 |
83 | try:
84 | # Decode base64 image
85 | base64_image = item['Image']
86 | image_bytes = base64.b64decode(base64_image)
87 | image_io = BytesIO(image_bytes)
88 | # Convert to PIL Image object
89 | image = Image.open(image_io)
90 |
91 | prompt = item['Prompt']
92 | question = item['Question']
93 | choices = item['Choices']
94 |
95 | # Build messages
96 | messages = [
97 | {"role": "user", "content": [
98 | {"type": "image", "image": image},
99 | {"type": "text", "text": f"{prompt}\n'Question:'{question}\n'Choices:'+{choices}"}
100 | ]}
101 | ]
102 |
103 | # Process messages
104 | chat_prompt = processor.apply_chat_template(
105 | messages,
106 | tokenize=False,
107 | add_generation_prompt=True,
108 | )
109 |
110 | image_inputs, video_inputs = process_vision_info(messages)
111 |
112 | mm_data = {}
113 | if image_inputs is not None:
114 | mm_data["image"] = image_inputs
115 |
116 | llm_inputs = {
117 | "prompt": chat_prompt,
118 | "multi_modal_data": mm_data,
119 | }
120 |
121 | # Generate response
122 | outputs = llm.generate([llm_inputs], sampling_params=sampling_params)
123 | response = outputs[0].outputs[0].text.strip()
124 |
125 | result = {
126 | "Question_id": question_id,
127 | "Response": response,
128 | "Answer": item.get('Answer'),
129 | "Category": item.get('Category'),
130 | "Png_id": item.get('Png_id')
131 | }
132 |
133 | results.append(result) # Add new result to the list
134 |
135 | # Save results every 10 *new* questions or at the last *new* question
136 | if newly_processed_count > 0 and (newly_processed_count % 10 == 0 or i == total_items - 1):
137 | if output_path:
138 | save_results_to_json(results, output_path)
139 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), saving results to {output_path}.")
140 |
141 | except Exception as e:
142 | print(f"Error processing Question_id {question_id}: {e}")
143 | result = {
144 | "Question_id": question_id,
145 | "Response": f"Processing error: {e}", # Record error message
146 | "Answer": item.get('Answer'),
147 | "Category": item.get('Category'),
148 | "Png_id": item.get('Png_id')
149 | }
150 | results.append(result)
151 |
152 | # Save results when an error occurs
153 | if output_path:
154 | save_results_to_json(results, output_path)
155 | print(f"Error processing Question_id {question_id}, current results saved to {output_path}.")
156 |
157 | # Ensure final results are saved at the end of function if any new items were processed
158 | if newly_processed_count > 0 and output_path:
159 | save_results_to_json(results, output_path)
160 | print(f"File processing complete, final results saved to: {output_path}")
161 |
162 | return results # Return the complete list of results for this file
163 |
164 | def main(data_folder, model_path, output_base_dir, categories=None, max_tokens=256, temperature=0.1, tensor_parallel_size=8):
165 | """Main function to process specified parquet datasets and generate responses, with separate output files for each dataset."""
166 | # Ensure base output directory exists
167 | os.makedirs(output_base_dir, exist_ok=True)
168 |
169 | # Find all Parquet files
170 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet"))
171 | files_to_process = []
172 |
173 | # Filter files by categories
174 | if categories:
175 | # Ensure category names don't include .parquet suffix for matching
176 | category_basenames = {cat.replace(".parquet", "") for cat in categories}
177 | for file_path in all_parquet_files:
178 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0]
179 | if basename_no_ext in category_basenames:
180 | files_to_process.append(file_path)
181 | print(f"Will process specified category files: {files_to_process}")
182 | # Check for missing categories
183 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process}
184 | missing_categories = category_basenames - found_basenames
185 | if missing_categories:
186 | print(f"Warning: The following specified category files were not found in the folder: {missing_categories}")
187 | else:
188 | files_to_process = all_parquet_files
189 | print(f"Will process all .parquet files in folder {data_folder}.")
190 |
191 | if not files_to_process:
192 | print("No Parquet files found to process.")
193 | return
194 |
195 | # Load model
196 | print("Loading model...")
197 | try:
198 | llm = LLM(
199 | model=model_path,
200 | tensor_parallel_size=tensor_parallel_size,
201 | max_model_len=32768,
202 | limit_mm_per_prompt={"image": 1, "video": 0},
203 | enforce_eager=True
204 | )
205 | processor = AutoProcessor.from_pretrained(model_path)
206 | print("Model loaded successfully.")
207 | except Exception as e:
208 | print(f"Error loading model: {e}")
209 | return
210 |
211 | # Process each file
212 | for file_path in files_to_process:
213 | print("-" * 50)
214 | print(f"Starting to process file: {file_path}")
215 |
216 | # 1. Dynamically generate output filename from input filename
217 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet"
218 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA"
219 | # Combine "qwen_" prefix with filename to make a JSON filename
220 | output_filename = f"qwen_{dataset_name}.json"
221 | current_output_path = os.path.join(output_base_dir, output_filename) # Complete output path
222 | print(f"Results will be saved to: {current_output_path}")
223 |
224 | # 2. Load existing results for current file
225 | existing_results_for_current_file = load_existing_results(current_output_path)
226 |
227 | # 3. Load Parquet data
228 | data_list = load_parquet_data(file_path)
229 |
230 | # 4. Process data
231 | if data_list:
232 | # Pass dynamically generated output path and loaded existing results to process_data
233 | process_data(
234 | data_list,
235 | llm,
236 | processor,
237 | temperature,
238 | max_tokens,
239 | existing_results=existing_results_for_current_file, # Pass existing results for current file
240 | output_path=current_output_path # Pass output path for current file
241 | )
242 | print(f"File {file_path} processing complete.")
243 | else:
244 | print(f"No data in file {file_path} or file loading failed.")
245 |
246 | print("=" * 50)
247 | print("Processing complete for all specified files.")
248 |
249 | if __name__ == "__main__":
250 | # Use relative paths
251 | data_folder = "data/datasets" # Data folder path
252 | model_path = "models/Qwen2.5-VL-72B-Instruct" # Model weights path
253 | output_dir = "results" # Output directory for results
254 |
255 | # Specify categories to process (only provide base names, e.g., "Multi-window_QA")
256 | categories = ["Real-world_QA", "Synthetic_QA", "Multi-window_QA"]
257 | # categories = None # Set to None to process all parquet files
258 |
259 | max_tokens = 1024 # Maximum tokens to generate
260 | temperature = 0 # Temperature parameter
261 | tensor_parallel_size = 8 # Use 8 GPUs in parallel
262 |
263 | # Call main function, passing the output directory
264 | main(data_folder, model_path, output_dir, categories, max_tokens, temperature, tensor_parallel_size)
--------------------------------------------------------------------------------
/generate_response/gemini_qa.py:
--------------------------------------------------------------------------------
1 | """
2 | Gemini API Visual Question Answering Tool
3 |
4 | This script processes various QA datasets using Google's Gemini API:
5 | - Real-world QA
6 | - Synthetic QA
7 | - Multi-window QA
8 |
9 | Usage:
10 | 1. Set your API key in the main function
11 | 2. Make sure datasets are in the ./datasets folder
12 | 3. Specify which categories to process
13 | 4. Run the script: python gemini_qa.py
14 |
15 | Results will be saved to ./results/gemini directory.
16 | """
17 |
18 | import os
19 | import json
20 | import base64
21 | import glob
22 | import pandas as pd
23 | from io import BytesIO
24 | from google import genai
25 | from google.genai import types
26 | import time
27 |
28 | def load_parquet_data(file_path):
29 | """Load data from a single Parquet file."""
30 | try:
31 | df = pd.read_parquet(file_path)
32 | return df.to_dict('records')
33 | except Exception as e:
34 | print(f"Failed to load file {file_path}: {e}")
35 | return []
36 |
37 | def load_existing_results(output_path):
38 | """Load existing results file to skip already processed data."""
39 | if os.path.exists(output_path):
40 | try:
41 | with open(output_path, 'r', encoding='utf-8') as f:
42 | return json.load(f)
43 | except Exception as e:
44 | print(f"Failed to load existing results file {output_path}: {e}")
45 | return []
46 |
47 | def save_results_to_json(results, output_path="output.json"):
48 | """Save results to JSON file."""
49 | try:
50 | with open(output_path, 'w', encoding='utf-8') as f:
51 | json.dump(results, f, ensure_ascii=False, indent=4)
52 | print(f"Results saved to: {output_path}")
53 | except Exception as e:
54 | print(f"Failed to save results to {output_path}: {e}")
55 |
56 |
57 | def process_data(data_list, client, model_name, temperature, max_tokens, existing_results=None, output_path=None):
58 | """Process data list, generate model responses, and save results in real-time."""
59 | results = []
60 |
61 | # Create a set of processed question IDs to skip
62 | processed_ids = set()
63 | if existing_results:
64 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')}
65 | results = existing_results.copy() # Use existing results for this file
66 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip these questions.")
67 |
68 | total_items = len(data_list)
69 | newly_processed_count = 0
70 |
71 | for i, item in enumerate(data_list):
72 | question_id = item.get('Question_id')
73 |
74 | # Skip already processed questions
75 | if question_id in processed_ids:
76 | print(f"Skipping already processed question ID: {question_id}")
77 | continue
78 |
79 | newly_processed_count += 1 # Count newly processed questions
80 |
81 | try:
82 | # Decode base64 image
83 | base64_image = item['Image']
84 | image_bytes = base64.b64decode(base64_image)
85 |
86 | prompt = item['Prompt']
87 | question = item['Question']
88 | choices = item['Choices']
89 |
90 | # Build prompt text
91 | prompt_text = f"{prompt}\nQuestion: {question}\nChoices: {choices}"
92 |
93 | # Build API request
94 | try:
95 | response = client.models.generate_content(
96 | model=model_name,
97 | contents=[
98 | prompt_text,
99 | types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg")
100 | ],
101 | config=types.GenerateContentConfig(
102 | temperature=temperature,
103 | thinking_config=thinking_config,
104 | max_output_tokens=max_tokens
105 | )
106 | )
107 |
108 | # Get response text
109 | model_response = response.text.strip()
110 |
111 | except Exception as api_error:
112 | print(f"API call error, attempting retry: {api_error}")
113 | # Simple retry mechanism
114 | time.sleep(60)
115 | try:
116 | response = client.models.generate_content(
117 | model=model_name,
118 | contents=[
119 | prompt_text,
120 | types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg")
121 | ],
122 | config=types.GenerateContentConfig(
123 | temperature=temperature,
124 | thinking_config=thinking_config,
125 | max_output_tokens=max_tokens
126 | )
127 | )
128 | model_response = response.text.strip()
129 | except Exception as retry_error:
130 | raise Exception(f"Retry failed: {retry_error}")
131 |
132 | result = {
133 | "Question_id": question_id,
134 | "Response": model_response,
135 | "Answer": item.get('Answer'),
136 | "Category": item.get('Category'),
137 | "Png_id": item.get('Png_id')
138 | }
139 |
140 | results.append(result) # Add new result to list
141 |
142 | # Save results every 5 new questions or at the last new question
143 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1):
144 | if output_path:
145 | save_results_to_json(results, output_path) # Save to specified file
146 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), real-time saving to {output_path}.")
147 | time.sleep(60)
148 |
149 | except Exception as e:
150 | print(f"Error processing Question_id {question_id}: {e}")
151 | result = {
152 | "Question_id": question_id,
153 | "Response": f"Processing error: {e}", # Record error message
154 | "Answer": item.get('Answer'),
155 | "Category": item.get('Category'),
156 | "Png_id": item.get('Png_id')
157 | }
158 | results.append(result)
159 |
160 | # Also save results when an error occurs
161 | if output_path:
162 | save_results_to_json(results, output_path)
163 | print(f"Error processing Question_id {question_id}, current results saved to {output_path}.")
164 |
165 | # Ensure final results are saved at the end of the function if any new items were processed
166 | if newly_processed_count > 0 and output_path:
167 | save_results_to_json(results, output_path)
168 | print(f"Finished processing file, final results saved to: {output_path}")
169 |
170 | return results # Return complete result list for this file
171 |
172 | def main(data_folder, api_key, model_name, output_base_dir, categories=None, max_tokens=1024, temperature=0):
173 | """Main function, processes Parquet datasets in the specified folder and generates a separate output file for each dataset."""
174 | # Ensure the base output directory exists
175 | os.makedirs(output_base_dir, exist_ok=True)
176 |
177 | # Find all Parquet files
178 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet"))
179 | files_to_process = []
180 |
181 | # Filter files based on categories
182 | if categories:
183 | # Ensure category names don't include .parquet suffix for matching
184 | category_basenames = {cat.replace(".parquet", "") for cat in categories}
185 | for file_path in all_parquet_files:
186 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0]
187 | if basename_no_ext in category_basenames:
188 | files_to_process.append(file_path)
189 | print(f"Will process specified category files: {files_to_process}")
190 | # Check for missing categories
191 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process}
192 | missing_categories = category_basenames - found_basenames
193 | if missing_categories:
194 | print(f"Warning: The following specified category files were not found in the folder: {missing_categories}")
195 |
196 | else:
197 | files_to_process = all_parquet_files
198 | print(f"Will process all .parquet files in the folder {data_folder}.")
199 |
200 | if not files_to_process:
201 | print("No Parquet files found to process.")
202 | return
203 |
204 | # --- Initialize Gemini API client ---
205 | print("Initializing Gemini API client...")
206 | try:
207 | client = genai.Client(api_key=api_key)
208 | print("Gemini API client initialized successfully.")
209 | except Exception as e:
210 | print(f"Error initializing Gemini API client: {e}")
211 | return
212 |
213 | # --- Process each file ---
214 | for file_path in files_to_process:
215 | print("-" * 50)
216 | print(f"Starting to process file: {file_path}")
217 |
218 | # 1. Dynamically generate output filename from input filename
219 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet"
220 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA"
221 | # Combine "gemini_" prefix with the file name to create a JSON filename
222 | output_filename = f"gemini_{dataset_name}.json"
223 | current_output_path = os.path.join(output_base_dir, output_filename) # Full output path
224 | print(f"Results will be saved to: {current_output_path}")
225 |
226 | # 2. Load existing results for current file
227 | existing_results_for_current_file = load_existing_results(current_output_path)
228 |
229 | # 3. Load Parquet data
230 | data_list = load_parquet_data(file_path)
231 |
232 | # 4. Process data
233 | if data_list:
234 | # Pass dynamically generated output path and loaded existing results to process_data
235 | process_data(
236 | data_list,
237 | client,
238 | model_name,
239 | temperature,
240 | max_tokens,
241 | existing_results=existing_results_for_current_file, # Pass existing results for current file
242 | output_path=current_output_path # Pass output path for current file
243 | )
244 | print(f"File {file_path} processing completed.")
245 | else:
246 | print(f"No data in file {file_path} or loading failed.")
247 |
248 | print("=" * 50)
249 | print("All specified files processing completed.")
250 |
251 | if __name__ == "__main__":
252 | data_folder = "./datasets" # Data folder path
253 | api_key = "your_api_key"
254 | model_name = "gemini-2.5-flash-preview-04-17" # Gemini model name
255 |
256 | output_dir = "./results/gemini" # Directory to save results
257 |
258 | # Specify which dataset categories to process (just provide base names like "Multi-window_QA")
259 | categories = ["Real-world_QA", "Synthetic_QA", "Multi-window_QA"]
260 |
261 | max_tokens = 300 # Maximum tokens to generate
262 | temperature = 0 # Temperature parameter
263 | thinking_config = types.ThinkingConfig(thinking_budget=0)
264 |
265 | # Call main function
266 | main(data_folder, api_key, model_name, output_dir, categories, max_tokens, temperature)
--------------------------------------------------------------------------------
/generate_response/o4-mini_qa.py:
--------------------------------------------------------------------------------
1 | """
2 | O4-Mini QA System
3 |
4 | This script processes image-based question answering tasks using OpenAI's o4-mini model.
5 |
6 | Usage:
7 | 1. Place your dataset parquet files in the ./datasets folder
8 | 2. Set your OpenAI API key in the script
9 | 3. Specify the categories you want to process
10 | 4. Run the script: python o4-mini_qa.py
11 |
12 | Results will be saved to ./results/o4-mini directory.
13 | """
14 |
15 | import os
16 | import json
17 | import base64
18 | import glob
19 | import pandas as pd
20 | from PIL import Image
21 | from io import BytesIO
22 | from openai import OpenAI
23 | import time
24 |
25 | def load_parquet_data(file_path):
26 | """Load data from a single Parquet file."""
27 | try:
28 | df = pd.read_parquet(file_path)
29 | return df.to_dict('records')
30 | except Exception as e:
31 | print(f"Failed to load file {file_path}: {e}")
32 | return []
33 |
34 | def load_existing_results(output_path):
35 | """Load existing results file to skip already processed data."""
36 | if os.path.exists(output_path):
37 | try:
38 | with open(output_path, 'r', encoding='utf-8') as f:
39 | return json.load(f)
40 | except Exception as e:
41 | print(f"Failed to load existing results file {output_path}: {e}")
42 | return []
43 |
44 | def save_results_to_json(results, output_path="output.json"):
45 | """Save results to JSON file."""
46 | try:
47 | with open(output_path, 'w', encoding='utf-8') as f:
48 | json.dump(results, f, ensure_ascii=False, indent=4)
49 | print(f"Results saved to: {output_path}")
50 | except Exception as e:
51 | print(f"Failed to save results to {output_path}: {e}")
52 |
53 |
54 | def process_data(data_list, client, model_name, existing_results=None, output_path=None):
55 | """Process data list, generate model responses, and save results in real-time."""
56 | results = []
57 |
58 | # Create a set of processed question IDs to skip
59 | processed_ids = set()
60 | if existing_results:
61 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')}
62 | results = existing_results.copy() # Use existing results passed in
63 | print(f"Loaded {len(processed_ids)} previously processed questions (from {output_path}), will skip these questions.")
64 |
65 | total_items = len(data_list)
66 | newly_processed_count = 0
67 |
68 | for i, item in enumerate(data_list):
69 | question_id = item.get('Question_id')
70 |
71 | # Skip already processed questions
72 | if question_id in processed_ids:
73 | print(f"Skipping already processed question ID: {question_id}")
74 | continue
75 |
76 | newly_processed_count += 1 # Count newly processed questions
77 |
78 | try:
79 | # Get base64 image
80 | base64_image = item['Image']
81 |
82 | prompt = item['Prompt']
83 | question = item['Question']
84 | choices = item['Choices']
85 |
86 | # Build prompt text
87 | prompt_text = f"{prompt}\nQuestion: {question}\nChoices: {choices}"
88 |
89 | # Build API request
90 | try:
91 | response = client.chat.completions.create(
92 | model=model_name,
93 | messages=[
94 | {
95 | "role": "user",
96 | "content": [
97 | {
98 | "type": "text",
99 | "text": prompt_text
100 | },
101 | {
102 | "type": "image_url",
103 | "image_url": {
104 | "url": f"data:image/jpeg;base64,{base64_image}"
105 | }
106 | }
107 | ]
108 | }
109 | ]
110 | )
111 |
112 | # Get response text
113 | model_response = response.choices[0].message.content
114 |
115 | except Exception as api_error:
116 | print(f"API call error, trying again: {api_error}")
117 | # Simple retry mechanism
118 | time.sleep(2)
119 | try:
120 | response = client.chat.completions.create(
121 | model=model_name,
122 | messages=[
123 | {
124 | "role": "user",
125 | "content": [
126 | {
127 | "type": "text",
128 | "text": prompt_text
129 | },
130 | {
131 | "type": "image_url",
132 | "image_url": {
133 | "url": f"data:image/jpeg;base64,{base64_image}"
134 | }
135 | }
136 | ]
137 | }
138 | ]
139 | )
140 | model_response = response.choices[0].message.content
141 | except Exception as retry_error:
142 | raise Exception(f"Retry failed: {retry_error}")
143 |
144 | result = {
145 | "Question_id": question_id,
146 | "Response": model_response,
147 | "Answer": item.get('Answer'),
148 | "Category": item.get('Category'),
149 | "Png_id": item.get('Png_id')
150 | }
151 |
152 | results.append(result) # Add new result to the list
153 |
154 | # Save results every 5 *new* questions or at the last *new* question
155 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1):
156 | if output_path:
157 | save_results_to_json(results, output_path)
158 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), real-time saved results to {output_path}.")
159 |
160 | except Exception as e:
161 | print(f"Error processing Question_id {question_id}: {e}")
162 | result = {
163 | "Question_id": question_id,
164 | "Response": f"Processing error: {e}", # Record error message
165 | "Answer": item.get('Answer'),
166 | "Category": item.get('Category'),
167 | "Png_id": item.get('Png_id')
168 | }
169 | results.append(result)
170 |
171 | # Save results when error occurs
172 | if output_path:
173 | save_results_to_json(results, output_path)
174 | print(f"Error processing Question_id {question_id}, current results saved to {output_path}.")
175 |
176 | # Ensure final results are saved if any new items were processed
177 | if newly_processed_count > 0 and output_path:
178 | save_results_to_json(results, output_path)
179 | print(f"File processing complete, final results saved to: {output_path}")
180 |
181 | return results # Return complete results list for this file
182 |
183 | def main(data_folder, api_key, model_name, output_base_dir, categories=None):
184 | """Main function to process specified Parquet datasets and generate separate output files for each."""
185 | # Ensure base output directory exists
186 | os.makedirs(output_base_dir, exist_ok=True)
187 |
188 | # Find all Parquet files
189 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet"))
190 | files_to_process = []
191 |
192 | # Filter files by categories
193 | if categories:
194 | # Ensure category names don't include .parquet suffix for matching
195 | category_basenames = {cat.replace(".parquet", "") for cat in categories}
196 | for file_path in all_parquet_files:
197 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0]
198 | if basename_no_ext in category_basenames:
199 | files_to_process.append(file_path)
200 | print(f"Will process specified category files: {files_to_process}")
201 | # Check for any missing categories
202 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process}
203 | missing_categories = category_basenames - found_basenames
204 | if missing_categories:
205 | print(f"Warning: The following specified category files were not found in the folder: {missing_categories}")
206 |
207 | else:
208 | files_to_process = all_parquet_files
209 | print(f"Will process all .parquet files in folder {data_folder}.")
210 |
211 | if not files_to_process:
212 | print("No Parquet files found to process.")
213 | return
214 |
215 | # --- Initialize OpenAI API client ---
216 | print("Initializing OpenAI API client...")
217 | try:
218 | client = OpenAI(api_key=api_key)
219 | print("OpenAI API client initialized successfully.")
220 | except Exception as e:
221 | print(f"Error initializing OpenAI API client: {e}")
222 | return
223 |
224 | # --- Process each file ---
225 | for file_path in files_to_process:
226 | print("-" * 50)
227 | print(f"Starting to process file: {file_path}")
228 |
229 | # 1. Dynamically generate output filename from input filename
230 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet"
231 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA"
232 | # Combine "o4mini_" prefix with filename to create JSON filename
233 | output_filename = f"o4mini_{dataset_name}.json"
234 | current_output_path = os.path.join(output_base_dir, output_filename) # Complete output path
235 | print(f"Results will be saved to: {current_output_path}")
236 |
237 | # 2. Load existing results for current file
238 | existing_results_for_current_file = load_existing_results(current_output_path)
239 |
240 | # 3. Load Parquet data
241 | data_list = load_parquet_data(file_path)
242 |
243 | # 4. Process data
244 | if data_list:
245 | # Pass dynamically generated output path and loaded existing results to process_data
246 | process_data(
247 | data_list,
248 | client,
249 | model_name,
250 | existing_results=existing_results_for_current_file, # Pass existing results for current file
251 | output_path=current_output_path # Pass output path for current file
252 | )
253 | print(f"File {file_path} processing complete.")
254 | else:
255 | print(f"No data in file {file_path} or loading failed.")
256 |
257 | print("=" * 50)
258 | print("Processing of all specified files complete.")
259 |
260 | if __name__ == "__main__":
261 | data_folder = "./datasets" # Data folder path
262 | api_key = "your_api_key"
263 | model_name = "o4-mini-2025-04-16"
264 |
265 | output_dir = "./results/o4-mini" # Specify results directory
266 |
267 | categories = ["Synthetic_QA", "Real-world_QA", "Multi-window_QA"]
268 |
269 | # Call main function
270 | main(data_folder, api_key, model_name, output_dir, categories)
--------------------------------------------------------------------------------
/generate_response/geminipro_qa.py:
--------------------------------------------------------------------------------
1 | """
2 | # Gemini Pro QA Evaluation Tool
3 | #
4 | # This script evaluates Gemini Pro model on image-based QA tasks from Parquet datasets.
5 | #
6 | # Usage:
7 | # 1. Install dependencies: pip install google-generativeai pandas pillow
8 | # 2. Configure your API key and dataset parameters in the main section
9 | # 3. Run the script: python geminipro_qa.py
10 | #
11 | # The script will process image data from Parquet files and save results to JSON files.
12 | """
13 |
14 | import os
15 | import json
16 | import base64
17 | import glob
18 | import pandas as pd
19 | from PIL import Image
20 | from io import BytesIO
21 | from google import genai
22 | from google.genai import types
23 | import time
24 |
25 | def load_parquet_data(file_path):
26 | """Load data from a single Parquet file."""
27 | try:
28 | df = pd.read_parquet(file_path)
29 | return df.to_dict('records')
30 | except Exception as e:
31 | print(f"Failed to load file {file_path}: {e}")
32 | return []
33 |
34 | def load_existing_results(output_path):
35 | """Load existing results file to skip already processed data."""
36 | if os.path.exists(output_path):
37 | try:
38 | with open(output_path, 'r', encoding='utf-8') as f:
39 | return json.load(f)
40 | except Exception as e:
41 | print(f"Failed to load existing results file {output_path}: {e}")
42 | return []
43 |
44 | def save_results_to_json(results, output_path="output.json"):
45 | """Save results to a JSON file."""
46 | try:
47 | with open(output_path, 'w', encoding='utf-8') as f:
48 | json.dump(results, f, ensure_ascii=False, indent=4)
49 | print(f"Results saved to: {output_path}")
50 | except Exception as e:
51 | print(f"Failed to save results to {output_path}: {e}")
52 |
53 |
54 | def process_data(data_list, client, model_name, temperature, max_tokens, existing_results=None, output_path=None):
55 | """Process data list, generate model responses, and save results in real-time."""
56 | results = []
57 |
58 | # Create a set of processed question IDs to skip
59 | processed_ids = set()
60 | if existing_results:
61 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')}
62 | results = existing_results.copy() # Use existing results for this file
63 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip these.")
64 |
65 | total_items = len(data_list)
66 | newly_processed_count = 0
67 |
68 | for i, item in enumerate(data_list):
69 | question_id = item.get('Question_id')
70 |
71 | # Skip already processed questions
72 | if question_id in processed_ids:
73 | print(f"Skipping already processed question ID: {question_id}")
74 | continue
75 |
76 | newly_processed_count += 1 # Count newly processed questions
77 |
78 | try:
79 | # Decode base64 image
80 | base64_image = item['Image']
81 | image_bytes = base64.b64decode(base64_image)
82 |
83 | prompt = item['Prompt']
84 | question = item['Question']
85 | choices = item['Choices']
86 |
87 | # Build prompt text
88 | prompt_text = f"{prompt}\nQuestion: {question}\nChoices: {choices}"
89 |
90 | # Build API request
91 | try:
92 | response = client.models.generate_content(
93 | model=model_name,
94 | contents=[
95 | prompt_text,
96 | types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg")
97 | ],
98 | config=types.GenerateContentConfig(
99 | temperature=temperature,
100 | thinking_config=thinking_config,
101 | max_output_tokens=max_tokens
102 | )
103 | )
104 |
105 | # Get response text
106 | model_response = response.text.strip()
107 |
108 | except Exception as api_error:
109 | print(f"API call error, attempting retry: {api_error}")
110 | # Simple retry mechanism
111 | time.sleep(60)
112 | try:
113 | response = client.models.generate_content(
114 | model=model_name,
115 | contents=[
116 | prompt_text,
117 | types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg")
118 | ],
119 | config=types.GenerateContentConfig(
120 | temperature=temperature,
121 | thinking_config=thinking_config,
122 | max_output_tokens=max_tokens
123 | )
124 | )
125 | model_response = response.text.strip()
126 | except Exception as retry_error:
127 | raise Exception(f"Retry failed: {retry_error}")
128 |
129 | result = {
130 | "Question_id": question_id,
131 | "Response": model_response,
132 | "Answer": item.get('Answer'),
133 | "Category": item.get('Category'),
134 | "Png_id": item.get('Png_id')
135 | }
136 |
137 | results.append(result) # Add new result to the list
138 |
139 | # Save results every 5 *new* questions or at the last *new* question
140 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1):
141 | if output_path:
142 | save_results_to_json(results, output_path) # Save to specified file
143 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), saving results to {output_path}.")
144 | time.sleep(60)
145 |
146 | except Exception as e:
147 | print(f"Error processing Question_id {question_id}: {e}")
148 | result = {
149 | "Question_id": question_id,
150 | "Response": f"Processing error: {e}", # Record error message
151 | "Answer": item.get('Answer'),
152 | "Category": item.get('Category'),
153 | "Png_id": item.get('Png_id')
154 | }
155 | results.append(result)
156 |
157 | # Save results on error
158 | if output_path:
159 | save_results_to_json(results, output_path)
160 | print(f"Error processing Question_id {question_id}, current results saved to {output_path}.")
161 |
162 | # Ensure final results are saved if any new items were processed
163 | if newly_processed_count > 0 and output_path:
164 | save_results_to_json(results, output_path)
165 | print(f"File processing complete, final results saved to: {output_path}")
166 |
167 | return results # Return complete results list for this file
168 |
169 | def main(data_folder, api_key, model_name, output_base_dir, categories=None, max_tokens=1024, temperature=0):
170 | """Main function to process Parquet datasets from a folder and run inference, generating separate output files for each dataset."""
171 | # Ensure base output directory exists
172 | os.makedirs(output_base_dir, exist_ok=True)
173 |
174 | # Find all Parquet files
175 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet"))
176 | files_to_process = []
177 |
178 | # Filter files by categories
179 | if categories:
180 | # Ensure category names don't include .parquet suffix for matching
181 | category_basenames = {cat.replace(".parquet", "") for cat in categories}
182 | for file_path in all_parquet_files:
183 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0]
184 | if basename_no_ext in category_basenames:
185 | files_to_process.append(file_path)
186 | print(f"Will process specified category files: {files_to_process}")
187 | # Check for missing categories
188 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process}
189 | missing_categories = category_basenames - found_basenames
190 | if missing_categories:
191 | print(f"Warning: The following specified category files were not found in the folder: {missing_categories}")
192 |
193 | else:
194 | files_to_process = all_parquet_files
195 | print(f"Will process all .parquet files in folder {data_folder}.")
196 |
197 | if not files_to_process:
198 | print("No Parquet files found to process.")
199 | return
200 |
201 | # --- Initialize Gemini API client ---
202 | print("Initializing Gemini API client...")
203 | try:
204 | client = genai.Client(api_key=api_key)
205 | print("Gemini API client initialization complete.")
206 | except Exception as e:
207 | print(f"Error initializing Gemini API client: {e}")
208 | return
209 |
210 | # --- Process each file ---
211 | for file_path in files_to_process:
212 | print("-" * 50)
213 | print(f"Starting to process file: {file_path}")
214 |
215 | # 1. Generate output filename dynamically from input filename
216 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet"
217 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA"
218 | # Combine "gemini_" prefix with filename to form JSON filename
219 | output_filename = f"gemini_{dataset_name}.json"
220 | current_output_path = os.path.join(output_base_dir, output_filename) # Complete output path
221 | print(f"Results will be saved to: {current_output_path}")
222 |
223 | # 2. Load existing results for current file
224 | existing_results_for_current_file = load_existing_results(current_output_path)
225 |
226 | # 3. Load Parquet data
227 | data_list = load_parquet_data(file_path)
228 |
229 | # 4. Process data
230 | if data_list:
231 | # Pass dynamically generated output path and loaded existing results to process_data
232 | process_data(
233 | data_list,
234 | client,
235 | model_name,
236 | temperature,
237 | max_tokens,
238 | existing_results=existing_results_for_current_file, # Pass existing results for current file
239 | output_path=current_output_path # Pass output path for current file
240 | )
241 | print(f"File {file_path} processing complete.")
242 | else:
243 | print(f"No data in file {file_path} or loading failed.")
244 |
245 | print("=" * 50)
246 | print("Processing of all specified files complete.")
247 |
248 | if __name__ == "__main__":
249 | data_folder = "./datasets" # Data folder path
250 | api_key = "your_api_key"
251 | model_name = "gemini-2.5-pro-preview-05-06" # Gemini model name
252 |
253 | # Output filenames will be generated dynamically based on dataset (e.g., gemini_Multi-window_QA.json)
254 | output_dir = "./results/gemini" # Specify directory to save results
255 |
256 | # Specify dataset categories to process (just provide base names like "Multi-window_QA")
257 | categories = ["Real-world_QA", "Synthetic_QA", "Multi-window_QA"]
258 |
259 | max_tokens = 300 # Maximum tokens to generate
260 | temperature = 0 # Temperature parameter
261 | thinking_config = types.ThinkingConfig(thinking_budget=0)
262 |
263 | # Call main function
264 | main(data_folder, api_key, model_name, output_dir, categories, max_tokens, temperature)
--------------------------------------------------------------------------------
/calculate_similarity/clip_score.py:
--------------------------------------------------------------------------------
1 | """
2 | # CLIP Score Calculator
3 | #
4 | # This script calculates semantic similarity between pairs of images using OpenAI's CLIP model.
5 | #
6 | # Usage:
7 | # 1. Install required packages: pip install torch clip-openai opencv-python pillow tqdm
8 | # 2. Set the source folders in the main section:
9 | # - source_folder1: directory containing model-generated images
10 | # - source_folder2: directory containing reference/ground truth images
11 | # 3. Set the output directory where results will be saved
12 | # 4. Run the script: python clip_score.py
13 | #
14 | # The script will process matching image files in both directories and output JSON files
15 | # with similarity scores for each model.
16 | """
17 |
18 | # pip install torch clip-openai opencv-python pillow tqdm
19 | import os
20 | import torch
21 | import clip
22 | from PIL import Image, ImageFile
23 | import cv2
24 | import json
25 | import numpy as np
26 | from tqdm import tqdm
27 | import warnings
28 |
29 | # Disable PIL's DecompressionBombWarning
30 | warnings.filterwarnings("ignore", category=Image.DecompressionBombWarning)
31 | # Allow PIL to load truncated images
32 | ImageFile.LOAD_TRUNCATED_IMAGES = True
33 |
34 | # Initialize CLIP model
35 | device = "cuda" if torch.cuda.is_available() else "cpu"
36 | model, preprocess = clip.load("/mnt/petrelfs/sunhaoyu/work/clip_model/ViT-B-32.pt", device=device)
37 |
38 | def safe_open_image(image_path):
39 | """
40 | Safely open an image, compressing it if pixel count exceeds limit
41 |
42 | Args:
43 | image_path (str): Path to the image file
44 |
45 | Returns:
46 | PIL.Image: Processed image object
47 | """
48 | try:
49 | # First try to open the image
50 | img = Image.open(image_path)
51 |
52 | # Calculate total pixel count
53 | width, height = img.size
54 | pixels = width * height
55 |
56 | # If pixel count exceeds limit, compress the image
57 | # Use a smaller limit than PIL's warning threshold
58 | max_pixels = 89000000 # Slightly less than PIL warning threshold 89478485
59 |
60 | if pixels > max_pixels:
61 | # Calculate scale ratio
62 | scale = (max_pixels / pixels) ** 0.5
63 |
64 | # Calculate new dimensions
65 | new_width = int(width * scale)
66 | new_height = int(height * scale)
67 |
68 | # Resize the image
69 | img = img.resize((new_width, new_height), Image.LANCZOS)
70 | print(f"Image {os.path.basename(image_path)} has been resized: {width}x{height} -> {new_width}x{new_height}")
71 |
72 | return img
73 | except Exception as e:
74 | print(f"Cannot open image {image_path}: {str(e)}")
75 | # Try using cv2 to load the image as a fallback
76 | try:
77 | print(f"Trying to load image {image_path} with OpenCV")
78 | img_cv = cv2.imread(image_path)
79 | if img_cv is None:
80 | raise Exception("OpenCV cannot load the image")
81 | # Convert BGR to RGB
82 | img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
83 | # Convert to PIL image
84 | img = Image.fromarray(img_rgb)
85 | return img
86 | except Exception as cv_error:
87 | print(f"Failed to load image {image_path} with OpenCV: {str(cv_error)}")
88 | raise
89 |
90 | def clip_similarity(image_path1, image_path2):
91 | """
92 | Calculate CLIP semantic similarity between two images
93 |
94 | Args:
95 | image_path1 (str or PIL.Image): Path to the first image or PIL image object
96 | image_path2 (str or PIL.Image): Path to the second image or PIL image object
97 |
98 | Returns:
99 | similarity (float): Cosine similarity between the two images
100 | """
101 | # Load and preprocess images
102 | if isinstance(image_path1, str) and isinstance(image_path2, str):
103 | img1 = safe_open_image(image_path1)
104 | img2 = safe_open_image(image_path2)
105 | else:
106 | img1 = image_path1
107 | img2 = image_path2
108 |
109 | # Preprocess images and transfer to device
110 | img1 = preprocess(img1).unsqueeze(0).to(device)
111 | img2 = preprocess(img2).unsqueeze(0).to(device)
112 |
113 | # Extract features using CLIP
114 | with torch.no_grad():
115 | features1 = model.encode_image(img1)
116 | features2 = model.encode_image(img2)
117 |
118 | # Normalize features
119 | features1 = features1 / features1.norm(p=2, dim=-1, keepdim=True)
120 | features2 = features2 / features2.norm(p=2, dim=-1, keepdim=True)
121 |
122 | # Calculate cosine similarity
123 | similarity = torch.nn.functional.cosine_similarity(features1, features2)
124 |
125 | return similarity.item()
126 |
127 | def load_or_create_json(json_path):
128 | """
129 | Load or create JSON file
130 |
131 | Args:
132 | json_path (str): Path to JSON file
133 |
134 | Returns:
135 | dict: Loaded data or empty dictionary
136 | """
137 | if os.path.exists(json_path):
138 | with open(json_path, 'r') as f:
139 | try:
140 | data = json.load(f)
141 | return data
142 | except json.JSONDecodeError:
143 | print(f"JSON file {json_path} parsing error, creating new file")
144 | return {"results": []}
145 | else:
146 | return {"results": []}
147 |
148 | def save_json(data, json_path):
149 | """
150 | Save JSON data to file
151 |
152 | Args:
153 | data (dict): Data to save
154 | json_path (str): Path to JSON file
155 | """
156 | with open(json_path, 'w') as f:
157 | json.dump(data, f, indent=2)
158 | print(f"Results saved to {json_path}")
159 |
160 | def update_category_summary(data, category):
161 | """
162 | Update or create category summary item
163 |
164 | Args:
165 | data (dict): Dictionary containing results
166 | category (str): Dataset category name
167 |
168 | Returns:
169 | bool: Whether a new summary item was created
170 | """
171 | category_results = [item for item in data["results"] if item.get("Category") == category and item.get("id") != f"{category}_summary"]
172 |
173 | if not category_results:
174 | return False
175 |
176 | # Calculate average scores
177 | clip_scores = [item.get("clip_score", 0) for item in category_results if isinstance(item.get("clip_score", 0), (int, float))]
178 |
179 | avg_clip = np.mean(clip_scores) if clip_scores else 0
180 |
181 | # Check if summary item exists
182 | summary_exists = False
183 | for item in data["results"]:
184 | if item.get("id") == f"{category}_summary":
185 | item["clip_score"] = float(avg_clip)
186 | summary_exists = True
187 | break
188 |
189 | # If it doesn't exist, create new summary item
190 | if not summary_exists:
191 | data["results"].append({
192 | "id": f"{category}_summary",
193 | "Category": category,
194 | "clip_score": float(avg_clip)
195 | })
196 | return True
197 |
198 | return False
199 |
200 | def compare_model_datasets(source_folder1, source_folder2, output_dir="./results"):
201 | """
202 | Compare image similarity across multiple models and datasets
203 |
204 | Args:
205 | source_folder1 (str): Path to first source folder containing model folders
206 | source_folder2 (str): Path to second source folder containing dataset folders
207 | output_dir (str): Output directory for JSON files
208 | """
209 | # Ensure output directory exists
210 | os.makedirs(output_dir, exist_ok=True)
211 |
212 | # Model list
213 | models = ["claude", "gemini", "internvl2", "internvl3", "llava", "o1", "openai", "qwen", "o4mini","pro"]
214 | # Dataset list
215 | datasets = ["Code Refinement", "Image_to_code", "Interaction_Authoring", "Text_to_code"]
216 |
217 | # Loop through all models
218 | for model in models:
219 | model_folder = os.path.join(source_folder1, model)
220 | if not os.path.exists(model_folder):
221 | print(f"Model folder {model_folder} does not exist, skipping")
222 | continue
223 |
224 | json_path = os.path.join(output_dir, f"{model}_similarity.json")
225 |
226 | # Load or create JSON file
227 | data = load_or_create_json(json_path)
228 |
229 | # Get current maximum ID
230 | current_ids = [item.get("id", "") for item in data["results"] if not isinstance(item.get("id"), str) or not item.get("id", "").endswith("_summary")]
231 | current_max_id = max([int(i) for i in current_ids if str(i).isdigit()] + [0])
232 |
233 | # Loop through datasets
234 | for dataset in datasets:
235 | dataset_folder1 = os.path.join(model_folder, dataset)
236 | dataset_folder2 = os.path.join(source_folder2, dataset)
237 |
238 | if not os.path.exists(dataset_folder1) or not os.path.exists(dataset_folder2):
239 | print(f"Dataset folder {dataset_folder1} or {dataset_folder2} does not exist, skipping")
240 | continue
241 |
242 | # Get image filenames from both folders
243 | files1 = [f for f in os.listdir(dataset_folder1) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
244 | files2 = [f for f in os.listdir(dataset_folder2) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
245 |
246 | # Find common filenames between the two folders
247 | common_files = list(set(files1) & set(files2))
248 |
249 | if not common_files:
250 | print(f"Warning: No matching image files found in dataset {dataset} folders")
251 | continue
252 |
253 | processing_count = 0
254 |
255 | # Calculate similarity for each pair of images
256 | for filename in tqdm(common_files, desc=f"Processing image similarity for {model}/{dataset}"):
257 | img1_path = os.path.join(dataset_folder1, filename)
258 | img2_path = os.path.join(dataset_folder2, filename)
259 |
260 | try:
261 | # Calculate CLIP similarity
262 | clip_score = clip_similarity(img1_path, img2_path)
263 |
264 | # Add result
265 | current_max_id += 1
266 | data["results"].append({
267 | "id": current_max_id,
268 | "Category": dataset,
269 | "filename": filename,
270 | "clip_score": float(clip_score)
271 | })
272 |
273 | processing_count += 1
274 |
275 | # Save every 10 processed items
276 | if processing_count % 10 == 0:
277 | # Update category summary
278 | update_category_summary(data, dataset)
279 | save_json(data, json_path)
280 |
281 | except Exception as e:
282 | print(f"Error processing image {filename}: {str(e)}")
283 |
284 | # Update summary item for current dataset
285 | created_new = update_category_summary(data, dataset)
286 | if created_new:
287 | print(f"Created new summary item for {dataset}")
288 |
289 | # Save after processing each dataset
290 | save_json(data, json_path)
291 |
292 | if __name__ == "__main__":
293 | # Source folder 1, containing models and datasets
294 | # Can be replaced when running
295 | source_folder1 = "pro_imgs"
296 | # Source folder 2, containing only datasets
297 | source_folder2 = "label_imgs"
298 | # Output directory
299 | output_dir = "results"
300 |
301 | # Run comparison
302 | compare_model_datasets(source_folder1, source_folder2, output_dir)
303 |
--------------------------------------------------------------------------------
/generate_response/claude_qa.py:
--------------------------------------------------------------------------------
1 | """
2 | Claude QA: A tool for evaluating Claude's performance on image-based question answering tasks.
3 |
4 | Usage:
5 | 1. Place your parquet datasets in the ./datasets folder
6 | 2. Configure your API key and other parameters in the main section
7 | 3. Run: python claude_qa.py
8 | 4. Results will be saved in ./results/claude directory
9 |
10 | The script processes image-based QA datasets, sends queries to Claude API, and saves the responses.
11 | It has checkpoint functionality to resume from interrupted runs.
12 | """
13 |
14 | import os
15 | import json
16 | import base64
17 | import glob
18 | import pandas as pd
19 | from PIL import Image
20 | from io import BytesIO
21 | from anthropic import Anthropic
22 | import time
23 | import math
24 |
25 | def load_parquet_data(file_path):
26 | """Load data from a single Parquet file."""
27 | try:
28 | df = pd.read_parquet(file_path)
29 | return df.to_dict('records')
30 | except Exception as e:
31 | print(f"Failed to load file {file_path}: {e}")
32 | return []
33 |
34 | def load_existing_results(output_path):
35 | """Load existing results file to skip already processed data."""
36 | if os.path.exists(output_path):
37 | try:
38 | with open(output_path, 'r', encoding='utf-8') as f:
39 | return json.load(f)
40 | except Exception as e:
41 | print(f"Failed to load existing results file {output_path}: {e}")
42 | return []
43 |
44 | def save_results_to_json(results, output_path="output.json"):
45 | """Save results to a JSON file."""
46 | try:
47 | with open(output_path, 'w', encoding='utf-8') as f:
48 | json.dump(results, f, ensure_ascii=False, indent=4)
49 | print(f"Results saved to: {output_path}")
50 | except Exception as e:
51 | print(f"Failed to save results to {output_path}: {e}")
52 |
53 |
54 | def process_data(data_list, client, model_name, temperature, max_tokens, existing_results=None, output_path=None):
55 | """Process data list, generate model responses, and save results in real-time."""
56 | results = []
57 |
58 | # Create a set of processed question IDs to skip
59 | processed_ids = set()
60 | if existing_results:
61 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')}
62 | results = existing_results.copy() # Use existing results for this file
63 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip these questions.")
64 |
65 | total_items = len(data_list)
66 | newly_processed_count = 0
67 |
68 | for i, item in enumerate(data_list):
69 | question_id = item.get('Question_id')
70 |
71 | # Skip already processed questions
72 | if question_id in processed_ids:
73 | print(f"Skipping already processed question ID: {question_id}")
74 | continue
75 |
76 | newly_processed_count += 1 # Count newly processed questions
77 |
78 | try:
79 | # Decode base64 image
80 | base64_image = item['Image']
81 |
82 | prompt = item['Prompt']
83 | question = item['Question']
84 | choices = item['Choices']
85 |
86 | # Build prompt text
87 | prompt_text = f"{prompt}\nQuestion: {question}\nChoices: {choices}"
88 |
89 | # Build image content object
90 | image_content = {
91 | "type": "image",
92 | "source": {
93 | "type": "base64",
94 | "media_type": "image/png",
95 | "data": base64_image
96 | }
97 | }
98 |
99 | # Construct API request
100 | try:
101 | response = client.messages.create(
102 | model=model_name,
103 | max_tokens=max_tokens,
104 | temperature=temperature,
105 | messages=[
106 | {
107 | "role": "user",
108 | "content": [
109 | {"type": "text", "text": prompt_text},
110 | image_content
111 | ]
112 | }
113 | ]
114 | )
115 |
116 | # Get response text
117 | model_response = response.content[0].text
118 |
119 | except Exception as api_error:
120 | error_msg = str(api_error)
121 |
122 | # If image size exceeds limit, convert to JPEG and retry
123 | if "image exceeds 5 MB maximum" in error_msg:
124 | print(f"Image exceeds size limit, converting to JPEG: {question_id}")
125 | try:
126 | # Convert to JPEG
127 | img_data = base64.b64decode(base64_image)
128 | img = Image.open(BytesIO(img_data))
129 |
130 | # Convert to RGB (if RGBA)
131 | if img.mode == 'RGBA':
132 | img = img.convert('RGB')
133 |
134 | # Save as JPEG
135 | buffer = BytesIO()
136 | img.save(buffer, format='JPEG', quality=85)
137 | buffer.seek(0)
138 |
139 | # Convert to base64
140 | jpeg_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
141 |
142 | # Update image content
143 | image_content = {
144 | "type": "image",
145 | "source": {
146 | "type": "base64",
147 | "media_type": "image/jpeg",
148 | "data": jpeg_base64
149 | }
150 | }
151 |
152 | # Retry API call
153 | response = client.messages.create(
154 | model=model_name,
155 | max_tokens=max_tokens,
156 | temperature=temperature,
157 | messages=[
158 | {
159 | "role": "user",
160 | "content": [
161 | {"type": "text", "text": prompt_text},
162 | image_content
163 | ]
164 | }
165 | ]
166 | )
167 |
168 | model_response = response.content[0].text
169 | print(f"Successfully converted to JPEG and completed request: {question_id}")
170 |
171 | except Exception as jpeg_error:
172 | print(f"JPEG conversion failed: {jpeg_error}")
173 | model_response = f"Processing error: Image exceeds size limit and JPEG conversion failed - {jpeg_error}"
174 |
175 | else:
176 | print(f"API call error, attempting retry: {api_error}")
177 | # Simple retry mechanism
178 | time.sleep(2)
179 | try:
180 | response = client.messages.create(
181 | model=model_name,
182 | max_tokens=max_tokens,
183 | temperature=temperature,
184 | messages=[
185 | {
186 | "role": "user",
187 | "content": [
188 | {"type": "text", "text": prompt_text},
189 | image_content
190 | ]
191 | }
192 | ]
193 | )
194 | model_response = response.content[0].text
195 | except Exception as retry_error:
196 | raise Exception(f"Retry failed: {retry_error}")
197 |
198 | result = {
199 | "Question_id": question_id,
200 | "Response": model_response,
201 | "Answer": item.get('Answer'),
202 | "Category": item.get('Category'),
203 | "Png_id": item.get('Png_id')
204 | }
205 |
206 | results.append(result) # Add new result to the list
207 |
208 | # Save results every 5 *new* questions or at the end of the file
209 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1):
210 | if output_path:
211 | save_results_to_json(results, output_path)
212 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), saved results to {output_path}.")
213 |
214 | except Exception as e:
215 | print(f"Error processing Question_id {question_id}: {e}")
216 | result = {
217 | "Question_id": question_id,
218 | "Response": f"Processing error: {e}", # Record error message
219 | "Answer": item.get('Answer'),
220 | "Category": item.get('Category'),
221 | "Png_id": item.get('Png_id')
222 | }
223 | results.append(result)
224 |
225 | # Save results when error occurs
226 | if output_path:
227 | save_results_to_json(results, output_path)
228 | print(f"Error processing Question_id {question_id}, saved current results to {output_path}.")
229 |
230 | # Ensure final results are saved at the end of function if any new items were processed
231 | if newly_processed_count > 0 and output_path:
232 | save_results_to_json(results, output_path)
233 | print(f"File processing complete, final results saved to: {output_path}")
234 |
235 | return results # Return the complete results list for the current file
236 |
237 | def main(data_folder, api_key, model_name, output_base_dir, categories=None, max_tokens=1024, temperature=0):
238 | """Main function to process specified parquet datasets and generate outputs in separate files."""
239 | # Ensure base output directory exists
240 | os.makedirs(output_base_dir, exist_ok=True)
241 |
242 | # Find all Parquet files
243 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet"))
244 | files_to_process = []
245 |
246 | # Filter files by categories
247 | if categories:
248 | # Make sure category names don't include .parquet suffix
249 | category_basenames = {cat.replace(".parquet", "") for cat in categories}
250 | for file_path in all_parquet_files:
251 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0]
252 | if basename_no_ext in category_basenames:
253 | files_to_process.append(file_path)
254 | print(f"Will process specified category files: {files_to_process}")
255 | # Check for missing categories
256 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process}
257 | missing_categories = category_basenames - found_basenames
258 | if missing_categories:
259 | print(f"Warning: The following specified category files were not found: {missing_categories}")
260 |
261 | else:
262 | files_to_process = all_parquet_files
263 | print(f"Will process all .parquet files in the {data_folder} folder.")
264 |
265 | if not files_to_process:
266 | print("No Parquet files found to process.")
267 | return
268 |
269 | # Initialize Claude API client
270 | print("Initializing Claude API client...")
271 | try:
272 | client = Anthropic(api_key=api_key)
273 | print("Claude API client initialization complete.")
274 | except Exception as e:
275 | print(f"Error initializing Claude API client: {e}")
276 | return
277 |
278 | # Process each file
279 | for file_path in files_to_process:
280 | print("-" * 50)
281 | print(f"Starting to process file: {file_path}")
282 |
283 | # 1. Dynamically generate output filename from input filename
284 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet"
285 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA"
286 | # Use "claude_" prefix and dataset name to create JSON filename
287 | output_filename = f"claude_{dataset_name}.json"
288 | current_output_path = os.path.join(output_base_dir, output_filename) # Full output path
289 | print(f"Results will be saved to: {current_output_path}")
290 |
291 | # 2. Load existing results for current file
292 | existing_results_for_current_file = load_existing_results(current_output_path)
293 |
294 | # 3. Load Parquet data
295 | data_list = load_parquet_data(file_path)
296 |
297 | # 4. Process data
298 | if data_list:
299 | # Pass dynamically generated output path and loaded existing results to process_data
300 | process_data(
301 | data_list,
302 | client,
303 | model_name,
304 | temperature,
305 | max_tokens,
306 | existing_results=existing_results_for_current_file, # Pass existing results
307 | output_path=current_output_path # Pass output path
308 | )
309 | print(f"File {file_path} processing complete.")
310 | else:
311 | print(f"No data in file {file_path} or loading failed.")
312 |
313 | print("=" * 50)
314 | print("Processing of all specified files complete.")
315 |
316 | if __name__ == "__main__":
317 | data_folder = "./datasets" # Data folder path
318 | api_key = "your_key"
319 | model_name = "claude-3-7-sonnet-20250219"
320 |
321 | output_dir = "./results/claude" # Specify result directory
322 |
323 | categories = ["Real-world_QA", "Synthetic_QA", "Multi-window_QA"]
324 |
325 | max_tokens = 300 # Maximum tokens to generate
326 | temperature = 0 # Temperature parameter
327 |
328 | # Call main function
329 | main(data_folder, api_key, model_name, output_dir, categories, max_tokens, temperature)
--------------------------------------------------------------------------------
/calculate_similarity/gemini_evaluate.py:
--------------------------------------------------------------------------------
1 | """
2 | # Gemini Image Comparison Evaluator
3 |
4 | This script evaluates the similarity between pairs of webpage images using Google's Gemini model.
5 | It compares model-generated images with ground truth images across multiple dimensions and saves
6 | the evaluation scores as JSON files.
7 |
8 | ## Usage:
9 | 1. Place model-generated images in the 'o4mini_imgs' directory
10 | 2. Place ground truth images in the 'label_imgs' directory
11 | 3. Add your Gemini API key(s) to the api_keys list
12 | 4. Run the script: python gemini_evaluate.py
13 |
14 | ## Directory Structure:
15 | - o4mini_imgs/[model]/[dataset]/*.png - Model generated images
16 | - label_imgs/[dataset]/*.png - Ground truth images
17 | - results/ - Output directory for evaluation results
18 | """
19 |
20 | from google import genai
21 | from google.genai import types
22 | import os
23 | import glob
24 | import json
25 | import re
26 | import time
27 | from pathlib import Path
28 | from datetime import datetime
29 |
30 | # API key configuration
31 | api_keys = [
32 | "your keys"
33 | ]
34 |
35 | current_api_key_index = 0
36 | api_key = api_keys[current_api_key_index]
37 | client = genai.Client(api_key=api_key)
38 |
39 | def switch_to_next_api_key():
40 | global current_api_key_index, client, api_key
41 | current_api_key_index = (current_api_key_index + 1) % len(api_keys)
42 | api_key = api_keys[current_api_key_index]
43 | client = genai.Client(api_key=api_key)
44 | print(f"Switching to new API key: {current_api_key_index + 1}/{len(api_keys)}")
45 | write_log(f"Switched to new API key index {current_api_key_index + 1}/{len(api_keys)}")
46 | return api_key
47 |
48 | # Source folders and output path
49 | source_folder1 = "o4mini_imgs" # Model generated images
50 | source_folder2 = "label_imgs" # Ground truth images
51 | json_save_path = "./results" # JSON save path
52 |
53 | os.makedirs(json_save_path, exist_ok=True)
54 |
55 | models = ["llava", "o1", "openai", "gemini", "internvl2", "internvl3", "qwen", "claude", "o4mini", "pro"] # models
56 | datasets = ["Code_Refinement", "Image_to_code", "Interaction_Authoring", "Text_to_code"]
57 |
58 | prompt = '''
59 | Your task is to assess two webpage images and output a score between 0 and 10 for each of the following 10 questions, reflecting the degree of similarity between the webpages. A score of 10 indicates perfect similarity (identical in every aspect), while a score of 0 indicates no similarity. For partial similarities, assign a score between 1 and 9, where higher scores reflect greater similarity. Only output a comma-separated list of 10 numbers enclosed in square brackets, e.g., [10,8,6,4,2,0,0,0,0,0]. Do not assign a score of 10 unless the two images are identical in the respective category.
60 |
61 | 1. **Element Reproduction (Score: 0-10):** Are key elements (text, images, buttons) fully present and styled identically to the original design? (e.g., 10 for identical elements, 5 for missing or slightly altered elements, 0 for completely different elements.)
62 | 2. **Proportion and Size Consistency (Score: 0-10):** Do the sizes and proportions of elements (text, images, buttons) match the original design, maintaining visual harmony? (e.g., 10 for exact proportions, 6 for minor size differences, 0 for significant discrepancies.)
63 | 3. **Layout and Typography Fidelity (Score: 0-10):** Does the overall layout (headers, footers, navigation bars, sidebars) faithfully replicate the original design's typography and structure? (e.g., 10 for identical layouts, 5 for similar but not exact placements, 0 for entirely different layouts.)
64 | 4. **Alignment and Spacing Accuracy (Score: 0-10):** Are elements aligned and spaced (margins, padding) as in the original design? (e.g., 10 for perfect alignment and spacing, 6 for minor misalignments, 0 for major misalignments.)
65 | 5. **Visual Hierarchy Clarity (Score: 0-10):** Does the webpage maintain the same visual hierarchy as the original, allowing users to quickly identify key information? (e.g., 10 for identical hierarchy, 5 for slightly altered emphasis, 0 for unclear or different hierarchy.)
66 | 6. **Color Consistency (Score: 0-10):** Do the overall color scheme, hues, and tones match the original design? (e.g., 10 for identical colors, 6 for similar palette with minor variations, 0 for completely different colors.)
67 | 7. **Style Consistency (Score: 0-10):** Does the webpage's overall aesthetic style (e.g., modern, minimalistic) align with the original design? (e.g., 10 for identical style, 4 for similar but distinguishable style, 0 for entirely different style.)
68 | 8. **Text Style Consistency (Score: 0-10):** Are text attributes (font type, size, line spacing, paragraph spacing, alignment) consistent with the original design? (e.g., 10 for identical text styles, 5 for similar fonts with spacing issues, 0 for completely different text styles.)
69 | 9. **Text Content Accuracy (Score: 0-10):** Does the webpage accurately reproduce the main textual content of the original design? (e.g., 10 for identical text, 5 for partial matches, 0 for entirely different text.)
70 | 10. **Overall Content Representation (Score: 0-10):** Does the webpage convey the same content information and intent as the original design? (e.g., 10 for identical content representation, 6 for similar but incomplete content, 0 for entirely different content.)
71 |
72 | **Output Format:** [score1,score2,score3,score4,score5,score6,score7,score8,score9,score10]
73 | '''
74 |
75 | log_file = os.path.join(json_save_path, "log.txt")
76 |
77 | def write_log(message):
78 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
79 | with open(log_file, "a", encoding="utf-8") as f:
80 | f.write(f"[{timestamp}] {message}\n")
81 |
82 | def update_category_summary(model, category, model_data):
83 | category_items = [item for item in model_data[model]["items"]
84 | if item.get("Category") == category and not item.get("id", "").endswith("_summary")]
85 |
86 | if not category_items:
87 | return
88 |
89 | all_scores = [item["scores"] for item in category_items]
90 |
91 | avg_scores = []
92 | for i in range(10):
93 | dimension_scores = [scores[i] for scores in all_scores]
94 | avg_scores.append(round(sum(dimension_scores) / len(dimension_scores), 2))
95 |
96 | summary_id = f"{category}_summary"
97 | summary_exists = False
98 |
99 | for i, item in enumerate(model_data[model]["items"]):
100 | if item.get("id") == summary_id:
101 | model_data[model]["items"][i]["scores"] = avg_scores
102 | summary_exists = True
103 | break
104 |
105 | if not summary_exists:
106 | summary_item = {
107 | "id": summary_id,
108 | "Category": category,
109 | "scores": avg_scores
110 | }
111 | model_data[model]["items"].append(summary_item)
112 |
113 | model_data = {}
114 | for model in models:
115 | json_file = os.path.join(json_save_path, f"{model}.json")
116 | if os.path.exists(json_file):
117 | with open(json_file, 'r', encoding='utf-8') as f:
118 | try:
119 | model_data[model] = json.load(f)
120 | except json.JSONDecodeError:
121 | model_data[model] = {"items": []}
122 | else:
123 | model_data[model] = {"items": []}
124 |
125 | total_pairs = 0
126 | processed_pairs = 0
127 | model_processed_counts = {model: 0 for model in models}
128 |
129 | try:
130 | for model in models:
131 | for dataset in datasets:
132 | dataset_path1 = os.path.join(source_folder1, model, dataset)
133 | dataset_path2 = os.path.join(source_folder2, dataset)
134 |
135 | if not os.path.exists(dataset_path1) or not os.path.exists(dataset_path2):
136 | message = f"Path does not exist: {dataset_path1} or {dataset_path2}"
137 | print(message)
138 | write_log(message)
139 | continue
140 |
141 | images1 = []
142 | for ext in ['*.png', '*.jpg', '*.jpeg']:
143 | images1.extend(glob.glob(os.path.join(dataset_path1, ext)))
144 |
145 | for img1_path in images1:
146 | img_name = os.path.basename(img1_path)
147 | img2_path = os.path.join(dataset_path2, img_name)
148 |
149 | if os.path.exists(img2_path):
150 | img_id = f"{dataset}_{Path(img_name).stem}"
151 | already_processed = False
152 |
153 | for item in model_data[model]["items"]:
154 | if item.get("id") == img_id:
155 | already_processed = True
156 | break
157 |
158 | if already_processed:
159 | print(f"Skipping already processed image: {model}/{dataset}/{img_name}")
160 | continue
161 |
162 | total_pairs += 1
163 |
164 | with open(img1_path, 'rb') as f1, open(img2_path, 'rb') as f2:
165 | img1_bytes = f1.read()
166 | img2_bytes = f2.read()
167 |
168 | success = False
169 | retry_count = 0
170 | max_retries = 2
171 | error_messages = []
172 |
173 | while not success and retry_count <= max_retries:
174 | try:
175 | if retry_count > 0:
176 | print(f"Retry #{retry_count} processing image: {model}/{dataset}/{img_name}")
177 |
178 | response = client.models.generate_content(
179 | model="gemini-2.5-flash-preview-04-17",
180 | contents=[
181 | prompt,
182 | types.Part.from_bytes(data=img1_bytes, mime_type='image/png'),
183 | types.Part.from_bytes(data=img2_bytes, mime_type='image/png')
184 | ]
185 | )
186 |
187 | response_text = response.text
188 | scores_match = re.search(r'\[(.*?)\]', response_text)
189 | if scores_match:
190 | scores_text = scores_match.group(1)
191 | scores = [float(s.strip()) for s in scores_text.split(',')]
192 | if len(scores) == 10:
193 | item_data = {
194 | "id": img_id,
195 | "Category": dataset,
196 | "scores": scores
197 | }
198 | model_data[model]["items"].append(item_data)
199 |
200 | update_category_summary(model, dataset, model_data)
201 |
202 | processed_pairs += 1
203 | model_processed_counts[model] += 1
204 | print(f"Processed image pair {processed_pairs}/{total_pairs}: {model}/{dataset}/{img_name}")
205 |
206 | if model_processed_counts[model] % 1 == 0:
207 | json_file = os.path.join(json_save_path, f"{model}.json")
208 | with open(json_file, 'w', encoding='utf-8') as f:
209 | json.dump(model_data[model], f, indent=2, ensure_ascii=False)
210 | print(f"Saved {model}.json (processed {model_processed_counts[model]} images)")
211 | success = True
212 | else:
213 | error_msg = f"Did not get 10 scores: {scores}"
214 | error_messages.append(error_msg)
215 | print(f"Warning: {model}/{dataset}/{img_name} {error_msg}")
216 | else:
217 | error_msg = f"Could not extract score list from response: {response_text}"
218 | error_messages.append(error_msg)
219 | print(f"Warning: {model}/{dataset}/{img_name} {error_msg}")
220 |
221 | except Exception as e:
222 | error_msg = f"Error processing image: {str(e)}"
223 | error_messages.append(error_msg)
224 | print(f"Warning: {model}/{dataset}/{img_name} {error_msg}")
225 |
226 | error_str = str(e)
227 | if "RESOURCE_EXHAUSTED" in error_str and "exceeded your current quota" in error_str:
228 | switch_to_next_api_key()
229 | continue
230 |
231 | retry_count += 1
232 | if not success and retry_count <= max_retries:
233 | time.sleep(2)
234 |
235 | if not success:
236 | error_log = f"Failed to process image (after {max_retries+1} attempts): {model}/{dataset}/{img_name}\n"
237 | for i, msg in enumerate(error_messages):
238 | error_log += f" Attempt {i+1} error: {msg}\n"
239 | write_log(error_log)
240 |
241 | except Exception as e:
242 | message = f"Error occurred, stopping script: {str(e)}"
243 | print(message)
244 | write_log(message)
245 | for model in models:
246 | if model in model_data:
247 | for dataset in datasets:
248 | update_category_summary(model, dataset, model_data)
249 |
250 | json_file = os.path.join(json_save_path, f"{model}.json")
251 | with open(json_file, 'w', encoding='utf-8') as f:
252 | json.dump(model_data[model], f, indent=2, ensure_ascii=False)
253 |
254 | for model in models:
255 | if model in model_data:
256 | for dataset in datasets:
257 | update_category_summary(model, dataset, model_data)
258 |
259 | json_file = os.path.join(json_save_path, f"{model}.json")
260 | with open(json_file, 'w', encoding='utf-8') as f:
261 | json.dump(model_data[model], f, indent=2, ensure_ascii=False)
262 |
263 | print(f"Complete! Processed {processed_pairs}/{total_pairs} image pairs.")
--------------------------------------------------------------------------------
/generate_response/internvl2_5_code.py:
--------------------------------------------------------------------------------
1 | """
2 | InternVL2.5 Code Generation Tool
3 |
4 | Usage:
5 | python internvl2_5_code.py [--data_folder DATA_FOLDER] [--model_path MODEL_PATH]
6 | [--output_dir OUTPUT_DIR] [--categories CATEGORIES]
7 | [--max_tokens MAX_TOKENS] [--temperature TEMPERATURE]
8 | [--tensor_parallel_size TP_SIZE]
9 | """
10 |
11 | import os
12 | import json
13 | import base64
14 | import re
15 | import pandas as pd
16 | from io import BytesIO
17 | from PIL import Image
18 | from vllm import LLM, SamplingParams
19 |
20 | def load_parquet_data(file_path):
21 | try:
22 | df = pd.read_parquet(file_path)
23 | return df.to_dict('records')
24 | except Exception as e:
25 | print(f"Failed to load file {file_path}: {e}")
26 | return []
27 |
28 | def load_existing_results(output_path):
29 | if os.path.exists(output_path):
30 | try:
31 | with open(output_path, 'r', encoding='utf-8') as f:
32 | return json.load(f)
33 | except Exception as e:
34 | print(f"Failed to load existing results file {output_path}: {e}")
35 | return []
36 |
37 | def save_results_to_json(results, output_path="output.json"):
38 | try:
39 | with open(output_path, 'w', encoding='utf-8') as f:
40 | json.dump(results, f, ensure_ascii=False, indent=4)
41 | print(f"Results saved to: {output_path}")
42 | except Exception as e:
43 | print(f"Failed to save results to {output_path}: {e}")
44 |
45 | def extract_html_code(text):
46 | html_pattern = r'(?:<\!DOCTYPE\s+html>.*?<\/html>)'
47 | matches = re.findall(html_pattern, text, re.DOTALL | re.IGNORECASE)
48 | if matches:
49 | return matches[0]
50 | if '' in text.lower() and '