├── asserts ├── fig1_teaser.png └── fig2_overview.png ├── generate_response ├── run_llava_qa.sh ├── run_qwen_code.sh ├── run_llava_code.sh ├── run_internvl3_qa.sh ├── run_qwen_qa.sh ├── run_internvl2_5_qa.sh ├── run_internvl3_code.sh ├── run_internvl2_5_code.sh ├── internvl2_5_qa.py ├── internvl3_qa.py ├── openai_qa.py ├── llava_qa.py ├── o1_qa.py ├── qwen_qa.py ├── gemini_qa.py ├── o4-mini_qa.py ├── geminipro_qa.py ├── claude_qa.py ├── internvl2_5_code.py ├── o1_code.py └── geminipro_code.py ├── requirements.txt ├── README.md └── calculate_similarity ├── render_img.py ├── clip_score.py └── gemini_evaluate.py /asserts/fig1_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mikivishy/FullFront/HEAD/asserts/fig1_teaser.png -------------------------------------------------------------------------------- /asserts/fig2_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mikivishy/FullFront/HEAD/asserts/fig2_overview.png -------------------------------------------------------------------------------- /generate_response/run_llava_qa.sh: -------------------------------------------------------------------------------- 1 | export NCCL_SOCKET_IFNAME=bond0 2 | export NCCL_IB_HCA=mlx5_0 3 | 4 | srun \ 5 | --partition=MoE \ 6 | --mpi=pmi2 \ 7 | --job-name=llava_qa \ 8 | -c 128 \ 9 | --gres=gpu:8 \ 10 | --nodes=1 \ 11 | --ntasks-per-node=1 \ 12 | --kill-on-bad-exit=1 \ 13 | --quotatype=spot \ 14 | python llava_qa.py -------------------------------------------------------------------------------- /generate_response/run_qwen_code.sh: -------------------------------------------------------------------------------- 1 | export NCCL_SOCKET_IFNAME=bond0 2 | export NCCL_IB_HCA=mlx5_0 3 | 4 | srun \ 5 | --partition=MoE \ 6 | --mpi=pmi2 \ 7 | --job-name=qwen_code \ 8 | -c 42 \ 9 | --gres=gpu:4 \ 10 | --nodes=1 \ 11 | --ntasks-per-node=1 \ 12 | --kill-on-bad-exit=1 \ 13 | --quotatype=spot \ 14 | python qwen_code.py -------------------------------------------------------------------------------- /generate_response/run_llava_code.sh: -------------------------------------------------------------------------------- 1 | export NCCL_SOCKET_IFNAME=bond0 2 | export NCCL_IB_HCA=mlx5_0 3 | 4 | srun \ 5 | --partition=MoE \ 6 | --mpi=pmi2 \ 7 | --job-name=llava_code \ 8 | -c 42 \ 9 | --gres=gpu:4 \ 10 | --nodes=1 \ 11 | --ntasks-per-node=1 \ 12 | --kill-on-bad-exit=1 \ 13 | --quotatype=auto \ 14 | python llava_code.py -------------------------------------------------------------------------------- /generate_response/run_internvl3_qa.sh: -------------------------------------------------------------------------------- 1 | export NCCL_SOCKET_IFNAME=bond0 2 | export NCCL_IB_HCA=mlx5_0 3 | 4 | srun \ 5 | --partition=MoE \ 6 | --mpi=pmi2 \ 7 | --job-name=internvl_3 \ 8 | -c 128 \ 9 | --gres=gpu:8 \ 10 | --nodes=1 \ 11 | --ntasks-per-node=1 \ 12 | --kill-on-bad-exit=1 \ 13 | --quotatype=spot \ 14 | python internvl3_qa.py -------------------------------------------------------------------------------- /generate_response/run_qwen_qa.sh: -------------------------------------------------------------------------------- 1 | export NCCL_SOCKET_IFNAME=bond0 2 | export NCCL_IB_HCA=mlx5_0 3 | 4 | srun \ 5 | --partition=MoE \ 6 | --mpi=pmi2 \ 7 | --job-name=qwen_qa \ 8 | -c 128 \ 9 | --gres=gpu:8 \ 10 | --nodes=1 \ 11 | --ntasks-per-node=1 \ 12 | --kill-on-bad-exit=1 \ 13 | --quotatype=spot \ 14 | python qwen_qa.py \ 15 | -- -------------------------------------------------------------------------------- /generate_response/run_internvl2_5_qa.sh: -------------------------------------------------------------------------------- 1 | export NCCL_SOCKET_IFNAME=bond0 2 | export NCCL_IB_HCA=mlx5_0 3 | 4 | srun \ 5 | --partition=MoE \ 6 | --mpi=pmi2 \ 7 | --job-name=intervl2_5_qa \ 8 | -c 128 \ 9 | --gres=gpu:8 \ 10 | --nodes=1 \ 11 | --ntasks-per-node=1 \ 12 | --kill-on-bad-exit=1 \ 13 | --quotatype=spot \ 14 | python internvl2_5_qa.py -------------------------------------------------------------------------------- /generate_response/run_internvl3_code.sh: -------------------------------------------------------------------------------- 1 | export NCCL_SOCKET_IFNAME=bond0 2 | export NCCL_IB_HCA=mlx5_0 3 | 4 | srun \ 5 | --partition=MoE \ 6 | --mpi=pmi2 \ 7 | --job-name=internvl3_code \ 8 | -c 42 \ 9 | --gres=gpu:4 \ 10 | --nodes=1 \ 11 | --ntasks-per-node=1 \ 12 | --kill-on-bad-exit=1 \ 13 | --quotatype=spot \ 14 | python internvl3_code.py -------------------------------------------------------------------------------- /generate_response/run_internvl2_5_code.sh: -------------------------------------------------------------------------------- 1 | export NCCL_SOCKET_IFNAME=bond0 2 | export NCCL_IB_HCA=mlx5_0 3 | 4 | srun \ 5 | --partition=MoE \ 6 | --mpi=pmi2 \ 7 | --job-name=internvl2_5_code \ 8 | -c 42 \ 9 | --gres=gpu:4 \ 10 | --nodes=1 \ 11 | --ntasks-per-node=1 \ 12 | --kill-on-bad-exit=1 \ 13 | --quotatype=spot \ 14 | python internvl2_5_code.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # FullFront Project Dependencies 2 | # Model Response Generation 3 | anthropic>=0.8.0 4 | openai>=1.0.0 5 | google-generativeai>=0.3.0 6 | transformers>=4.35.0 7 | pillow>=9.0.0 8 | pandas>=2.0.0 9 | numpy>=1.24.0 10 | tqdm>=4.65.0 11 | requests>=2.30.0 12 | 13 | # HTML Rendering 14 | playwright>=1.40.0 15 | 16 | # Similarity Calculation 17 | torch>=2.0.0 18 | torchvision>=0.15.0 19 | clip-openai>=1.0 20 | opencv-python>=4.8.0 21 | sentence-transformers>=2.2.2 22 | bs4>=0.0.1 23 | beautifulsoup4>=4.12.0 24 | html-similarity>=0.3.3 25 | lxml>=4.9.0 26 | difflib>=0.1.0 27 | 28 | # Data Processing 29 | pyarrow>=14.0.0 30 | fastparquet>=2023.10.0 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FullFront 2 | 3 |
4 | 5 |
6 | 7 | FullFront is a comprehensive benchmark for evaluating Multimodal Large Language Models (MLLMs) across the entire front-end engineering workflow. This project provides code generation, page understanding, and evaluation tools to measure MLLMs' performance at various stages of front-end development. 8 | 9 |
10 | 11 |
12 | 13 | ## Project Overview 14 | 15 | The FullFront benchmark covers three core tasks in front-end engineering: 16 | 1. **Webpage Design** - Assessing the model's ability to organize and structure visual elements 17 | 2. **Webpage Perception QA** - Evaluating the model's understanding of visual organization, element characteristics, and spatial relationships 18 | 3. **Webpage Code Generation** - Focusing on the ability to accurately translate visual designs into functional code 19 | 20 | ## Key Features 21 | 22 | - Supports evaluation of multiple mainstream multimodal models (Claude, OpenAI, Gemini, etc.) 23 | - Provides a complete code generation and evaluation workflow 24 | - Includes image similarity and code quality assessment metrics 25 | - Automatically renders HTML into images for evaluation 26 | 27 | ## Installation 28 | 29 | 1. Clone this repository: 30 | ```bash 31 | git clone https://github.com/your-username/FullFront.git 32 | cd FullFront 33 | ``` 34 | 35 | 2. Install dependencies: 36 | ```bash 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | ## Usage Guide 41 | 42 | ### Generating Model Responses 43 | 44 | The `generate_response` directory contains scripts for generating responses from different models: 45 | 46 | 1. **Set API Keys**: Based on the model you're using, set the API key in the corresponding script. 47 | 48 | 2. **Run Generation Scripts**: 49 | ```bash 50 | cd generate_response 51 | python claude_code.py # Generate code using Claude model 52 | python openai_code.py # Generate code using OpenAI model 53 | python gemini_code.py # Generate code using Gemini model 54 | ``` 55 | 56 | 3. **Use Shell Scripts for Batch Processing**: 57 | ```bash 58 | bash run_llava_code.sh # Run LLaVA model code generation tasks 59 | bash run_qwen_qa.sh # Run Qwen model QA tasks 60 | ``` 61 | 62 | The generated results will be saved in the `generate_response/results/{model_name}` directory. 63 | 64 | ### Rendering HTML to Images 65 | 66 | Use `calculate_similarity/render_img.py` to render generated HTML into images: 67 | 68 | ```bash 69 | python calculate_similarity/render_img.py 70 | ``` 71 | 72 | You can modify the input and output directories in this script: 73 | ```python 74 | html_folder = "./path/to/your/html/files" 75 | screenshot_folder = "./path/to/save/screenshots" 76 | ``` 77 | 78 | ### Calculating Similarity Scores 79 | 80 | 1. **CLIP Similarity**: Evaluate semantic similarity between generated images and target images 81 | ```bash 82 | python calculate_similarity/clip_score.py 83 | ``` 84 | 85 | 2. **Code Similarity**: Evaluate structure and content similarity between generated code and standard code 86 | ```bash 87 | python calculate_similarity/code_score.py 88 | ``` 89 | 90 | 3. **Gemini Evaluation**: Use Gemini model to evaluate generated content 91 | ```bash 92 | python calculate_similarity/gemini_evaluate.py 93 | ``` 94 | 95 | ### Result Analysis 96 | 97 | Evaluation results will be saved in the `calculate_similarity/results/` directory, containing the following metrics: 98 | - CLIP similarity score 99 | - Code structure similarity 100 | - Code content similarity -------------------------------------------------------------------------------- /calculate_similarity/render_img.py: -------------------------------------------------------------------------------- 1 | """ 2 | HTML to Image Renderer 3 | 4 | Usage: 5 | 1. Run the script directly: python render_img.py 6 | 2. Or import the capture_screenshot function: 7 | from render_img import capture_screenshot 8 | capture_screenshot(html_file_path, screenshot_folder) 9 | 10 | This script renders HTML files to PNG images using Playwright. 11 | It waits for all images to load, handles lazy-loaded content, and 12 | captures full-page screenshots. 13 | """ 14 | 15 | from playwright.sync_api import sync_playwright 16 | import os 17 | import time 18 | 19 | def capture_screenshot(html_file_path, screenshot_folder): 20 | """ 21 | Capture a screenshot of an HTML file after all images are loaded. 22 | 23 | Args: 24 | html_file_path: Path to the HTML file 25 | screenshot_folder: Path to save screenshots 26 | 27 | Returns: 28 | bool: True if successful, False otherwise 29 | """ 30 | try: 31 | # Ensure HTML file exists 32 | if not os.path.exists(html_file_path): 33 | raise FileNotFoundError(f"HTML file not found: {html_file_path}") 34 | 35 | # Get absolute path 36 | absolute_path = os.path.abspath(html_file_path) 37 | file_url = f"file://{absolute_path}" 38 | 39 | # Generate screenshot filename 40 | html_file_name = os.path.splitext(os.path.basename(html_file_path))[0] 41 | screenshot_path = os.path.join(screenshot_folder, f"{html_file_name}.png") 42 | 43 | with sync_playwright() as p: 44 | # Launch browser 45 | browser = p.chromium.launch(headless=True) 46 | page = browser.new_page() 47 | 48 | # Navigate to HTML file 49 | page.goto(file_url) 50 | 51 | # Wait for page to load 52 | page.wait_for_load_state("load") 53 | 54 | # Wait for all images to fully load 55 | page.evaluate(""" 56 | () => { 57 | return new Promise((resolve) => { 58 | const images = document.querySelectorAll('img'); 59 | let loadedImages = 0; 60 | 61 | // If no images, resolve immediately 62 | if (images.length === 0) { 63 | return resolve(); 64 | } 65 | 66 | // Add load event handlers to each image 67 | images.forEach(img => { 68 | // Already loaded images 69 | if (img.complete) { 70 | loadedImages++; 71 | if (loadedImages === images.length) { 72 | resolve(); 73 | } 74 | } else { 75 | // Listen for load events 76 | img.addEventListener('load', () => { 77 | loadedImages++; 78 | if (loadedImages === images.length) { 79 | resolve(); 80 | } 81 | }); 82 | 83 | // Listen for error events 84 | img.addEventListener('error', () => { 85 | loadedImages++; 86 | if (loadedImages === images.length) { 87 | resolve(); 88 | } 89 | }); 90 | } 91 | }); 92 | }); 93 | } 94 | """) 95 | 96 | # Simulate scrolling to load lazy-loaded elements 97 | page.evaluate(""" 98 | () => { 99 | return new Promise((resolve) => { 100 | let totalHeight = 0; 101 | const distance = 100; 102 | const timer = setInterval(() => { 103 | const scrollHeight = document.body.scrollHeight; 104 | window.scrollBy(0, distance); 105 | totalHeight += distance; 106 | 107 | // Stop if reached bottom or scrolled enough 108 | if(totalHeight >= scrollHeight || totalHeight > 10000){ 109 | clearInterval(timer); 110 | // Scroll back to top 111 | window.scrollTo(0, 0); 112 | setTimeout(resolve, 100); // Give time for page to stabilize 113 | } 114 | }, 100); 115 | }); 116 | } 117 | """) 118 | 119 | # Capture screenshot 120 | page.screenshot(path=screenshot_path, full_page=True) 121 | 122 | # Close browser 123 | browser.close() 124 | 125 | print(f"Screenshot saved to: {screenshot_path}") 126 | return True 127 | except Exception as e: 128 | print(f"Error processing file {html_file_path}: {str(e)}") 129 | return False 130 | 131 | if __name__ == "__main__": 132 | start_time = time.time() 133 | # HTML files directory 134 | html_folder = "./input/html" # Replace with your HTML files directory 135 | # Screenshots output directory 136 | screenshot_folder = "./output/images" # Replace with your screenshots directory 137 | 138 | # Ensure screenshot folder exists 139 | if not os.path.exists(screenshot_folder): 140 | os.makedirs(screenshot_folder) 141 | 142 | # Get all HTML files 143 | html_files = [f for f in os.listdir(html_folder) if f.endswith(".html")] 144 | 145 | # Sort alphabetically 146 | html_files.sort() 147 | 148 | # Get existing screenshot names (without extension) 149 | existing_screenshots = set() 150 | if os.path.exists(screenshot_folder): 151 | for file in os.listdir(screenshot_folder): 152 | if file.endswith(".png"): 153 | existing_screenshots.add(os.path.splitext(file)[0]) 154 | 155 | # Track failed files 156 | failed_indices = [] 157 | # Count skipped files 158 | skipped_count = 0 159 | 160 | # Process each HTML file 161 | for index, html_file in enumerate(html_files): 162 | # Get HTML filename without extension 163 | html_file_name = os.path.splitext(html_file)[0] 164 | 165 | # Skip if screenshot already exists 166 | if html_file_name in existing_screenshots: 167 | print(f"Screenshot exists, skipping [{index+1}/{len(html_files)}]: {html_file}") 168 | skipped_count += 1 169 | continue 170 | 171 | print(f"Processing [{index+1}/{len(html_files)}]: {html_file}") 172 | html_file_path = os.path.join(html_folder, html_file) 173 | success = capture_screenshot(html_file_path, screenshot_folder) 174 | if not success: 175 | failed_indices.append(index) 176 | 177 | # Print processing statistics 178 | print("\nProcessing Statistics:") 179 | print(f"Total files: {len(html_files)}") 180 | print(f"Skipped: {skipped_count}") 181 | print(f"Newly processed: {len(html_files) - skipped_count - len(failed_indices)}") 182 | 183 | # Print failed files 184 | if failed_indices: 185 | print("\nFailed files indices:") 186 | for idx in failed_indices: 187 | print(f"Index {idx}: {html_files[idx]}") 188 | print(f"Failed count: {len(failed_indices)}") 189 | else: 190 | print("\nAll files processed successfully!") 191 | end_time = time.time() 192 | print(f"Total processing time: {end_time - start_time} seconds") -------------------------------------------------------------------------------- /generate_response/internvl2_5_qa.py: -------------------------------------------------------------------------------- 1 | """ 2 | InternVL2_5 QA Generation Tool 3 | 4 | Usage: 5 | 1. Place your data files (.parquet) in a 'data' folder in the same directory 6 | 2. Run this script: python internvl2_5_qa.py 7 | 3. Results will be saved in the 'results' folder 8 | 9 | You can configure model settings, categories, and other parameters at the bottom of this script. 10 | """ 11 | 12 | import os 13 | import json 14 | import base64 15 | import glob 16 | import pandas as pd 17 | from PIL import Image 18 | from io import BytesIO 19 | from vllm import LLM, SamplingParams 20 | 21 | def load_parquet_data(file_path): 22 | """Load data from a single Parquet file.""" 23 | try: 24 | df = pd.read_parquet(file_path) 25 | return df.to_dict('records') 26 | except Exception as e: 27 | print(f"Failed to load file {file_path}: {e}") 28 | return [] 29 | 30 | def load_existing_results(output_path): 31 | """Load existing results file to skip already processed data.""" 32 | if os.path.exists(output_path): 33 | try: 34 | with open(output_path, 'r', encoding='utf-8') as f: 35 | return json.load(f) 36 | except Exception as e: 37 | print(f"Failed to load existing results from {output_path}: {e}") 38 | return [] 39 | 40 | def save_results_to_json(results, output_path="output.json"): 41 | """Save results to a JSON file.""" 42 | try: 43 | with open(output_path, 'w', encoding='utf-8') as f: 44 | json.dump(results, f, ensure_ascii=False, indent=4) 45 | print(f"Results saved to: {output_path}") 46 | except Exception as e: 47 | print(f"Failed to save results to {output_path}: {e}") 48 | 49 | def create_prompt(prompt, question, choices): 50 | """Create a prompt for InternVL2_5""" 51 | return f"USER: \n{prompt}\n'Question:'{question}\n'Choices:'{choices}\nASSISTANT:" 52 | 53 | def process_data(data_list, llm, temperature, max_tokens, existing_results=None, output_path=None): 54 | """Process data list, generate model responses, and save results in real-time.""" 55 | results = [] 56 | sampling_params = SamplingParams( 57 | temperature=temperature, 58 | max_tokens=max_tokens, 59 | ) 60 | 61 | # Create a set of processed question IDs to skip 62 | processed_ids = set() 63 | if existing_results: 64 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')} 65 | results = existing_results.copy() 66 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip these.") 67 | 68 | total_items = len(data_list) 69 | newly_processed_count = 0 70 | 71 | for i, item in enumerate(data_list): 72 | question_id = item.get('Question_id') 73 | 74 | # Skip already processed questions 75 | if question_id in processed_ids: 76 | continue 77 | 78 | newly_processed_count += 1 79 | 80 | try: 81 | # Decode base64 image 82 | base64_image = item['Image'] 83 | image_bytes = base64.b64decode(base64_image) 84 | image_io = BytesIO(image_bytes) 85 | # Convert to PIL image object 86 | image = Image.open(image_io) 87 | 88 | prompt = item['Prompt'] 89 | question = item['Question'] 90 | choices = item['Choices'] 91 | 92 | # Create prompt 93 | internvl_prompt = create_prompt(prompt, question, choices) 94 | 95 | # Single image inference 96 | inputs = { 97 | "prompt": internvl_prompt, 98 | "multi_modal_data": { 99 | "image": image 100 | }, 101 | } 102 | 103 | # Generate response 104 | outputs = llm.generate([inputs], sampling_params=sampling_params) 105 | response = outputs[0].outputs[0].text.strip() 106 | 107 | result = { 108 | "Question_id": question_id, 109 | "Response": response, 110 | "Answer": item.get('Answer'), 111 | "Category": item.get('Category'), 112 | "Png_id": item.get('Png_id') 113 | } 114 | 115 | results.append(result) 116 | 117 | # Save results after every 10 new questions or at the last new question 118 | if newly_processed_count > 0 and (newly_processed_count % 10 == 0 or i == total_items - 1): 119 | if output_path: 120 | save_results_to_json(results, output_path) 121 | print(f"Processed {i + 1}/{total_items} questions (new: {newly_processed_count}), saving results to {output_path}.") 122 | 123 | except Exception as e: 124 | print(f"Error processing Question_id {question_id}: {e}") 125 | result = { 126 | "Question_id": question_id, 127 | "Response": f"Error: {e}", 128 | "Answer": item.get('Answer'), 129 | "Category": item.get('Category'), 130 | "Png_id": item.get('Png_id') 131 | } 132 | results.append(result) 133 | 134 | # Save results when an error occurs 135 | if output_path: 136 | save_results_to_json(results, output_path) 137 | print(f"Error processing Question_id {question_id}, current results saved to {output_path}.") 138 | 139 | # Ensure final results are saved if any new items were processed 140 | if newly_processed_count > 0 and output_path: 141 | save_results_to_json(results, output_path) 142 | print(f"File processing complete, final results saved to: {output_path}") 143 | 144 | return results 145 | 146 | def main(data_folder, model_path, output_base_dir, categories=None, max_tokens=256, temperature=0, tensor_parallel_size=8): 147 | """Main function to process Parquet datasets and generate responses, creating a separate output file for each dataset.""" 148 | # Ensure output directory exists 149 | os.makedirs(output_base_dir, exist_ok=True) 150 | 151 | # Find all Parquet files 152 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet")) 153 | files_to_process = [] 154 | 155 | # Filter files by categories if specified 156 | if categories: 157 | # Ensure category names don't include .parquet extension for matching 158 | category_basenames = {cat.replace(".parquet", "") for cat in categories} 159 | for file_path in all_parquet_files: 160 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0] 161 | if basename_no_ext in category_basenames: 162 | files_to_process.append(file_path) 163 | print(f"Will process specified category files: {files_to_process}") 164 | # Check for any categories not found 165 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process} 166 | missing_categories = category_basenames - found_basenames 167 | if missing_categories: 168 | print(f"Warning: The following specified category files were not found: {missing_categories}") 169 | 170 | else: 171 | files_to_process = all_parquet_files 172 | print(f"Will process all .parquet files in {data_folder} folder.") 173 | 174 | if not files_to_process: 175 | print("No Parquet files found to process.") 176 | return 177 | 178 | # --- Model Loading --- 179 | print("Loading model...") 180 | try: 181 | llm = LLM( 182 | model=model_path, 183 | max_model_len=32768, 184 | tensor_parallel_size=tensor_parallel_size, 185 | enforce_eager=True 186 | ) 187 | print("Model loaded successfully.") 188 | except Exception as e: 189 | print(f"Error loading model: {e}") 190 | return 191 | 192 | # --- Process each file --- 193 | for file_path in files_to_process: 194 | print("-" * 50) 195 | print(f"Starting to process file: {file_path}") 196 | 197 | # 1. Dynamically generate output filename from input filename 198 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet" 199 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA" 200 | # Combine "internvl2_5_" prefix with filename to create JSON filename 201 | output_filename = f"internvl2_5_{dataset_name}.json" 202 | current_output_path = os.path.join(output_base_dir, output_filename) 203 | print(f"Results will be saved to: {current_output_path}") 204 | 205 | # 2. Load existing results for current file 206 | existing_results_for_current_file = load_existing_results(current_output_path) 207 | 208 | # 3. Load Parquet data 209 | data_list = load_parquet_data(file_path) 210 | 211 | # 4. Process data 212 | if data_list: 213 | process_data( 214 | data_list, 215 | llm, 216 | temperature, 217 | max_tokens, 218 | existing_results=existing_results_for_current_file, 219 | output_path=current_output_path 220 | ) 221 | print(f"File {file_path} processing completed.") 222 | else: 223 | print(f"No data in file {file_path} or loading failed.") 224 | 225 | print("=" * 50) 226 | print("Processing of all specified files completed.") 227 | 228 | if __name__ == "__main__": 229 | # Relative paths 230 | data_folder = "data" # Data folder path 231 | model_path = "models/InternVL2_5-78B" # Model weights path 232 | output_dir = "results" # Output directory for results 233 | 234 | categories = ["Real-world_QA", "Synthetic_QA", "Multi-window_QA"] # Specify dataset categories to process 235 | # categories = None # Set to None to process all parquet files 236 | 237 | max_tokens = 1024 # Maximum tokens to generate 238 | temperature = 0 # Temperature parameter 239 | tensor_parallel_size = 8 # Use 8 GPUs in parallel 240 | 241 | # Call main function with output directory 242 | main(data_folder, model_path, output_dir, categories, max_tokens, temperature, tensor_parallel_size) -------------------------------------------------------------------------------- /generate_response/internvl3_qa.py: -------------------------------------------------------------------------------- 1 | """ 2 | InternVL3 QA Inference Script 3 | 4 | Usage: 5 | python internvl3_qa.py [--data_folder DATA_PATH] [--model_path MODEL_PATH] [--output_dir OUTPUT_DIR] 6 | [--categories CATEGORY1 CATEGORY2 ...] [--max_tokens MAX_TOKENS] 7 | [--temperature TEMP] [--tensor_parallel_size TP_SIZE] 8 | 9 | Description: 10 | This script processes parquet files containing image-text QA data and generates responses using InternVL3. 11 | It supports processing specific categories of data and continues from previous runs. 12 | """ 13 | 14 | import os 15 | import json 16 | import base64 17 | import glob 18 | import pandas as pd 19 | from PIL import Image 20 | from io import BytesIO 21 | from vllm import LLM, SamplingParams 22 | 23 | def load_parquet_data(file_path): 24 | """Load data from a single Parquet file.""" 25 | try: 26 | df = pd.read_parquet(file_path) 27 | return df.to_dict('records') 28 | except Exception as e: 29 | print(f"Failed to load file {file_path}: {e}") 30 | return [] 31 | 32 | def load_existing_results(output_path): 33 | """Load existing results file to skip already processed data.""" 34 | if os.path.exists(output_path): 35 | try: 36 | with open(output_path, 'r', encoding='utf-8') as f: 37 | return json.load(f) 38 | except Exception as e: 39 | print(f"Failed to load existing results file {output_path}: {e}") 40 | return [] 41 | 42 | def save_results_to_json(results, output_path="output.json"): 43 | """Save results to a JSON file.""" 44 | try: 45 | with open(output_path, 'w', encoding='utf-8') as f: 46 | json.dump(results, f, ensure_ascii=False, indent=4) 47 | print(f"Results saved to: {output_path}") 48 | except Exception as e: 49 | print(f"Failed to save results to {output_path}: {e}") 50 | 51 | def create_prompt(prompt, question, choices): 52 | """Create prompt for InternVL3""" 53 | return f"USER: \n{prompt}\n'Question:'{question}\n'Choices:'{choices}\nASSISTANT:" 54 | 55 | def process_data(data_list, llm, temperature, max_tokens, existing_results=None, output_path=None): 56 | """Process data list, generate model responses, and save results in real-time.""" 57 | results = [] 58 | sampling_params = SamplingParams( 59 | temperature=temperature, 60 | max_tokens=max_tokens, 61 | ) 62 | 63 | # Create a set of processed question IDs to skip 64 | processed_ids = set() 65 | if existing_results: 66 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')} 67 | results = existing_results.copy() 68 | print(f"Loaded {len(processed_ids)} previously processed questions (from {output_path}), will skip these.") 69 | 70 | total_items = len(data_list) 71 | newly_processed_count = 0 72 | 73 | for i, item in enumerate(data_list): 74 | question_id = item.get('Question_id') 75 | 76 | # Skip already processed questions 77 | if question_id in processed_ids: 78 | continue 79 | 80 | newly_processed_count += 1 81 | 82 | try: 83 | # Decode base64 image 84 | base64_image = item['Image'] 85 | image_bytes = base64.b64decode(base64_image) 86 | image_io = BytesIO(image_bytes) 87 | # Convert to PIL image object 88 | image = Image.open(image_io) 89 | 90 | prompt = item['Prompt'] 91 | question = item['Question'] 92 | choices = item['Choices'] 93 | 94 | # Create prompt 95 | internvl_prompt = create_prompt(prompt, question, choices) 96 | 97 | # Single image inference 98 | inputs = { 99 | "prompt": internvl_prompt, 100 | "multi_modal_data": { 101 | "image": image 102 | }, 103 | } 104 | 105 | # Generate response 106 | outputs = llm.generate([inputs], sampling_params=sampling_params) 107 | response = outputs[0].outputs[0].text.strip() 108 | 109 | result = { 110 | "Question_id": question_id, 111 | "Response": response, 112 | "Answer": item.get('Answer'), 113 | "Category": item.get('Category'), 114 | "Png_id": item.get('Png_id') 115 | } 116 | 117 | results.append(result) 118 | 119 | # Save results every 10 new questions or at the last question 120 | if newly_processed_count > 0 and (newly_processed_count % 10 == 0 or i == total_items - 1): 121 | if output_path: 122 | save_results_to_json(results, output_path) 123 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), saved to {output_path}.") 124 | 125 | except Exception as e: 126 | print(f"Error processing Question_id {question_id}: {e}") 127 | result = { 128 | "Question_id": question_id, 129 | "Response": f"Processing error: {e}", 130 | "Answer": item.get('Answer'), 131 | "Category": item.get('Category'), 132 | "Png_id": item.get('Png_id') 133 | } 134 | results.append(result) 135 | 136 | # Save results when an error occurs 137 | if output_path: 138 | save_results_to_json(results, output_path) 139 | print(f"Error processing Question_id {question_id}, saved current results to {output_path}.") 140 | 141 | # Ensure final results are saved if any new items were processed 142 | if newly_processed_count > 0 and output_path: 143 | save_results_to_json(results, output_path) 144 | print(f"File processing complete, final results saved to: {output_path}") 145 | 146 | return results 147 | 148 | def main(data_folder="./data", model_path="./model", output_base_dir="./results", 149 | categories=None, max_tokens=256, temperature=0, tensor_parallel_size=8): 150 | """Main function to process Parquet datasets and generate responses.""" 151 | # Ensure output directory exists 152 | os.makedirs(output_base_dir, exist_ok=True) 153 | 154 | # Find all Parquet files 155 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet")) 156 | files_to_process = [] 157 | 158 | # Filter files based on categories 159 | if categories: 160 | category_basenames = {cat.replace(".parquet", "") for cat in categories} 161 | for file_path in all_parquet_files: 162 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0] 163 | if basename_no_ext in category_basenames: 164 | files_to_process.append(file_path) 165 | print(f"Will process specific category files: {files_to_process}") 166 | # Check for missing categories 167 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process} 168 | missing_categories = category_basenames - found_basenames 169 | if missing_categories: 170 | print(f"Warning: The following specified categories were not found: {missing_categories}") 171 | else: 172 | files_to_process = all_parquet_files 173 | print(f"Will process all .parquet files in {data_folder}.") 174 | 175 | if not files_to_process: 176 | print("No Parquet files found to process.") 177 | return 178 | 179 | # Load model 180 | print("Loading model...") 181 | try: 182 | llm = LLM( 183 | model=model_path, 184 | tensor_parallel_size=tensor_parallel_size, 185 | max_model_len=32768, 186 | enforce_eager=True 187 | ) 188 | print("Model loaded successfully.") 189 | except Exception as e: 190 | print(f"Error loading model: {e}") 191 | return 192 | 193 | # Process each file 194 | for file_path in files_to_process: 195 | print("-" * 50) 196 | print(f"Processing file: {file_path}") 197 | 198 | # Generate output filename from input filename 199 | base_name = os.path.basename(file_path) 200 | dataset_name = os.path.splitext(base_name)[0] 201 | output_filename = f"internvl3_{dataset_name}.json" 202 | current_output_path = os.path.join(output_base_dir, output_filename) 203 | print(f"Results will be saved to: {current_output_path}") 204 | 205 | # Load existing results for current file 206 | existing_results_for_current_file = load_existing_results(current_output_path) 207 | 208 | # Load Parquet data 209 | data_list = load_parquet_data(file_path) 210 | 211 | # Process data 212 | if data_list: 213 | process_data( 214 | data_list, 215 | llm, 216 | temperature, 217 | max_tokens, 218 | existing_results=existing_results_for_current_file, 219 | output_path=current_output_path 220 | ) 221 | print(f"File {file_path} processing complete.") 222 | else: 223 | print(f"No data in file {file_path} or loading failed.") 224 | 225 | print("=" * 50) 226 | print("All specified files have been processed.") 227 | 228 | if __name__ == "__main__": 229 | import argparse 230 | 231 | parser = argparse.ArgumentParser(description="InternVL3 QA Inference") 232 | parser.add_argument("--data_folder", default="./data", help="Path to data folder containing parquet files") 233 | parser.add_argument("--model_path", default="./model", help="Path to model weights") 234 | parser.add_argument("--output_dir", default="./results", help="Directory to save results") 235 | parser.add_argument("--categories", nargs="+", default=["Real-world_QA", "Synthetic_QA", "Multi-window_QA"], 236 | help="Categories to process (parquet file names without extension)") 237 | parser.add_argument("--max_tokens", type=int, default=1024, help="Maximum tokens to generate") 238 | parser.add_argument("--temperature", type=float, default=0, help="Temperature parameter") 239 | parser.add_argument("--tensor_parallel_size", type=int, default=8, help="Number of GPUs for tensor parallelism") 240 | 241 | args = parser.parse_args() 242 | 243 | main( 244 | args.data_folder, 245 | args.model_path, 246 | args.output_dir, 247 | args.categories, 248 | args.max_tokens, 249 | args.temperature, 250 | args.tensor_parallel_size 251 | ) -------------------------------------------------------------------------------- /generate_response/openai_qa.py: -------------------------------------------------------------------------------- 1 | """ 2 | OpenAI Vision API Question Answering Script 3 | 4 | Usage: 5 | 1. Place your dataset parquet files in the './datasets' folder 6 | 2. Configure your OpenAI API key in the main function 7 | 3. Set the desired model name (default: gpt-4o-2024-11-20) 8 | 4. Run the script: python openai_qa.py 9 | 5. Results will be saved to './results/openai' directory 10 | 11 | The script processes parquet files containing images (base64 encoded) and questions, 12 | sends them to OpenAI vision API, and saves the responses as JSON files. 13 | """ 14 | 15 | import os 16 | import json 17 | import glob 18 | import pandas as pd 19 | from openai import OpenAI 20 | import time 21 | 22 | def load_parquet_data(file_path): 23 | try: 24 | df = pd.read_parquet(file_path) 25 | return df.to_dict('records') 26 | except Exception as e: 27 | print(f"Failed to load file {file_path}: {e}") 28 | return [] 29 | 30 | def load_existing_results(output_path): 31 | if os.path.exists(output_path): 32 | try: 33 | with open(output_path, 'r', encoding='utf-8') as f: 34 | return json.load(f) 35 | except Exception as e: 36 | print(f"Failed to load existing results from {output_path}: {e}") 37 | return [] 38 | 39 | def save_results_to_json(results, output_path="output.json"): 40 | try: 41 | with open(output_path, 'w', encoding='utf-8') as f: 42 | json.dump(results, f, ensure_ascii=False, indent=4) 43 | print(f"Results saved to: {output_path}") 44 | except Exception as e: 45 | print(f"Failed to save results to {output_path}: {e}") 46 | 47 | 48 | def process_data(data_list, client, model_name, temperature, max_tokens, existing_results=None, output_path=None): 49 | results = [] 50 | 51 | processed_ids = set() 52 | if existing_results: 53 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')} 54 | results = existing_results.copy() 55 | print(f"Loaded {len(processed_ids)} previously processed questions (from {output_path}), will skip them.") 56 | 57 | total_items = len(data_list) 58 | newly_processed_count = 0 59 | 60 | for i, item in enumerate(data_list): 61 | question_id = item.get('Question_id') 62 | 63 | if question_id in processed_ids: 64 | print(f"Skipping already processed question ID: {question_id}") 65 | continue 66 | 67 | newly_processed_count += 1 68 | 69 | try: 70 | base64_image = item['Image'] 71 | 72 | prompt = item['Prompt'] 73 | question = item['Question'] 74 | choices = item['Choices'] 75 | 76 | prompt_text = f"{prompt}\nQuestion: {question}\nChoices: {choices}" 77 | 78 | try: 79 | response = client.chat.completions.create( 80 | model=model_name, 81 | messages=[ 82 | { 83 | "role": "user", 84 | "content": [ 85 | { 86 | "type": "text", 87 | "text": prompt_text 88 | }, 89 | { 90 | "type": "image_url", 91 | "image_url": { 92 | "url": f"data:image/jpeg;base64,{base64_image}" 93 | } 94 | } 95 | ] 96 | } 97 | ], 98 | temperature=temperature, 99 | max_tokens=max_tokens 100 | ) 101 | 102 | model_response = response.choices[0].message.content 103 | 104 | except Exception as api_error: 105 | print(f"API call error, attempting retry: {api_error}") 106 | time.sleep(2) 107 | try: 108 | response = client.chat.completions.create( 109 | model=model_name, 110 | messages=[ 111 | { 112 | "role": "user", 113 | "content": [ 114 | { 115 | "type": "text", 116 | "text": prompt_text 117 | }, 118 | { 119 | "type": "image_url", 120 | "image_url": { 121 | "url": f"data:image/jpeg;base64,{base64_image}" 122 | } 123 | } 124 | ] 125 | } 126 | ], 127 | temperature=temperature, 128 | max_tokens=max_tokens 129 | ) 130 | model_response = response.choices[0].message.content 131 | except Exception as retry_error: 132 | raise Exception(f"Retry failed: {retry_error}") 133 | 134 | result = { 135 | "Question_id": question_id, 136 | "Response": model_response, 137 | "Answer": item.get('Answer'), 138 | "Category": item.get('Category'), 139 | "Png_id": item.get('Png_id') 140 | } 141 | 142 | results.append(result) 143 | 144 | # Save results every 5 new questions or at the last item 145 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 146 | if output_path: 147 | save_results_to_json(results, output_path) 148 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), results saved to {output_path}.") 149 | 150 | except Exception as e: 151 | print(f"Error processing Question_id {question_id}: {e}") 152 | result = { 153 | "Question_id": question_id, 154 | "Response": f"Processing error: {e}", 155 | "Answer": item.get('Answer'), 156 | "Category": item.get('Category'), 157 | "Png_id": item.get('Png_id') 158 | } 159 | results.append(result) 160 | 161 | # Save results when error occurs 162 | if output_path: 163 | save_results_to_json(results, output_path) 164 | print(f"Error processing Question_id {question_id}, current results saved to {output_path}.") 165 | 166 | # Ensure final results are saved at the end if any new items were processed 167 | if newly_processed_count > 0 and output_path: 168 | save_results_to_json(results, output_path) 169 | print(f"File processing complete, final results saved to: {output_path}") 170 | 171 | return results 172 | 173 | def main(data_folder, api_key, model_name, output_base_dir, categories=None, max_tokens=1024, temperature=0): 174 | """Main function to process parquet datasets and generate output files for each dataset.""" 175 | # Ensure output directory exists 176 | os.makedirs(output_base_dir, exist_ok=True) 177 | 178 | # Find all parquet files 179 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet")) 180 | files_to_process = [] 181 | 182 | # Filter files by categories if specified 183 | if categories: 184 | category_basenames = {cat.replace(".parquet", "") for cat in categories} 185 | for file_path in all_parquet_files: 186 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0] 187 | if basename_no_ext in category_basenames: 188 | files_to_process.append(file_path) 189 | print(f"Will process specified category files: {files_to_process}") 190 | 191 | # Check for missing categories 192 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process} 193 | missing_categories = category_basenames - found_basenames 194 | if missing_categories: 195 | print(f"Warning: The following specified categories were not found: {missing_categories}") 196 | else: 197 | files_to_process = all_parquet_files 198 | print(f"Will process all .parquet files in the {data_folder} folder.") 199 | 200 | if not files_to_process: 201 | print("No parquet files found to process.") 202 | return 203 | 204 | # Initialize OpenAI API client 205 | print("Initializing OpenAI API client...") 206 | try: 207 | client = OpenAI(api_key=api_key) 208 | print("OpenAI API client initialized successfully.") 209 | except Exception as e: 210 | print(f"Error initializing OpenAI API client: {e}") 211 | return 212 | 213 | # Process each file 214 | for file_path in files_to_process: 215 | print("-" * 50) 216 | print(f"Starting to process file: {file_path}") 217 | 218 | # Generate output filename from input filename 219 | base_name = os.path.basename(file_path) 220 | dataset_name = os.path.splitext(base_name)[0] 221 | output_filename = f"openai_{dataset_name}.json" 222 | current_output_path = os.path.join(output_base_dir, output_filename) 223 | print(f"Results will be saved to: {current_output_path}") 224 | 225 | # Load existing results for the current file 226 | existing_results_for_current_file = load_existing_results(current_output_path) 227 | 228 | # Load parquet data 229 | data_list = load_parquet_data(file_path) 230 | 231 | # Process data 232 | if data_list: 233 | process_data( 234 | data_list, 235 | client, 236 | model_name, 237 | temperature, 238 | max_tokens, 239 | existing_results=existing_results_for_current_file, 240 | output_path=current_output_path 241 | ) 242 | print(f"File {file_path} processing completed.") 243 | else: 244 | print(f"No data found or failed to load file {file_path}.") 245 | 246 | print("=" * 50) 247 | print("All specified files have been processed.") 248 | 249 | if __name__ == "__main__": 250 | data_folder = "./datasets" 251 | api_key = "your_api_key" 252 | model_name = "gpt-4o-2024-11-20" 253 | 254 | output_dir = "./results/openai" 255 | 256 | categories = ["Real-world_QA", "Synthetic_QA", "Multi-window_QA"] 257 | 258 | max_tokens = 300 259 | temperature = 0 260 | 261 | main(data_folder, api_key, model_name, output_dir, categories, max_tokens, temperature) -------------------------------------------------------------------------------- /generate_response/llava_qa.py: -------------------------------------------------------------------------------- 1 | """ 2 | LLaVA QA Processor 3 | 4 | This script processes vision-language question answering tasks using the LLaVA model. 5 | It handles multiple parquet datasets, generating responses for image-based questions. 6 | 7 | Usage: 8 | python llava_qa.py [--data_folder DATA_PATH] [--model_path MODEL_PATH] 9 | [--output_dir OUTPUT_PATH] [--categories CATEGORY1 CATEGORY2 ...] 10 | [--max_tokens MAX_TOKENS] [--temperature TEMP] [--tensor_parallel_size GPUS] 11 | 12 | Example: 13 | python llava_qa.py --data_folder "./datasets" --model_path "./models/llava-model" 14 | --output_dir "./results" --categories "Real-world_QA" "Synthetic_QA" 15 | """ 16 | 17 | import os 18 | import json 19 | import base64 20 | import glob 21 | import pandas as pd 22 | from PIL import Image 23 | from io import BytesIO 24 | from vllm import LLM, SamplingParams 25 | import torch 26 | 27 | def load_parquet_data(file_path): 28 | """Load data from a single Parquet file.""" 29 | try: 30 | df = pd.read_parquet(file_path) 31 | return df.to_dict('records') 32 | except Exception as e: 33 | print(f"Failed to load file {file_path}: {e}") 34 | return [] 35 | 36 | def load_existing_results(output_path): 37 | """Load existing results file to skip already processed data.""" 38 | if os.path.exists(output_path): 39 | try: 40 | with open(output_path, 'r', encoding='utf-8') as f: 41 | return json.load(f) 42 | except Exception as e: 43 | print(f"Failed to load existing results file {output_path}: {e}") 44 | return [] 45 | 46 | def save_results_to_json(results, output_path="output.json"): 47 | """Save results to a JSON file.""" 48 | try: 49 | with open(output_path, 'w', encoding='utf-8') as f: 50 | json.dump(results, f, ensure_ascii=False, indent=4) 51 | print(f"Results saved to: {output_path}") 52 | except Exception as e: 53 | print(f"Failed to save results to {output_path}: {e}") 54 | 55 | def create_prompt(prompt, question, choices): 56 | """Create prompt for the LLaVA model""" 57 | return f"<|im_start|>user \n{prompt}\n'Question:'{question}\n'Choices:'{choices}<|im_end|> <|im_start|>assistant\n" 58 | # return f"<|im_start|>user \n请你描述这张图片<|im_end|> <|im_start|>assistant\n" 59 | 60 | def process_data(data_list, llm, temperature, max_tokens, existing_results=None, output_path=None): 61 | """Process data list, generate model responses, and save results in real-time.""" 62 | results = [] 63 | sampling_params = SamplingParams( 64 | temperature=temperature, 65 | max_tokens=max_tokens, 66 | ) 67 | 68 | # Create a set of processed question IDs to skip 69 | processed_ids = set() 70 | if existing_results: 71 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')} 72 | results = existing_results.copy() 73 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip these questions.") 74 | 75 | total_items = len(data_list) 76 | newly_processed_count = 0 77 | 78 | for i, item in enumerate(data_list): 79 | question_id = item.get('Question_id') 80 | 81 | # Skip already processed questions 82 | if question_id in processed_ids: 83 | continue 84 | 85 | newly_processed_count += 1 86 | 87 | try: 88 | # Decode base64 image 89 | base64_image = item['Image'] 90 | image_bytes = base64.b64decode(base64_image) 91 | image_io = BytesIO(image_bytes) 92 | # Convert to PIL image object 93 | image = Image.open(image_io) 94 | 95 | prompt = item['Prompt'] 96 | question = item['Question'] 97 | choices = item['Choices'] 98 | 99 | # Create prompt 100 | llava_prompt = create_prompt(prompt, question, choices) 101 | 102 | # print("**************************************llava_prompt: ", llava_prompt) 103 | 104 | # Single image inference 105 | inputs = { 106 | "prompt": llava_prompt, 107 | "multi_modal_data": { 108 | "image": image 109 | }, 110 | } 111 | 112 | # Generate response 113 | outputs = llm.generate([inputs], sampling_params=sampling_params) 114 | response = outputs[0].outputs[0].text.strip() 115 | 116 | result = { 117 | "Question_id": question_id, 118 | "Response": response, 119 | "Answer": item.get('Answer'), 120 | "Category": item.get('Category'), 121 | "Png_id": item.get('Png_id') 122 | } 123 | 124 | results.append(result) 125 | 126 | # Save results every 10 *new* questions or at the last *new* question 127 | if newly_processed_count > 0 and (newly_processed_count % 10 == 0 or i == total_items - 1): 128 | if output_path: 129 | save_results_to_json(results, output_path) 130 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), saved results to {output_path}.") 131 | 132 | except Exception as e: 133 | print(f"Error processing Question_id {question_id}: {e}") 134 | result = { 135 | "Question_id": question_id, 136 | "Response": f"Processing error: {e}", 137 | "Answer": item.get('Answer'), 138 | "Category": item.get('Category'), 139 | "Png_id": item.get('Png_id') 140 | } 141 | results.append(result) 142 | 143 | # Save results when an error occurs 144 | if output_path: 145 | save_results_to_json(results, output_path) 146 | print(f"Error processing Question_id {question_id}, saved current results to {output_path}.") 147 | 148 | # Ensure final results are saved at the end of function if any new items were processed 149 | if newly_processed_count > 0 and output_path: 150 | save_results_to_json(results, output_path) 151 | print(f"Completed processing file, final results saved to: {output_path}") 152 | 153 | return results 154 | 155 | def main(data_folder, model_path, output_base_dir, categories=None, max_tokens=256, temperature=0, tensor_parallel_size=8): 156 | """Main function to process specified Parquet datasets and generate inferences.""" 157 | # Ensure output directory exists 158 | os.makedirs(output_base_dir, exist_ok=True) 159 | 160 | # Find all Parquet files 161 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet")) 162 | files_to_process = [] 163 | 164 | # Filter files by categories 165 | if categories: 166 | # Ensure category names don't include .parquet suffix for matching 167 | category_basenames = {cat.replace(".parquet", "") for cat in categories} 168 | for file_path in all_parquet_files: 169 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0] 170 | if basename_no_ext in category_basenames: 171 | files_to_process.append(file_path) 172 | print(f"Will process specified category files: {files_to_process}") 173 | # Check for missing categories 174 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process} 175 | missing_categories = category_basenames - found_basenames 176 | if missing_categories: 177 | print(f"Warning: The following specified category files were not found: {missing_categories}") 178 | else: 179 | files_to_process = all_parquet_files 180 | print(f"Will process all .parquet files in folder {data_folder}.") 181 | 182 | if not files_to_process: 183 | print("No Parquet files found to process.") 184 | return 185 | 186 | # --- Model Loading --- 187 | print("Loading model...") 188 | try: 189 | llm = LLM( 190 | model=model_path, 191 | tensor_parallel_size=tensor_parallel_size, 192 | max_model_len=32768, 193 | limit_mm_per_prompt={"image": 1, "video": 0}, 194 | enforce_eager=True 195 | ) 196 | print("Model loaded successfully.") 197 | except Exception as e: 198 | print(f"Error loading model: {e}") 199 | return 200 | 201 | # --- Process each file --- 202 | for file_path in files_to_process: 203 | print("-" * 50) 204 | print(f"Starting to process file: {file_path}") 205 | 206 | # 1. Dynamically generate output filename from input filename 207 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet" 208 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA" 209 | # Combine "llava_" prefix with filename to create JSON filename 210 | output_filename = f"llava_{dataset_name}.json" 211 | current_output_path = os.path.join(output_base_dir, output_filename) 212 | print(f"Results will be saved to: {current_output_path}") 213 | 214 | # 2. Load existing results for current file 215 | existing_results_for_current_file = load_existing_results(current_output_path) 216 | 217 | # 3. Load Parquet data 218 | data_list = load_parquet_data(file_path) 219 | 220 | # 4. Process data 221 | if data_list: 222 | process_data( 223 | data_list, 224 | llm, 225 | temperature, 226 | max_tokens, 227 | existing_results=existing_results_for_current_file, 228 | output_path=current_output_path 229 | ) 230 | print(f"File {file_path} processing completed.") 231 | else: 232 | print(f"No data in file {file_path} or loading failed.") 233 | 234 | print("=" * 50) 235 | print("Processing of all specified files completed.") 236 | 237 | if __name__ == "__main__": 238 | import argparse 239 | 240 | parser = argparse.ArgumentParser(description='Process vision-language QA tasks with LLaVA') 241 | parser.add_argument('--data_folder', default='./datasets', help='Path to data folder') 242 | parser.add_argument('--model_path', default='./models/llava-model', help='Path to model weights') 243 | parser.add_argument('--output_dir', default='./results', help='Directory to save results') 244 | parser.add_argument('--categories', nargs='+', default=["Real-world_QA", "Synthetic_QA", "Multi-window_QA"], 245 | help='Categories to process (default: all)') 246 | parser.add_argument('--max_tokens', type=int, default=1024, help='Maximum tokens to generate') 247 | parser.add_argument('--temperature', type=float, default=0, help='Temperature parameter') 248 | parser.add_argument('--tensor_parallel_size', type=int, default=8, help='Number of GPUs for parallel inference') 249 | 250 | args = parser.parse_args() 251 | 252 | main(args.data_folder, args.model_path, args.output_dir, 253 | args.categories, args.max_tokens, args.temperature, args.tensor_parallel_size) -------------------------------------------------------------------------------- /generate_response/o1_qa.py: -------------------------------------------------------------------------------- 1 | """ 2 | O1_QA.py - OpenAI Vision Model Evaluation Tool 3 | 4 | Usage: 5 | 1. Place your Parquet datasets in the './mini_datasets' folder 6 | 2. Set your OpenAI API key in the main function 7 | 3. Configure model name, output directory and categories as needed 8 | 4. Run the script: python o1_qa.py 9 | 10 | This script processes image question-answering datasets in Parquet format, 11 | queries the OpenAI vision model, and saves results to JSON files. 12 | """ 13 | 14 | import os 15 | import json 16 | import base64 17 | import glob 18 | import pandas as pd 19 | from PIL import Image 20 | from io import BytesIO 21 | from openai import OpenAI 22 | import time 23 | 24 | def load_parquet_data(file_path): 25 | """Load data from a single Parquet file.""" 26 | try: 27 | df = pd.read_parquet(file_path) 28 | return df.to_dict('records') 29 | except Exception as e: 30 | print(f"Failed to load file {file_path}: {e}") 31 | return [] 32 | 33 | def load_existing_results(output_path): 34 | """Load existing results file to skip already processed data.""" 35 | if os.path.exists(output_path): 36 | try: 37 | with open(output_path, 'r', encoding='utf-8') as f: 38 | return json.load(f) 39 | except Exception as e: 40 | print(f"Failed to load existing results file {output_path}: {e}") 41 | return [] 42 | 43 | def save_results_to_json(results, output_path="output.json"): 44 | """Save results to a JSON file.""" 45 | try: 46 | with open(output_path, 'w', encoding='utf-8') as f: 47 | json.dump(results, f, ensure_ascii=False, indent=4) 48 | print(f"Results saved to: {output_path}") 49 | except Exception as e: 50 | print(f"Failed to save results to {output_path}: {e}") 51 | 52 | 53 | def process_data(data_list, client, model_name, existing_results=None, output_path=None): 54 | """Process data list, generate model responses, and save results in real-time.""" 55 | results = [] 56 | 57 | # Create a set of processed question IDs to skip 58 | processed_ids = set() 59 | if existing_results: 60 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')} 61 | results = existing_results.copy() 62 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip them.") 63 | 64 | total_items = len(data_list) 65 | newly_processed_count = 0 66 | 67 | for i, item in enumerate(data_list): 68 | question_id = item.get('Question_id') 69 | 70 | # Skip already processed questions 71 | if question_id in processed_ids: 72 | print(f"Skipping already processed question ID: {question_id}") 73 | continue 74 | 75 | newly_processed_count += 1 76 | 77 | try: 78 | # Get base64 image 79 | base64_image = item['Image'] 80 | 81 | prompt = item['Prompt'] 82 | question = item['Question'] 83 | choices = item['Choices'] 84 | 85 | # Build prompt text 86 | prompt_text = f"{prompt}\nQuestion: {question}\nChoices: {choices}" 87 | 88 | # Build API request 89 | try: 90 | response = client.chat.completions.create( 91 | model=model_name, 92 | messages=[ 93 | { 94 | "role": "user", 95 | "content": [ 96 | { 97 | "type": "text", 98 | "text": prompt_text 99 | }, 100 | { 101 | "type": "image_url", 102 | "image_url": { 103 | "url": f"data:image/jpeg;base64,{base64_image}" 104 | } 105 | } 106 | ] 107 | } 108 | ] 109 | ) 110 | 111 | # Get response text 112 | model_response = response.choices[0].message.content 113 | 114 | except Exception as api_error: 115 | print(f"API call error, trying to retry: {api_error}") 116 | # Simple retry mechanism 117 | time.sleep(2) 118 | try: 119 | response = client.chat.completions.create( 120 | model=model_name, 121 | messages=[ 122 | { 123 | "role": "user", 124 | "content": [ 125 | { 126 | "type": "text", 127 | "text": prompt_text 128 | }, 129 | { 130 | "type": "image_url", 131 | "image_url": { 132 | "url": f"data:image/jpeg;base64,{base64_image}" 133 | } 134 | } 135 | ] 136 | } 137 | ] 138 | ) 139 | model_response = response.choices[0].message.content 140 | except Exception as retry_error: 141 | raise Exception(f"Retry failed: {retry_error}") 142 | 143 | result = { 144 | "Question_id": question_id, 145 | "Response": model_response, 146 | "Answer": item.get('Answer'), 147 | "Category": item.get('Category'), 148 | "Png_id": item.get('Png_id') 149 | } 150 | 151 | results.append(result) 152 | 153 | # Save results every 5 new questions or at the last new question 154 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 155 | if output_path: 156 | save_results_to_json(results, output_path) 157 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), saved results to {output_path}.") 158 | 159 | except Exception as e: 160 | print(f"Error processing Question_id {question_id}: {e}") 161 | result = { 162 | "Question_id": question_id, 163 | "Response": f"Processing error: {e}", 164 | "Answer": item.get('Answer'), 165 | "Category": item.get('Category'), 166 | "Png_id": item.get('Png_id') 167 | } 168 | results.append(result) 169 | 170 | # Save results when an error occurs 171 | if output_path: 172 | save_results_to_json(results, output_path) 173 | print(f"Error processing Question_id {question_id}, saved current results to {output_path}.") 174 | 175 | # Ensure final results are saved at the end of function if any new items were processed 176 | if newly_processed_count > 0 and output_path: 177 | save_results_to_json(results, output_path) 178 | print(f"Completed processing file, final results saved to: {output_path}") 179 | 180 | return results 181 | 182 | def main(data_folder, api_key, model_name, output_base_dir, categories=None): 183 | """Main function to process Parquet datasets in the specified folder and generate separate output files for each dataset.""" 184 | # Ensure output directory exists 185 | os.makedirs(output_base_dir, exist_ok=True) 186 | 187 | # Find all Parquet files 188 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet")) 189 | files_to_process = [] 190 | 191 | # Filter files by categories if specified 192 | if categories: 193 | # Ensure category names don't include .parquet extension 194 | category_basenames = {cat.replace(".parquet", "") for cat in categories} 195 | for file_path in all_parquet_files: 196 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0] 197 | if basename_no_ext in category_basenames: 198 | files_to_process.append(file_path) 199 | print(f"Will process specified category files: {files_to_process}") 200 | # Check for missing categories 201 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process} 202 | missing_categories = category_basenames - found_basenames 203 | if missing_categories: 204 | print(f"Warning: The following specified category files were not found: {missing_categories}") 205 | 206 | else: 207 | files_to_process = all_parquet_files 208 | print(f"Will process all .parquet files in the {data_folder} folder.") 209 | 210 | if not files_to_process: 211 | print("No Parquet files found to process.") 212 | return 213 | 214 | # Initialize OpenAI API client 215 | print("Initializing OpenAI API client...") 216 | try: 217 | client = OpenAI(api_key=api_key) 218 | print("OpenAI API client initialized successfully.") 219 | except Exception as e: 220 | print(f"Error initializing OpenAI API client: {e}") 221 | return 222 | 223 | # Process each file 224 | for file_path in files_to_process: 225 | print("-" * 50) 226 | print(f"Starting to process file: {file_path}") 227 | 228 | # 1. Generate output filename from input filename 229 | base_name = os.path.basename(file_path) 230 | dataset_name = os.path.splitext(base_name)[0] 231 | output_filename = f"o1_{dataset_name}.json" 232 | current_output_path = os.path.join(output_base_dir, output_filename) 233 | print(f"Results will be saved to: {current_output_path}") 234 | 235 | # 2. Load existing results for current file 236 | existing_results_for_current_file = load_existing_results(current_output_path) 237 | 238 | # 3. Load Parquet data 239 | data_list = load_parquet_data(file_path) 240 | 241 | # 4. Process data 242 | if data_list: 243 | process_data( 244 | data_list, 245 | client, 246 | model_name, 247 | existing_results=existing_results_for_current_file, 248 | output_path=current_output_path 249 | ) 250 | print(f"File {file_path} processing completed.") 251 | else: 252 | print(f"No data in file {file_path} or loading failed.") 253 | 254 | print("=" * 50) 255 | print("Processing of all specified files completed.") 256 | 257 | if __name__ == "__main__": 258 | data_folder = "./mini_datasets" # Data folder path 259 | api_key = "your_api_key" # OpenAI API key 260 | model_name = "o1" # OpenAI model name 261 | 262 | output_dir = "./results/o1" # Output directory for results 263 | 264 | categories = ["Real-world_QA_mini", "Synthetic_QA_mini", "Multi-window_QA_mini"] 265 | 266 | # Call the main function 267 | main(data_folder, api_key, model_name, output_dir, categories) -------------------------------------------------------------------------------- /generate_response/qwen_qa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Qwen Vision-Language Model Question Answering Script 3 | 4 | Usage: 5 | python qwen_qa.py [--data_folder DATA_FOLDER] [--model_path MODEL_PATH] 6 | [--output_dir OUTPUT_DIR] [--categories CATEGORIES] 7 | [--max_tokens MAX_TOKENS] [--temperature TEMPERATURE] 8 | [--tensor_parallel_size TENSOR_PARALLEL_SIZE] 9 | 10 | This script processes image-based questions from parquet files, generates responses 11 | using Qwen Vision-Language model, and saves results to JSON files. 12 | """ 13 | 14 | import os 15 | import json 16 | import base64 17 | import glob 18 | import pandas as pd 19 | from PIL import Image 20 | from io import BytesIO 21 | from vllm import LLM, SamplingParams 22 | from transformers import AutoProcessor 23 | from qwen_vl_utils import process_vision_info 24 | 25 | # --- load_parquet_data, load_existing_results, save_results_to_json, process_data 函数保持不变 --- 26 | 27 | def load_parquet_data(file_path): 28 | """Load data from a single Parquet file.""" 29 | try: 30 | df = pd.read_parquet(file_path) 31 | return df.to_dict('records') 32 | except Exception as e: 33 | print(f"Failed to load file {file_path}: {e}") 34 | return [] 35 | 36 | def load_existing_results(output_path): 37 | """Load existing results file to skip already processed data.""" 38 | if os.path.exists(output_path): 39 | try: 40 | with open(output_path, 'r', encoding='utf-8') as f: 41 | return json.load(f) 42 | except Exception as e: 43 | print(f"Failed to load existing results file {output_path}: {e}") 44 | return [] 45 | 46 | def save_results_to_json(results, output_path="output.json"): 47 | """Save results to a JSON file.""" 48 | try: 49 | with open(output_path, 'w', encoding='utf-8') as f: 50 | json.dump(results, f, ensure_ascii=False, indent=4) 51 | print(f"Results saved to: {output_path}") 52 | except Exception as e: 53 | print(f"Failed to save results to {output_path}: {e}") 54 | 55 | 56 | def process_data(data_list, llm, processor, temperature, max_tokens, existing_results=None, output_path=None): 57 | """Process data list, generate model responses, and save results in real-time.""" 58 | results = [] 59 | sampling_params = SamplingParams( 60 | temperature=temperature, 61 | max_tokens=max_tokens, 62 | ) 63 | 64 | # Create a set of processed question IDs to skip 65 | processed_ids = set() 66 | if existing_results: 67 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')} 68 | results = existing_results.copy() # Use existing results for this file 69 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip them.") 70 | 71 | total_items = len(data_list) 72 | newly_processed_count = 0 73 | 74 | for i, item in enumerate(data_list): 75 | question_id = item.get('Question_id') 76 | 77 | # Skip already processed questions 78 | if question_id in processed_ids: 79 | continue 80 | 81 | newly_processed_count += 1 # Count newly processed questions 82 | 83 | try: 84 | # Decode base64 image 85 | base64_image = item['Image'] 86 | image_bytes = base64.b64decode(base64_image) 87 | image_io = BytesIO(image_bytes) 88 | # Convert to PIL Image object 89 | image = Image.open(image_io) 90 | 91 | prompt = item['Prompt'] 92 | question = item['Question'] 93 | choices = item['Choices'] 94 | 95 | # Build messages 96 | messages = [ 97 | {"role": "user", "content": [ 98 | {"type": "image", "image": image}, 99 | {"type": "text", "text": f"{prompt}\n'Question:'{question}\n'Choices:'+{choices}"} 100 | ]} 101 | ] 102 | 103 | # Process messages 104 | chat_prompt = processor.apply_chat_template( 105 | messages, 106 | tokenize=False, 107 | add_generation_prompt=True, 108 | ) 109 | 110 | image_inputs, video_inputs = process_vision_info(messages) 111 | 112 | mm_data = {} 113 | if image_inputs is not None: 114 | mm_data["image"] = image_inputs 115 | 116 | llm_inputs = { 117 | "prompt": chat_prompt, 118 | "multi_modal_data": mm_data, 119 | } 120 | 121 | # Generate response 122 | outputs = llm.generate([llm_inputs], sampling_params=sampling_params) 123 | response = outputs[0].outputs[0].text.strip() 124 | 125 | result = { 126 | "Question_id": question_id, 127 | "Response": response, 128 | "Answer": item.get('Answer'), 129 | "Category": item.get('Category'), 130 | "Png_id": item.get('Png_id') 131 | } 132 | 133 | results.append(result) # Add new result to the list 134 | 135 | # Save results every 10 *new* questions or at the last *new* question 136 | if newly_processed_count > 0 and (newly_processed_count % 10 == 0 or i == total_items - 1): 137 | if output_path: 138 | save_results_to_json(results, output_path) 139 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), saving results to {output_path}.") 140 | 141 | except Exception as e: 142 | print(f"Error processing Question_id {question_id}: {e}") 143 | result = { 144 | "Question_id": question_id, 145 | "Response": f"Processing error: {e}", # Record error message 146 | "Answer": item.get('Answer'), 147 | "Category": item.get('Category'), 148 | "Png_id": item.get('Png_id') 149 | } 150 | results.append(result) 151 | 152 | # Save results when an error occurs 153 | if output_path: 154 | save_results_to_json(results, output_path) 155 | print(f"Error processing Question_id {question_id}, current results saved to {output_path}.") 156 | 157 | # Ensure final results are saved at the end of function if any new items were processed 158 | if newly_processed_count > 0 and output_path: 159 | save_results_to_json(results, output_path) 160 | print(f"File processing complete, final results saved to: {output_path}") 161 | 162 | return results # Return the complete list of results for this file 163 | 164 | def main(data_folder, model_path, output_base_dir, categories=None, max_tokens=256, temperature=0.1, tensor_parallel_size=8): 165 | """Main function to process specified parquet datasets and generate responses, with separate output files for each dataset.""" 166 | # Ensure base output directory exists 167 | os.makedirs(output_base_dir, exist_ok=True) 168 | 169 | # Find all Parquet files 170 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet")) 171 | files_to_process = [] 172 | 173 | # Filter files by categories 174 | if categories: 175 | # Ensure category names don't include .parquet suffix for matching 176 | category_basenames = {cat.replace(".parquet", "") for cat in categories} 177 | for file_path in all_parquet_files: 178 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0] 179 | if basename_no_ext in category_basenames: 180 | files_to_process.append(file_path) 181 | print(f"Will process specified category files: {files_to_process}") 182 | # Check for missing categories 183 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process} 184 | missing_categories = category_basenames - found_basenames 185 | if missing_categories: 186 | print(f"Warning: The following specified category files were not found in the folder: {missing_categories}") 187 | else: 188 | files_to_process = all_parquet_files 189 | print(f"Will process all .parquet files in folder {data_folder}.") 190 | 191 | if not files_to_process: 192 | print("No Parquet files found to process.") 193 | return 194 | 195 | # Load model 196 | print("Loading model...") 197 | try: 198 | llm = LLM( 199 | model=model_path, 200 | tensor_parallel_size=tensor_parallel_size, 201 | max_model_len=32768, 202 | limit_mm_per_prompt={"image": 1, "video": 0}, 203 | enforce_eager=True 204 | ) 205 | processor = AutoProcessor.from_pretrained(model_path) 206 | print("Model loaded successfully.") 207 | except Exception as e: 208 | print(f"Error loading model: {e}") 209 | return 210 | 211 | # Process each file 212 | for file_path in files_to_process: 213 | print("-" * 50) 214 | print(f"Starting to process file: {file_path}") 215 | 216 | # 1. Dynamically generate output filename from input filename 217 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet" 218 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA" 219 | # Combine "qwen_" prefix with filename to make a JSON filename 220 | output_filename = f"qwen_{dataset_name}.json" 221 | current_output_path = os.path.join(output_base_dir, output_filename) # Complete output path 222 | print(f"Results will be saved to: {current_output_path}") 223 | 224 | # 2. Load existing results for current file 225 | existing_results_for_current_file = load_existing_results(current_output_path) 226 | 227 | # 3. Load Parquet data 228 | data_list = load_parquet_data(file_path) 229 | 230 | # 4. Process data 231 | if data_list: 232 | # Pass dynamically generated output path and loaded existing results to process_data 233 | process_data( 234 | data_list, 235 | llm, 236 | processor, 237 | temperature, 238 | max_tokens, 239 | existing_results=existing_results_for_current_file, # Pass existing results for current file 240 | output_path=current_output_path # Pass output path for current file 241 | ) 242 | print(f"File {file_path} processing complete.") 243 | else: 244 | print(f"No data in file {file_path} or file loading failed.") 245 | 246 | print("=" * 50) 247 | print("Processing complete for all specified files.") 248 | 249 | if __name__ == "__main__": 250 | # Use relative paths 251 | data_folder = "data/datasets" # Data folder path 252 | model_path = "models/Qwen2.5-VL-72B-Instruct" # Model weights path 253 | output_dir = "results" # Output directory for results 254 | 255 | # Specify categories to process (only provide base names, e.g., "Multi-window_QA") 256 | categories = ["Real-world_QA", "Synthetic_QA", "Multi-window_QA"] 257 | # categories = None # Set to None to process all parquet files 258 | 259 | max_tokens = 1024 # Maximum tokens to generate 260 | temperature = 0 # Temperature parameter 261 | tensor_parallel_size = 8 # Use 8 GPUs in parallel 262 | 263 | # Call main function, passing the output directory 264 | main(data_folder, model_path, output_dir, categories, max_tokens, temperature, tensor_parallel_size) -------------------------------------------------------------------------------- /generate_response/gemini_qa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gemini API Visual Question Answering Tool 3 | 4 | This script processes various QA datasets using Google's Gemini API: 5 | - Real-world QA 6 | - Synthetic QA 7 | - Multi-window QA 8 | 9 | Usage: 10 | 1. Set your API key in the main function 11 | 2. Make sure datasets are in the ./datasets folder 12 | 3. Specify which categories to process 13 | 4. Run the script: python gemini_qa.py 14 | 15 | Results will be saved to ./results/gemini directory. 16 | """ 17 | 18 | import os 19 | import json 20 | import base64 21 | import glob 22 | import pandas as pd 23 | from io import BytesIO 24 | from google import genai 25 | from google.genai import types 26 | import time 27 | 28 | def load_parquet_data(file_path): 29 | """Load data from a single Parquet file.""" 30 | try: 31 | df = pd.read_parquet(file_path) 32 | return df.to_dict('records') 33 | except Exception as e: 34 | print(f"Failed to load file {file_path}: {e}") 35 | return [] 36 | 37 | def load_existing_results(output_path): 38 | """Load existing results file to skip already processed data.""" 39 | if os.path.exists(output_path): 40 | try: 41 | with open(output_path, 'r', encoding='utf-8') as f: 42 | return json.load(f) 43 | except Exception as e: 44 | print(f"Failed to load existing results file {output_path}: {e}") 45 | return [] 46 | 47 | def save_results_to_json(results, output_path="output.json"): 48 | """Save results to JSON file.""" 49 | try: 50 | with open(output_path, 'w', encoding='utf-8') as f: 51 | json.dump(results, f, ensure_ascii=False, indent=4) 52 | print(f"Results saved to: {output_path}") 53 | except Exception as e: 54 | print(f"Failed to save results to {output_path}: {e}") 55 | 56 | 57 | def process_data(data_list, client, model_name, temperature, max_tokens, existing_results=None, output_path=None): 58 | """Process data list, generate model responses, and save results in real-time.""" 59 | results = [] 60 | 61 | # Create a set of processed question IDs to skip 62 | processed_ids = set() 63 | if existing_results: 64 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')} 65 | results = existing_results.copy() # Use existing results for this file 66 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip these questions.") 67 | 68 | total_items = len(data_list) 69 | newly_processed_count = 0 70 | 71 | for i, item in enumerate(data_list): 72 | question_id = item.get('Question_id') 73 | 74 | # Skip already processed questions 75 | if question_id in processed_ids: 76 | print(f"Skipping already processed question ID: {question_id}") 77 | continue 78 | 79 | newly_processed_count += 1 # Count newly processed questions 80 | 81 | try: 82 | # Decode base64 image 83 | base64_image = item['Image'] 84 | image_bytes = base64.b64decode(base64_image) 85 | 86 | prompt = item['Prompt'] 87 | question = item['Question'] 88 | choices = item['Choices'] 89 | 90 | # Build prompt text 91 | prompt_text = f"{prompt}\nQuestion: {question}\nChoices: {choices}" 92 | 93 | # Build API request 94 | try: 95 | response = client.models.generate_content( 96 | model=model_name, 97 | contents=[ 98 | prompt_text, 99 | types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg") 100 | ], 101 | config=types.GenerateContentConfig( 102 | temperature=temperature, 103 | thinking_config=thinking_config, 104 | max_output_tokens=max_tokens 105 | ) 106 | ) 107 | 108 | # Get response text 109 | model_response = response.text.strip() 110 | 111 | except Exception as api_error: 112 | print(f"API call error, attempting retry: {api_error}") 113 | # Simple retry mechanism 114 | time.sleep(60) 115 | try: 116 | response = client.models.generate_content( 117 | model=model_name, 118 | contents=[ 119 | prompt_text, 120 | types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg") 121 | ], 122 | config=types.GenerateContentConfig( 123 | temperature=temperature, 124 | thinking_config=thinking_config, 125 | max_output_tokens=max_tokens 126 | ) 127 | ) 128 | model_response = response.text.strip() 129 | except Exception as retry_error: 130 | raise Exception(f"Retry failed: {retry_error}") 131 | 132 | result = { 133 | "Question_id": question_id, 134 | "Response": model_response, 135 | "Answer": item.get('Answer'), 136 | "Category": item.get('Category'), 137 | "Png_id": item.get('Png_id') 138 | } 139 | 140 | results.append(result) # Add new result to list 141 | 142 | # Save results every 5 new questions or at the last new question 143 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 144 | if output_path: 145 | save_results_to_json(results, output_path) # Save to specified file 146 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), real-time saving to {output_path}.") 147 | time.sleep(60) 148 | 149 | except Exception as e: 150 | print(f"Error processing Question_id {question_id}: {e}") 151 | result = { 152 | "Question_id": question_id, 153 | "Response": f"Processing error: {e}", # Record error message 154 | "Answer": item.get('Answer'), 155 | "Category": item.get('Category'), 156 | "Png_id": item.get('Png_id') 157 | } 158 | results.append(result) 159 | 160 | # Also save results when an error occurs 161 | if output_path: 162 | save_results_to_json(results, output_path) 163 | print(f"Error processing Question_id {question_id}, current results saved to {output_path}.") 164 | 165 | # Ensure final results are saved at the end of the function if any new items were processed 166 | if newly_processed_count > 0 and output_path: 167 | save_results_to_json(results, output_path) 168 | print(f"Finished processing file, final results saved to: {output_path}") 169 | 170 | return results # Return complete result list for this file 171 | 172 | def main(data_folder, api_key, model_name, output_base_dir, categories=None, max_tokens=1024, temperature=0): 173 | """Main function, processes Parquet datasets in the specified folder and generates a separate output file for each dataset.""" 174 | # Ensure the base output directory exists 175 | os.makedirs(output_base_dir, exist_ok=True) 176 | 177 | # Find all Parquet files 178 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet")) 179 | files_to_process = [] 180 | 181 | # Filter files based on categories 182 | if categories: 183 | # Ensure category names don't include .parquet suffix for matching 184 | category_basenames = {cat.replace(".parquet", "") for cat in categories} 185 | for file_path in all_parquet_files: 186 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0] 187 | if basename_no_ext in category_basenames: 188 | files_to_process.append(file_path) 189 | print(f"Will process specified category files: {files_to_process}") 190 | # Check for missing categories 191 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process} 192 | missing_categories = category_basenames - found_basenames 193 | if missing_categories: 194 | print(f"Warning: The following specified category files were not found in the folder: {missing_categories}") 195 | 196 | else: 197 | files_to_process = all_parquet_files 198 | print(f"Will process all .parquet files in the folder {data_folder}.") 199 | 200 | if not files_to_process: 201 | print("No Parquet files found to process.") 202 | return 203 | 204 | # --- Initialize Gemini API client --- 205 | print("Initializing Gemini API client...") 206 | try: 207 | client = genai.Client(api_key=api_key) 208 | print("Gemini API client initialized successfully.") 209 | except Exception as e: 210 | print(f"Error initializing Gemini API client: {e}") 211 | return 212 | 213 | # --- Process each file --- 214 | for file_path in files_to_process: 215 | print("-" * 50) 216 | print(f"Starting to process file: {file_path}") 217 | 218 | # 1. Dynamically generate output filename from input filename 219 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet" 220 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA" 221 | # Combine "gemini_" prefix with the file name to create a JSON filename 222 | output_filename = f"gemini_{dataset_name}.json" 223 | current_output_path = os.path.join(output_base_dir, output_filename) # Full output path 224 | print(f"Results will be saved to: {current_output_path}") 225 | 226 | # 2. Load existing results for current file 227 | existing_results_for_current_file = load_existing_results(current_output_path) 228 | 229 | # 3. Load Parquet data 230 | data_list = load_parquet_data(file_path) 231 | 232 | # 4. Process data 233 | if data_list: 234 | # Pass dynamically generated output path and loaded existing results to process_data 235 | process_data( 236 | data_list, 237 | client, 238 | model_name, 239 | temperature, 240 | max_tokens, 241 | existing_results=existing_results_for_current_file, # Pass existing results for current file 242 | output_path=current_output_path # Pass output path for current file 243 | ) 244 | print(f"File {file_path} processing completed.") 245 | else: 246 | print(f"No data in file {file_path} or loading failed.") 247 | 248 | print("=" * 50) 249 | print("All specified files processing completed.") 250 | 251 | if __name__ == "__main__": 252 | data_folder = "./datasets" # Data folder path 253 | api_key = "your_api_key" 254 | model_name = "gemini-2.5-flash-preview-04-17" # Gemini model name 255 | 256 | output_dir = "./results/gemini" # Directory to save results 257 | 258 | # Specify which dataset categories to process (just provide base names like "Multi-window_QA") 259 | categories = ["Real-world_QA", "Synthetic_QA", "Multi-window_QA"] 260 | 261 | max_tokens = 300 # Maximum tokens to generate 262 | temperature = 0 # Temperature parameter 263 | thinking_config = types.ThinkingConfig(thinking_budget=0) 264 | 265 | # Call main function 266 | main(data_folder, api_key, model_name, output_dir, categories, max_tokens, temperature) -------------------------------------------------------------------------------- /generate_response/o4-mini_qa.py: -------------------------------------------------------------------------------- 1 | """ 2 | O4-Mini QA System 3 | 4 | This script processes image-based question answering tasks using OpenAI's o4-mini model. 5 | 6 | Usage: 7 | 1. Place your dataset parquet files in the ./datasets folder 8 | 2. Set your OpenAI API key in the script 9 | 3. Specify the categories you want to process 10 | 4. Run the script: python o4-mini_qa.py 11 | 12 | Results will be saved to ./results/o4-mini directory. 13 | """ 14 | 15 | import os 16 | import json 17 | import base64 18 | import glob 19 | import pandas as pd 20 | from PIL import Image 21 | from io import BytesIO 22 | from openai import OpenAI 23 | import time 24 | 25 | def load_parquet_data(file_path): 26 | """Load data from a single Parquet file.""" 27 | try: 28 | df = pd.read_parquet(file_path) 29 | return df.to_dict('records') 30 | except Exception as e: 31 | print(f"Failed to load file {file_path}: {e}") 32 | return [] 33 | 34 | def load_existing_results(output_path): 35 | """Load existing results file to skip already processed data.""" 36 | if os.path.exists(output_path): 37 | try: 38 | with open(output_path, 'r', encoding='utf-8') as f: 39 | return json.load(f) 40 | except Exception as e: 41 | print(f"Failed to load existing results file {output_path}: {e}") 42 | return [] 43 | 44 | def save_results_to_json(results, output_path="output.json"): 45 | """Save results to JSON file.""" 46 | try: 47 | with open(output_path, 'w', encoding='utf-8') as f: 48 | json.dump(results, f, ensure_ascii=False, indent=4) 49 | print(f"Results saved to: {output_path}") 50 | except Exception as e: 51 | print(f"Failed to save results to {output_path}: {e}") 52 | 53 | 54 | def process_data(data_list, client, model_name, existing_results=None, output_path=None): 55 | """Process data list, generate model responses, and save results in real-time.""" 56 | results = [] 57 | 58 | # Create a set of processed question IDs to skip 59 | processed_ids = set() 60 | if existing_results: 61 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')} 62 | results = existing_results.copy() # Use existing results passed in 63 | print(f"Loaded {len(processed_ids)} previously processed questions (from {output_path}), will skip these questions.") 64 | 65 | total_items = len(data_list) 66 | newly_processed_count = 0 67 | 68 | for i, item in enumerate(data_list): 69 | question_id = item.get('Question_id') 70 | 71 | # Skip already processed questions 72 | if question_id in processed_ids: 73 | print(f"Skipping already processed question ID: {question_id}") 74 | continue 75 | 76 | newly_processed_count += 1 # Count newly processed questions 77 | 78 | try: 79 | # Get base64 image 80 | base64_image = item['Image'] 81 | 82 | prompt = item['Prompt'] 83 | question = item['Question'] 84 | choices = item['Choices'] 85 | 86 | # Build prompt text 87 | prompt_text = f"{prompt}\nQuestion: {question}\nChoices: {choices}" 88 | 89 | # Build API request 90 | try: 91 | response = client.chat.completions.create( 92 | model=model_name, 93 | messages=[ 94 | { 95 | "role": "user", 96 | "content": [ 97 | { 98 | "type": "text", 99 | "text": prompt_text 100 | }, 101 | { 102 | "type": "image_url", 103 | "image_url": { 104 | "url": f"data:image/jpeg;base64,{base64_image}" 105 | } 106 | } 107 | ] 108 | } 109 | ] 110 | ) 111 | 112 | # Get response text 113 | model_response = response.choices[0].message.content 114 | 115 | except Exception as api_error: 116 | print(f"API call error, trying again: {api_error}") 117 | # Simple retry mechanism 118 | time.sleep(2) 119 | try: 120 | response = client.chat.completions.create( 121 | model=model_name, 122 | messages=[ 123 | { 124 | "role": "user", 125 | "content": [ 126 | { 127 | "type": "text", 128 | "text": prompt_text 129 | }, 130 | { 131 | "type": "image_url", 132 | "image_url": { 133 | "url": f"data:image/jpeg;base64,{base64_image}" 134 | } 135 | } 136 | ] 137 | } 138 | ] 139 | ) 140 | model_response = response.choices[0].message.content 141 | except Exception as retry_error: 142 | raise Exception(f"Retry failed: {retry_error}") 143 | 144 | result = { 145 | "Question_id": question_id, 146 | "Response": model_response, 147 | "Answer": item.get('Answer'), 148 | "Category": item.get('Category'), 149 | "Png_id": item.get('Png_id') 150 | } 151 | 152 | results.append(result) # Add new result to the list 153 | 154 | # Save results every 5 *new* questions or at the last *new* question 155 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 156 | if output_path: 157 | save_results_to_json(results, output_path) 158 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), real-time saved results to {output_path}.") 159 | 160 | except Exception as e: 161 | print(f"Error processing Question_id {question_id}: {e}") 162 | result = { 163 | "Question_id": question_id, 164 | "Response": f"Processing error: {e}", # Record error message 165 | "Answer": item.get('Answer'), 166 | "Category": item.get('Category'), 167 | "Png_id": item.get('Png_id') 168 | } 169 | results.append(result) 170 | 171 | # Save results when error occurs 172 | if output_path: 173 | save_results_to_json(results, output_path) 174 | print(f"Error processing Question_id {question_id}, current results saved to {output_path}.") 175 | 176 | # Ensure final results are saved if any new items were processed 177 | if newly_processed_count > 0 and output_path: 178 | save_results_to_json(results, output_path) 179 | print(f"File processing complete, final results saved to: {output_path}") 180 | 181 | return results # Return complete results list for this file 182 | 183 | def main(data_folder, api_key, model_name, output_base_dir, categories=None): 184 | """Main function to process specified Parquet datasets and generate separate output files for each.""" 185 | # Ensure base output directory exists 186 | os.makedirs(output_base_dir, exist_ok=True) 187 | 188 | # Find all Parquet files 189 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet")) 190 | files_to_process = [] 191 | 192 | # Filter files by categories 193 | if categories: 194 | # Ensure category names don't include .parquet suffix for matching 195 | category_basenames = {cat.replace(".parquet", "") for cat in categories} 196 | for file_path in all_parquet_files: 197 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0] 198 | if basename_no_ext in category_basenames: 199 | files_to_process.append(file_path) 200 | print(f"Will process specified category files: {files_to_process}") 201 | # Check for any missing categories 202 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process} 203 | missing_categories = category_basenames - found_basenames 204 | if missing_categories: 205 | print(f"Warning: The following specified category files were not found in the folder: {missing_categories}") 206 | 207 | else: 208 | files_to_process = all_parquet_files 209 | print(f"Will process all .parquet files in folder {data_folder}.") 210 | 211 | if not files_to_process: 212 | print("No Parquet files found to process.") 213 | return 214 | 215 | # --- Initialize OpenAI API client --- 216 | print("Initializing OpenAI API client...") 217 | try: 218 | client = OpenAI(api_key=api_key) 219 | print("OpenAI API client initialized successfully.") 220 | except Exception as e: 221 | print(f"Error initializing OpenAI API client: {e}") 222 | return 223 | 224 | # --- Process each file --- 225 | for file_path in files_to_process: 226 | print("-" * 50) 227 | print(f"Starting to process file: {file_path}") 228 | 229 | # 1. Dynamically generate output filename from input filename 230 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet" 231 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA" 232 | # Combine "o4mini_" prefix with filename to create JSON filename 233 | output_filename = f"o4mini_{dataset_name}.json" 234 | current_output_path = os.path.join(output_base_dir, output_filename) # Complete output path 235 | print(f"Results will be saved to: {current_output_path}") 236 | 237 | # 2. Load existing results for current file 238 | existing_results_for_current_file = load_existing_results(current_output_path) 239 | 240 | # 3. Load Parquet data 241 | data_list = load_parquet_data(file_path) 242 | 243 | # 4. Process data 244 | if data_list: 245 | # Pass dynamically generated output path and loaded existing results to process_data 246 | process_data( 247 | data_list, 248 | client, 249 | model_name, 250 | existing_results=existing_results_for_current_file, # Pass existing results for current file 251 | output_path=current_output_path # Pass output path for current file 252 | ) 253 | print(f"File {file_path} processing complete.") 254 | else: 255 | print(f"No data in file {file_path} or loading failed.") 256 | 257 | print("=" * 50) 258 | print("Processing of all specified files complete.") 259 | 260 | if __name__ == "__main__": 261 | data_folder = "./datasets" # Data folder path 262 | api_key = "your_api_key" 263 | model_name = "o4-mini-2025-04-16" 264 | 265 | output_dir = "./results/o4-mini" # Specify results directory 266 | 267 | categories = ["Synthetic_QA", "Real-world_QA", "Multi-window_QA"] 268 | 269 | # Call main function 270 | main(data_folder, api_key, model_name, output_dir, categories) -------------------------------------------------------------------------------- /generate_response/geminipro_qa.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Gemini Pro QA Evaluation Tool 3 | # 4 | # This script evaluates Gemini Pro model on image-based QA tasks from Parquet datasets. 5 | # 6 | # Usage: 7 | # 1. Install dependencies: pip install google-generativeai pandas pillow 8 | # 2. Configure your API key and dataset parameters in the main section 9 | # 3. Run the script: python geminipro_qa.py 10 | # 11 | # The script will process image data from Parquet files and save results to JSON files. 12 | """ 13 | 14 | import os 15 | import json 16 | import base64 17 | import glob 18 | import pandas as pd 19 | from PIL import Image 20 | from io import BytesIO 21 | from google import genai 22 | from google.genai import types 23 | import time 24 | 25 | def load_parquet_data(file_path): 26 | """Load data from a single Parquet file.""" 27 | try: 28 | df = pd.read_parquet(file_path) 29 | return df.to_dict('records') 30 | except Exception as e: 31 | print(f"Failed to load file {file_path}: {e}") 32 | return [] 33 | 34 | def load_existing_results(output_path): 35 | """Load existing results file to skip already processed data.""" 36 | if os.path.exists(output_path): 37 | try: 38 | with open(output_path, 'r', encoding='utf-8') as f: 39 | return json.load(f) 40 | except Exception as e: 41 | print(f"Failed to load existing results file {output_path}: {e}") 42 | return [] 43 | 44 | def save_results_to_json(results, output_path="output.json"): 45 | """Save results to a JSON file.""" 46 | try: 47 | with open(output_path, 'w', encoding='utf-8') as f: 48 | json.dump(results, f, ensure_ascii=False, indent=4) 49 | print(f"Results saved to: {output_path}") 50 | except Exception as e: 51 | print(f"Failed to save results to {output_path}: {e}") 52 | 53 | 54 | def process_data(data_list, client, model_name, temperature, max_tokens, existing_results=None, output_path=None): 55 | """Process data list, generate model responses, and save results in real-time.""" 56 | results = [] 57 | 58 | # Create a set of processed question IDs to skip 59 | processed_ids = set() 60 | if existing_results: 61 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')} 62 | results = existing_results.copy() # Use existing results for this file 63 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip these.") 64 | 65 | total_items = len(data_list) 66 | newly_processed_count = 0 67 | 68 | for i, item in enumerate(data_list): 69 | question_id = item.get('Question_id') 70 | 71 | # Skip already processed questions 72 | if question_id in processed_ids: 73 | print(f"Skipping already processed question ID: {question_id}") 74 | continue 75 | 76 | newly_processed_count += 1 # Count newly processed questions 77 | 78 | try: 79 | # Decode base64 image 80 | base64_image = item['Image'] 81 | image_bytes = base64.b64decode(base64_image) 82 | 83 | prompt = item['Prompt'] 84 | question = item['Question'] 85 | choices = item['Choices'] 86 | 87 | # Build prompt text 88 | prompt_text = f"{prompt}\nQuestion: {question}\nChoices: {choices}" 89 | 90 | # Build API request 91 | try: 92 | response = client.models.generate_content( 93 | model=model_name, 94 | contents=[ 95 | prompt_text, 96 | types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg") 97 | ], 98 | config=types.GenerateContentConfig( 99 | temperature=temperature, 100 | thinking_config=thinking_config, 101 | max_output_tokens=max_tokens 102 | ) 103 | ) 104 | 105 | # Get response text 106 | model_response = response.text.strip() 107 | 108 | except Exception as api_error: 109 | print(f"API call error, attempting retry: {api_error}") 110 | # Simple retry mechanism 111 | time.sleep(60) 112 | try: 113 | response = client.models.generate_content( 114 | model=model_name, 115 | contents=[ 116 | prompt_text, 117 | types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg") 118 | ], 119 | config=types.GenerateContentConfig( 120 | temperature=temperature, 121 | thinking_config=thinking_config, 122 | max_output_tokens=max_tokens 123 | ) 124 | ) 125 | model_response = response.text.strip() 126 | except Exception as retry_error: 127 | raise Exception(f"Retry failed: {retry_error}") 128 | 129 | result = { 130 | "Question_id": question_id, 131 | "Response": model_response, 132 | "Answer": item.get('Answer'), 133 | "Category": item.get('Category'), 134 | "Png_id": item.get('Png_id') 135 | } 136 | 137 | results.append(result) # Add new result to the list 138 | 139 | # Save results every 5 *new* questions or at the last *new* question 140 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 141 | if output_path: 142 | save_results_to_json(results, output_path) # Save to specified file 143 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), saving results to {output_path}.") 144 | time.sleep(60) 145 | 146 | except Exception as e: 147 | print(f"Error processing Question_id {question_id}: {e}") 148 | result = { 149 | "Question_id": question_id, 150 | "Response": f"Processing error: {e}", # Record error message 151 | "Answer": item.get('Answer'), 152 | "Category": item.get('Category'), 153 | "Png_id": item.get('Png_id') 154 | } 155 | results.append(result) 156 | 157 | # Save results on error 158 | if output_path: 159 | save_results_to_json(results, output_path) 160 | print(f"Error processing Question_id {question_id}, current results saved to {output_path}.") 161 | 162 | # Ensure final results are saved if any new items were processed 163 | if newly_processed_count > 0 and output_path: 164 | save_results_to_json(results, output_path) 165 | print(f"File processing complete, final results saved to: {output_path}") 166 | 167 | return results # Return complete results list for this file 168 | 169 | def main(data_folder, api_key, model_name, output_base_dir, categories=None, max_tokens=1024, temperature=0): 170 | """Main function to process Parquet datasets from a folder and run inference, generating separate output files for each dataset.""" 171 | # Ensure base output directory exists 172 | os.makedirs(output_base_dir, exist_ok=True) 173 | 174 | # Find all Parquet files 175 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet")) 176 | files_to_process = [] 177 | 178 | # Filter files by categories 179 | if categories: 180 | # Ensure category names don't include .parquet suffix for matching 181 | category_basenames = {cat.replace(".parquet", "") for cat in categories} 182 | for file_path in all_parquet_files: 183 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0] 184 | if basename_no_ext in category_basenames: 185 | files_to_process.append(file_path) 186 | print(f"Will process specified category files: {files_to_process}") 187 | # Check for missing categories 188 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process} 189 | missing_categories = category_basenames - found_basenames 190 | if missing_categories: 191 | print(f"Warning: The following specified category files were not found in the folder: {missing_categories}") 192 | 193 | else: 194 | files_to_process = all_parquet_files 195 | print(f"Will process all .parquet files in folder {data_folder}.") 196 | 197 | if not files_to_process: 198 | print("No Parquet files found to process.") 199 | return 200 | 201 | # --- Initialize Gemini API client --- 202 | print("Initializing Gemini API client...") 203 | try: 204 | client = genai.Client(api_key=api_key) 205 | print("Gemini API client initialization complete.") 206 | except Exception as e: 207 | print(f"Error initializing Gemini API client: {e}") 208 | return 209 | 210 | # --- Process each file --- 211 | for file_path in files_to_process: 212 | print("-" * 50) 213 | print(f"Starting to process file: {file_path}") 214 | 215 | # 1. Generate output filename dynamically from input filename 216 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet" 217 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA" 218 | # Combine "gemini_" prefix with filename to form JSON filename 219 | output_filename = f"gemini_{dataset_name}.json" 220 | current_output_path = os.path.join(output_base_dir, output_filename) # Complete output path 221 | print(f"Results will be saved to: {current_output_path}") 222 | 223 | # 2. Load existing results for current file 224 | existing_results_for_current_file = load_existing_results(current_output_path) 225 | 226 | # 3. Load Parquet data 227 | data_list = load_parquet_data(file_path) 228 | 229 | # 4. Process data 230 | if data_list: 231 | # Pass dynamically generated output path and loaded existing results to process_data 232 | process_data( 233 | data_list, 234 | client, 235 | model_name, 236 | temperature, 237 | max_tokens, 238 | existing_results=existing_results_for_current_file, # Pass existing results for current file 239 | output_path=current_output_path # Pass output path for current file 240 | ) 241 | print(f"File {file_path} processing complete.") 242 | else: 243 | print(f"No data in file {file_path} or loading failed.") 244 | 245 | print("=" * 50) 246 | print("Processing of all specified files complete.") 247 | 248 | if __name__ == "__main__": 249 | data_folder = "./datasets" # Data folder path 250 | api_key = "your_api_key" 251 | model_name = "gemini-2.5-pro-preview-05-06" # Gemini model name 252 | 253 | # Output filenames will be generated dynamically based on dataset (e.g., gemini_Multi-window_QA.json) 254 | output_dir = "./results/gemini" # Specify directory to save results 255 | 256 | # Specify dataset categories to process (just provide base names like "Multi-window_QA") 257 | categories = ["Real-world_QA", "Synthetic_QA", "Multi-window_QA"] 258 | 259 | max_tokens = 300 # Maximum tokens to generate 260 | temperature = 0 # Temperature parameter 261 | thinking_config = types.ThinkingConfig(thinking_budget=0) 262 | 263 | # Call main function 264 | main(data_folder, api_key, model_name, output_dir, categories, max_tokens, temperature) -------------------------------------------------------------------------------- /calculate_similarity/clip_score.py: -------------------------------------------------------------------------------- 1 | """ 2 | # CLIP Score Calculator 3 | # 4 | # This script calculates semantic similarity between pairs of images using OpenAI's CLIP model. 5 | # 6 | # Usage: 7 | # 1. Install required packages: pip install torch clip-openai opencv-python pillow tqdm 8 | # 2. Set the source folders in the main section: 9 | # - source_folder1: directory containing model-generated images 10 | # - source_folder2: directory containing reference/ground truth images 11 | # 3. Set the output directory where results will be saved 12 | # 4. Run the script: python clip_score.py 13 | # 14 | # The script will process matching image files in both directories and output JSON files 15 | # with similarity scores for each model. 16 | """ 17 | 18 | # pip install torch clip-openai opencv-python pillow tqdm 19 | import os 20 | import torch 21 | import clip 22 | from PIL import Image, ImageFile 23 | import cv2 24 | import json 25 | import numpy as np 26 | from tqdm import tqdm 27 | import warnings 28 | 29 | # Disable PIL's DecompressionBombWarning 30 | warnings.filterwarnings("ignore", category=Image.DecompressionBombWarning) 31 | # Allow PIL to load truncated images 32 | ImageFile.LOAD_TRUNCATED_IMAGES = True 33 | 34 | # Initialize CLIP model 35 | device = "cuda" if torch.cuda.is_available() else "cpu" 36 | model, preprocess = clip.load("/mnt/petrelfs/sunhaoyu/work/clip_model/ViT-B-32.pt", device=device) 37 | 38 | def safe_open_image(image_path): 39 | """ 40 | Safely open an image, compressing it if pixel count exceeds limit 41 | 42 | Args: 43 | image_path (str): Path to the image file 44 | 45 | Returns: 46 | PIL.Image: Processed image object 47 | """ 48 | try: 49 | # First try to open the image 50 | img = Image.open(image_path) 51 | 52 | # Calculate total pixel count 53 | width, height = img.size 54 | pixels = width * height 55 | 56 | # If pixel count exceeds limit, compress the image 57 | # Use a smaller limit than PIL's warning threshold 58 | max_pixels = 89000000 # Slightly less than PIL warning threshold 89478485 59 | 60 | if pixels > max_pixels: 61 | # Calculate scale ratio 62 | scale = (max_pixels / pixels) ** 0.5 63 | 64 | # Calculate new dimensions 65 | new_width = int(width * scale) 66 | new_height = int(height * scale) 67 | 68 | # Resize the image 69 | img = img.resize((new_width, new_height), Image.LANCZOS) 70 | print(f"Image {os.path.basename(image_path)} has been resized: {width}x{height} -> {new_width}x{new_height}") 71 | 72 | return img 73 | except Exception as e: 74 | print(f"Cannot open image {image_path}: {str(e)}") 75 | # Try using cv2 to load the image as a fallback 76 | try: 77 | print(f"Trying to load image {image_path} with OpenCV") 78 | img_cv = cv2.imread(image_path) 79 | if img_cv is None: 80 | raise Exception("OpenCV cannot load the image") 81 | # Convert BGR to RGB 82 | img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB) 83 | # Convert to PIL image 84 | img = Image.fromarray(img_rgb) 85 | return img 86 | except Exception as cv_error: 87 | print(f"Failed to load image {image_path} with OpenCV: {str(cv_error)}") 88 | raise 89 | 90 | def clip_similarity(image_path1, image_path2): 91 | """ 92 | Calculate CLIP semantic similarity between two images 93 | 94 | Args: 95 | image_path1 (str or PIL.Image): Path to the first image or PIL image object 96 | image_path2 (str or PIL.Image): Path to the second image or PIL image object 97 | 98 | Returns: 99 | similarity (float): Cosine similarity between the two images 100 | """ 101 | # Load and preprocess images 102 | if isinstance(image_path1, str) and isinstance(image_path2, str): 103 | img1 = safe_open_image(image_path1) 104 | img2 = safe_open_image(image_path2) 105 | else: 106 | img1 = image_path1 107 | img2 = image_path2 108 | 109 | # Preprocess images and transfer to device 110 | img1 = preprocess(img1).unsqueeze(0).to(device) 111 | img2 = preprocess(img2).unsqueeze(0).to(device) 112 | 113 | # Extract features using CLIP 114 | with torch.no_grad(): 115 | features1 = model.encode_image(img1) 116 | features2 = model.encode_image(img2) 117 | 118 | # Normalize features 119 | features1 = features1 / features1.norm(p=2, dim=-1, keepdim=True) 120 | features2 = features2 / features2.norm(p=2, dim=-1, keepdim=True) 121 | 122 | # Calculate cosine similarity 123 | similarity = torch.nn.functional.cosine_similarity(features1, features2) 124 | 125 | return similarity.item() 126 | 127 | def load_or_create_json(json_path): 128 | """ 129 | Load or create JSON file 130 | 131 | Args: 132 | json_path (str): Path to JSON file 133 | 134 | Returns: 135 | dict: Loaded data or empty dictionary 136 | """ 137 | if os.path.exists(json_path): 138 | with open(json_path, 'r') as f: 139 | try: 140 | data = json.load(f) 141 | return data 142 | except json.JSONDecodeError: 143 | print(f"JSON file {json_path} parsing error, creating new file") 144 | return {"results": []} 145 | else: 146 | return {"results": []} 147 | 148 | def save_json(data, json_path): 149 | """ 150 | Save JSON data to file 151 | 152 | Args: 153 | data (dict): Data to save 154 | json_path (str): Path to JSON file 155 | """ 156 | with open(json_path, 'w') as f: 157 | json.dump(data, f, indent=2) 158 | print(f"Results saved to {json_path}") 159 | 160 | def update_category_summary(data, category): 161 | """ 162 | Update or create category summary item 163 | 164 | Args: 165 | data (dict): Dictionary containing results 166 | category (str): Dataset category name 167 | 168 | Returns: 169 | bool: Whether a new summary item was created 170 | """ 171 | category_results = [item for item in data["results"] if item.get("Category") == category and item.get("id") != f"{category}_summary"] 172 | 173 | if not category_results: 174 | return False 175 | 176 | # Calculate average scores 177 | clip_scores = [item.get("clip_score", 0) for item in category_results if isinstance(item.get("clip_score", 0), (int, float))] 178 | 179 | avg_clip = np.mean(clip_scores) if clip_scores else 0 180 | 181 | # Check if summary item exists 182 | summary_exists = False 183 | for item in data["results"]: 184 | if item.get("id") == f"{category}_summary": 185 | item["clip_score"] = float(avg_clip) 186 | summary_exists = True 187 | break 188 | 189 | # If it doesn't exist, create new summary item 190 | if not summary_exists: 191 | data["results"].append({ 192 | "id": f"{category}_summary", 193 | "Category": category, 194 | "clip_score": float(avg_clip) 195 | }) 196 | return True 197 | 198 | return False 199 | 200 | def compare_model_datasets(source_folder1, source_folder2, output_dir="./results"): 201 | """ 202 | Compare image similarity across multiple models and datasets 203 | 204 | Args: 205 | source_folder1 (str): Path to first source folder containing model folders 206 | source_folder2 (str): Path to second source folder containing dataset folders 207 | output_dir (str): Output directory for JSON files 208 | """ 209 | # Ensure output directory exists 210 | os.makedirs(output_dir, exist_ok=True) 211 | 212 | # Model list 213 | models = ["claude", "gemini", "internvl2", "internvl3", "llava", "o1", "openai", "qwen", "o4mini","pro"] 214 | # Dataset list 215 | datasets = ["Code Refinement", "Image_to_code", "Interaction_Authoring", "Text_to_code"] 216 | 217 | # Loop through all models 218 | for model in models: 219 | model_folder = os.path.join(source_folder1, model) 220 | if not os.path.exists(model_folder): 221 | print(f"Model folder {model_folder} does not exist, skipping") 222 | continue 223 | 224 | json_path = os.path.join(output_dir, f"{model}_similarity.json") 225 | 226 | # Load or create JSON file 227 | data = load_or_create_json(json_path) 228 | 229 | # Get current maximum ID 230 | current_ids = [item.get("id", "") for item in data["results"] if not isinstance(item.get("id"), str) or not item.get("id", "").endswith("_summary")] 231 | current_max_id = max([int(i) for i in current_ids if str(i).isdigit()] + [0]) 232 | 233 | # Loop through datasets 234 | for dataset in datasets: 235 | dataset_folder1 = os.path.join(model_folder, dataset) 236 | dataset_folder2 = os.path.join(source_folder2, dataset) 237 | 238 | if not os.path.exists(dataset_folder1) or not os.path.exists(dataset_folder2): 239 | print(f"Dataset folder {dataset_folder1} or {dataset_folder2} does not exist, skipping") 240 | continue 241 | 242 | # Get image filenames from both folders 243 | files1 = [f for f in os.listdir(dataset_folder1) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))] 244 | files2 = [f for f in os.listdir(dataset_folder2) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))] 245 | 246 | # Find common filenames between the two folders 247 | common_files = list(set(files1) & set(files2)) 248 | 249 | if not common_files: 250 | print(f"Warning: No matching image files found in dataset {dataset} folders") 251 | continue 252 | 253 | processing_count = 0 254 | 255 | # Calculate similarity for each pair of images 256 | for filename in tqdm(common_files, desc=f"Processing image similarity for {model}/{dataset}"): 257 | img1_path = os.path.join(dataset_folder1, filename) 258 | img2_path = os.path.join(dataset_folder2, filename) 259 | 260 | try: 261 | # Calculate CLIP similarity 262 | clip_score = clip_similarity(img1_path, img2_path) 263 | 264 | # Add result 265 | current_max_id += 1 266 | data["results"].append({ 267 | "id": current_max_id, 268 | "Category": dataset, 269 | "filename": filename, 270 | "clip_score": float(clip_score) 271 | }) 272 | 273 | processing_count += 1 274 | 275 | # Save every 10 processed items 276 | if processing_count % 10 == 0: 277 | # Update category summary 278 | update_category_summary(data, dataset) 279 | save_json(data, json_path) 280 | 281 | except Exception as e: 282 | print(f"Error processing image {filename}: {str(e)}") 283 | 284 | # Update summary item for current dataset 285 | created_new = update_category_summary(data, dataset) 286 | if created_new: 287 | print(f"Created new summary item for {dataset}") 288 | 289 | # Save after processing each dataset 290 | save_json(data, json_path) 291 | 292 | if __name__ == "__main__": 293 | # Source folder 1, containing models and datasets 294 | # Can be replaced when running 295 | source_folder1 = "pro_imgs" 296 | # Source folder 2, containing only datasets 297 | source_folder2 = "label_imgs" 298 | # Output directory 299 | output_dir = "results" 300 | 301 | # Run comparison 302 | compare_model_datasets(source_folder1, source_folder2, output_dir) 303 | -------------------------------------------------------------------------------- /generate_response/claude_qa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Claude QA: A tool for evaluating Claude's performance on image-based question answering tasks. 3 | 4 | Usage: 5 | 1. Place your parquet datasets in the ./datasets folder 6 | 2. Configure your API key and other parameters in the main section 7 | 3. Run: python claude_qa.py 8 | 4. Results will be saved in ./results/claude directory 9 | 10 | The script processes image-based QA datasets, sends queries to Claude API, and saves the responses. 11 | It has checkpoint functionality to resume from interrupted runs. 12 | """ 13 | 14 | import os 15 | import json 16 | import base64 17 | import glob 18 | import pandas as pd 19 | from PIL import Image 20 | from io import BytesIO 21 | from anthropic import Anthropic 22 | import time 23 | import math 24 | 25 | def load_parquet_data(file_path): 26 | """Load data from a single Parquet file.""" 27 | try: 28 | df = pd.read_parquet(file_path) 29 | return df.to_dict('records') 30 | except Exception as e: 31 | print(f"Failed to load file {file_path}: {e}") 32 | return [] 33 | 34 | def load_existing_results(output_path): 35 | """Load existing results file to skip already processed data.""" 36 | if os.path.exists(output_path): 37 | try: 38 | with open(output_path, 'r', encoding='utf-8') as f: 39 | return json.load(f) 40 | except Exception as e: 41 | print(f"Failed to load existing results file {output_path}: {e}") 42 | return [] 43 | 44 | def save_results_to_json(results, output_path="output.json"): 45 | """Save results to a JSON file.""" 46 | try: 47 | with open(output_path, 'w', encoding='utf-8') as f: 48 | json.dump(results, f, ensure_ascii=False, indent=4) 49 | print(f"Results saved to: {output_path}") 50 | except Exception as e: 51 | print(f"Failed to save results to {output_path}: {e}") 52 | 53 | 54 | def process_data(data_list, client, model_name, temperature, max_tokens, existing_results=None, output_path=None): 55 | """Process data list, generate model responses, and save results in real-time.""" 56 | results = [] 57 | 58 | # Create a set of processed question IDs to skip 59 | processed_ids = set() 60 | if existing_results: 61 | processed_ids = {item.get('Question_id') for item in existing_results if item.get('Question_id')} 62 | results = existing_results.copy() # Use existing results for this file 63 | print(f"Loaded {len(processed_ids)} already processed questions (from {output_path}), will skip these questions.") 64 | 65 | total_items = len(data_list) 66 | newly_processed_count = 0 67 | 68 | for i, item in enumerate(data_list): 69 | question_id = item.get('Question_id') 70 | 71 | # Skip already processed questions 72 | if question_id in processed_ids: 73 | print(f"Skipping already processed question ID: {question_id}") 74 | continue 75 | 76 | newly_processed_count += 1 # Count newly processed questions 77 | 78 | try: 79 | # Decode base64 image 80 | base64_image = item['Image'] 81 | 82 | prompt = item['Prompt'] 83 | question = item['Question'] 84 | choices = item['Choices'] 85 | 86 | # Build prompt text 87 | prompt_text = f"{prompt}\nQuestion: {question}\nChoices: {choices}" 88 | 89 | # Build image content object 90 | image_content = { 91 | "type": "image", 92 | "source": { 93 | "type": "base64", 94 | "media_type": "image/png", 95 | "data": base64_image 96 | } 97 | } 98 | 99 | # Construct API request 100 | try: 101 | response = client.messages.create( 102 | model=model_name, 103 | max_tokens=max_tokens, 104 | temperature=temperature, 105 | messages=[ 106 | { 107 | "role": "user", 108 | "content": [ 109 | {"type": "text", "text": prompt_text}, 110 | image_content 111 | ] 112 | } 113 | ] 114 | ) 115 | 116 | # Get response text 117 | model_response = response.content[0].text 118 | 119 | except Exception as api_error: 120 | error_msg = str(api_error) 121 | 122 | # If image size exceeds limit, convert to JPEG and retry 123 | if "image exceeds 5 MB maximum" in error_msg: 124 | print(f"Image exceeds size limit, converting to JPEG: {question_id}") 125 | try: 126 | # Convert to JPEG 127 | img_data = base64.b64decode(base64_image) 128 | img = Image.open(BytesIO(img_data)) 129 | 130 | # Convert to RGB (if RGBA) 131 | if img.mode == 'RGBA': 132 | img = img.convert('RGB') 133 | 134 | # Save as JPEG 135 | buffer = BytesIO() 136 | img.save(buffer, format='JPEG', quality=85) 137 | buffer.seek(0) 138 | 139 | # Convert to base64 140 | jpeg_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') 141 | 142 | # Update image content 143 | image_content = { 144 | "type": "image", 145 | "source": { 146 | "type": "base64", 147 | "media_type": "image/jpeg", 148 | "data": jpeg_base64 149 | } 150 | } 151 | 152 | # Retry API call 153 | response = client.messages.create( 154 | model=model_name, 155 | max_tokens=max_tokens, 156 | temperature=temperature, 157 | messages=[ 158 | { 159 | "role": "user", 160 | "content": [ 161 | {"type": "text", "text": prompt_text}, 162 | image_content 163 | ] 164 | } 165 | ] 166 | ) 167 | 168 | model_response = response.content[0].text 169 | print(f"Successfully converted to JPEG and completed request: {question_id}") 170 | 171 | except Exception as jpeg_error: 172 | print(f"JPEG conversion failed: {jpeg_error}") 173 | model_response = f"Processing error: Image exceeds size limit and JPEG conversion failed - {jpeg_error}" 174 | 175 | else: 176 | print(f"API call error, attempting retry: {api_error}") 177 | # Simple retry mechanism 178 | time.sleep(2) 179 | try: 180 | response = client.messages.create( 181 | model=model_name, 182 | max_tokens=max_tokens, 183 | temperature=temperature, 184 | messages=[ 185 | { 186 | "role": "user", 187 | "content": [ 188 | {"type": "text", "text": prompt_text}, 189 | image_content 190 | ] 191 | } 192 | ] 193 | ) 194 | model_response = response.content[0].text 195 | except Exception as retry_error: 196 | raise Exception(f"Retry failed: {retry_error}") 197 | 198 | result = { 199 | "Question_id": question_id, 200 | "Response": model_response, 201 | "Answer": item.get('Answer'), 202 | "Category": item.get('Category'), 203 | "Png_id": item.get('Png_id') 204 | } 205 | 206 | results.append(result) # Add new result to the list 207 | 208 | # Save results every 5 *new* questions or at the end of the file 209 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 210 | if output_path: 211 | save_results_to_json(results, output_path) 212 | print(f"Processed {i + 1}/{total_items} questions (added {newly_processed_count} new), saved results to {output_path}.") 213 | 214 | except Exception as e: 215 | print(f"Error processing Question_id {question_id}: {e}") 216 | result = { 217 | "Question_id": question_id, 218 | "Response": f"Processing error: {e}", # Record error message 219 | "Answer": item.get('Answer'), 220 | "Category": item.get('Category'), 221 | "Png_id": item.get('Png_id') 222 | } 223 | results.append(result) 224 | 225 | # Save results when error occurs 226 | if output_path: 227 | save_results_to_json(results, output_path) 228 | print(f"Error processing Question_id {question_id}, saved current results to {output_path}.") 229 | 230 | # Ensure final results are saved at the end of function if any new items were processed 231 | if newly_processed_count > 0 and output_path: 232 | save_results_to_json(results, output_path) 233 | print(f"File processing complete, final results saved to: {output_path}") 234 | 235 | return results # Return the complete results list for the current file 236 | 237 | def main(data_folder, api_key, model_name, output_base_dir, categories=None, max_tokens=1024, temperature=0): 238 | """Main function to process specified parquet datasets and generate outputs in separate files.""" 239 | # Ensure base output directory exists 240 | os.makedirs(output_base_dir, exist_ok=True) 241 | 242 | # Find all Parquet files 243 | all_parquet_files = glob.glob(os.path.join(data_folder, "*.parquet")) 244 | files_to_process = [] 245 | 246 | # Filter files by categories 247 | if categories: 248 | # Make sure category names don't include .parquet suffix 249 | category_basenames = {cat.replace(".parquet", "") for cat in categories} 250 | for file_path in all_parquet_files: 251 | basename_no_ext = os.path.splitext(os.path.basename(file_path))[0] 252 | if basename_no_ext in category_basenames: 253 | files_to_process.append(file_path) 254 | print(f"Will process specified category files: {files_to_process}") 255 | # Check for missing categories 256 | found_basenames = {os.path.splitext(os.path.basename(f))[0] for f in files_to_process} 257 | missing_categories = category_basenames - found_basenames 258 | if missing_categories: 259 | print(f"Warning: The following specified category files were not found: {missing_categories}") 260 | 261 | else: 262 | files_to_process = all_parquet_files 263 | print(f"Will process all .parquet files in the {data_folder} folder.") 264 | 265 | if not files_to_process: 266 | print("No Parquet files found to process.") 267 | return 268 | 269 | # Initialize Claude API client 270 | print("Initializing Claude API client...") 271 | try: 272 | client = Anthropic(api_key=api_key) 273 | print("Claude API client initialization complete.") 274 | except Exception as e: 275 | print(f"Error initializing Claude API client: {e}") 276 | return 277 | 278 | # Process each file 279 | for file_path in files_to_process: 280 | print("-" * 50) 281 | print(f"Starting to process file: {file_path}") 282 | 283 | # 1. Dynamically generate output filename from input filename 284 | base_name = os.path.basename(file_path) # e.g., "Multi-window_QA.parquet" 285 | dataset_name = os.path.splitext(base_name)[0] # e.g., "Multi-window_QA" 286 | # Use "claude_" prefix and dataset name to create JSON filename 287 | output_filename = f"claude_{dataset_name}.json" 288 | current_output_path = os.path.join(output_base_dir, output_filename) # Full output path 289 | print(f"Results will be saved to: {current_output_path}") 290 | 291 | # 2. Load existing results for current file 292 | existing_results_for_current_file = load_existing_results(current_output_path) 293 | 294 | # 3. Load Parquet data 295 | data_list = load_parquet_data(file_path) 296 | 297 | # 4. Process data 298 | if data_list: 299 | # Pass dynamically generated output path and loaded existing results to process_data 300 | process_data( 301 | data_list, 302 | client, 303 | model_name, 304 | temperature, 305 | max_tokens, 306 | existing_results=existing_results_for_current_file, # Pass existing results 307 | output_path=current_output_path # Pass output path 308 | ) 309 | print(f"File {file_path} processing complete.") 310 | else: 311 | print(f"No data in file {file_path} or loading failed.") 312 | 313 | print("=" * 50) 314 | print("Processing of all specified files complete.") 315 | 316 | if __name__ == "__main__": 317 | data_folder = "./datasets" # Data folder path 318 | api_key = "your_key" 319 | model_name = "claude-3-7-sonnet-20250219" 320 | 321 | output_dir = "./results/claude" # Specify result directory 322 | 323 | categories = ["Real-world_QA", "Synthetic_QA", "Multi-window_QA"] 324 | 325 | max_tokens = 300 # Maximum tokens to generate 326 | temperature = 0 # Temperature parameter 327 | 328 | # Call main function 329 | main(data_folder, api_key, model_name, output_dir, categories, max_tokens, temperature) -------------------------------------------------------------------------------- /calculate_similarity/gemini_evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Gemini Image Comparison Evaluator 3 | 4 | This script evaluates the similarity between pairs of webpage images using Google's Gemini model. 5 | It compares model-generated images with ground truth images across multiple dimensions and saves 6 | the evaluation scores as JSON files. 7 | 8 | ## Usage: 9 | 1. Place model-generated images in the 'o4mini_imgs' directory 10 | 2. Place ground truth images in the 'label_imgs' directory 11 | 3. Add your Gemini API key(s) to the api_keys list 12 | 4. Run the script: python gemini_evaluate.py 13 | 14 | ## Directory Structure: 15 | - o4mini_imgs/[model]/[dataset]/*.png - Model generated images 16 | - label_imgs/[dataset]/*.png - Ground truth images 17 | - results/ - Output directory for evaluation results 18 | """ 19 | 20 | from google import genai 21 | from google.genai import types 22 | import os 23 | import glob 24 | import json 25 | import re 26 | import time 27 | from pathlib import Path 28 | from datetime import datetime 29 | 30 | # API key configuration 31 | api_keys = [ 32 | "your keys" 33 | ] 34 | 35 | current_api_key_index = 0 36 | api_key = api_keys[current_api_key_index] 37 | client = genai.Client(api_key=api_key) 38 | 39 | def switch_to_next_api_key(): 40 | global current_api_key_index, client, api_key 41 | current_api_key_index = (current_api_key_index + 1) % len(api_keys) 42 | api_key = api_keys[current_api_key_index] 43 | client = genai.Client(api_key=api_key) 44 | print(f"Switching to new API key: {current_api_key_index + 1}/{len(api_keys)}") 45 | write_log(f"Switched to new API key index {current_api_key_index + 1}/{len(api_keys)}") 46 | return api_key 47 | 48 | # Source folders and output path 49 | source_folder1 = "o4mini_imgs" # Model generated images 50 | source_folder2 = "label_imgs" # Ground truth images 51 | json_save_path = "./results" # JSON save path 52 | 53 | os.makedirs(json_save_path, exist_ok=True) 54 | 55 | models = ["llava", "o1", "openai", "gemini", "internvl2", "internvl3", "qwen", "claude", "o4mini", "pro"] # models 56 | datasets = ["Code_Refinement", "Image_to_code", "Interaction_Authoring", "Text_to_code"] 57 | 58 | prompt = ''' 59 | Your task is to assess two webpage images and output a score between 0 and 10 for each of the following 10 questions, reflecting the degree of similarity between the webpages. A score of 10 indicates perfect similarity (identical in every aspect), while a score of 0 indicates no similarity. For partial similarities, assign a score between 1 and 9, where higher scores reflect greater similarity. Only output a comma-separated list of 10 numbers enclosed in square brackets, e.g., [10,8,6,4,2,0,0,0,0,0]. Do not assign a score of 10 unless the two images are identical in the respective category. 60 | 61 | 1. **Element Reproduction (Score: 0-10):** Are key elements (text, images, buttons) fully present and styled identically to the original design? (e.g., 10 for identical elements, 5 for missing or slightly altered elements, 0 for completely different elements.) 62 | 2. **Proportion and Size Consistency (Score: 0-10):** Do the sizes and proportions of elements (text, images, buttons) match the original design, maintaining visual harmony? (e.g., 10 for exact proportions, 6 for minor size differences, 0 for significant discrepancies.) 63 | 3. **Layout and Typography Fidelity (Score: 0-10):** Does the overall layout (headers, footers, navigation bars, sidebars) faithfully replicate the original design's typography and structure? (e.g., 10 for identical layouts, 5 for similar but not exact placements, 0 for entirely different layouts.) 64 | 4. **Alignment and Spacing Accuracy (Score: 0-10):** Are elements aligned and spaced (margins, padding) as in the original design? (e.g., 10 for perfect alignment and spacing, 6 for minor misalignments, 0 for major misalignments.) 65 | 5. **Visual Hierarchy Clarity (Score: 0-10):** Does the webpage maintain the same visual hierarchy as the original, allowing users to quickly identify key information? (e.g., 10 for identical hierarchy, 5 for slightly altered emphasis, 0 for unclear or different hierarchy.) 66 | 6. **Color Consistency (Score: 0-10):** Do the overall color scheme, hues, and tones match the original design? (e.g., 10 for identical colors, 6 for similar palette with minor variations, 0 for completely different colors.) 67 | 7. **Style Consistency (Score: 0-10):** Does the webpage's overall aesthetic style (e.g., modern, minimalistic) align with the original design? (e.g., 10 for identical style, 4 for similar but distinguishable style, 0 for entirely different style.) 68 | 8. **Text Style Consistency (Score: 0-10):** Are text attributes (font type, size, line spacing, paragraph spacing, alignment) consistent with the original design? (e.g., 10 for identical text styles, 5 for similar fonts with spacing issues, 0 for completely different text styles.) 69 | 9. **Text Content Accuracy (Score: 0-10):** Does the webpage accurately reproduce the main textual content of the original design? (e.g., 10 for identical text, 5 for partial matches, 0 for entirely different text.) 70 | 10. **Overall Content Representation (Score: 0-10):** Does the webpage convey the same content information and intent as the original design? (e.g., 10 for identical content representation, 6 for similar but incomplete content, 0 for entirely different content.) 71 | 72 | **Output Format:** [score1,score2,score3,score4,score5,score6,score7,score8,score9,score10] 73 | ''' 74 | 75 | log_file = os.path.join(json_save_path, "log.txt") 76 | 77 | def write_log(message): 78 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 79 | with open(log_file, "a", encoding="utf-8") as f: 80 | f.write(f"[{timestamp}] {message}\n") 81 | 82 | def update_category_summary(model, category, model_data): 83 | category_items = [item for item in model_data[model]["items"] 84 | if item.get("Category") == category and not item.get("id", "").endswith("_summary")] 85 | 86 | if not category_items: 87 | return 88 | 89 | all_scores = [item["scores"] for item in category_items] 90 | 91 | avg_scores = [] 92 | for i in range(10): 93 | dimension_scores = [scores[i] for scores in all_scores] 94 | avg_scores.append(round(sum(dimension_scores) / len(dimension_scores), 2)) 95 | 96 | summary_id = f"{category}_summary" 97 | summary_exists = False 98 | 99 | for i, item in enumerate(model_data[model]["items"]): 100 | if item.get("id") == summary_id: 101 | model_data[model]["items"][i]["scores"] = avg_scores 102 | summary_exists = True 103 | break 104 | 105 | if not summary_exists: 106 | summary_item = { 107 | "id": summary_id, 108 | "Category": category, 109 | "scores": avg_scores 110 | } 111 | model_data[model]["items"].append(summary_item) 112 | 113 | model_data = {} 114 | for model in models: 115 | json_file = os.path.join(json_save_path, f"{model}.json") 116 | if os.path.exists(json_file): 117 | with open(json_file, 'r', encoding='utf-8') as f: 118 | try: 119 | model_data[model] = json.load(f) 120 | except json.JSONDecodeError: 121 | model_data[model] = {"items": []} 122 | else: 123 | model_data[model] = {"items": []} 124 | 125 | total_pairs = 0 126 | processed_pairs = 0 127 | model_processed_counts = {model: 0 for model in models} 128 | 129 | try: 130 | for model in models: 131 | for dataset in datasets: 132 | dataset_path1 = os.path.join(source_folder1, model, dataset) 133 | dataset_path2 = os.path.join(source_folder2, dataset) 134 | 135 | if not os.path.exists(dataset_path1) or not os.path.exists(dataset_path2): 136 | message = f"Path does not exist: {dataset_path1} or {dataset_path2}" 137 | print(message) 138 | write_log(message) 139 | continue 140 | 141 | images1 = [] 142 | for ext in ['*.png', '*.jpg', '*.jpeg']: 143 | images1.extend(glob.glob(os.path.join(dataset_path1, ext))) 144 | 145 | for img1_path in images1: 146 | img_name = os.path.basename(img1_path) 147 | img2_path = os.path.join(dataset_path2, img_name) 148 | 149 | if os.path.exists(img2_path): 150 | img_id = f"{dataset}_{Path(img_name).stem}" 151 | already_processed = False 152 | 153 | for item in model_data[model]["items"]: 154 | if item.get("id") == img_id: 155 | already_processed = True 156 | break 157 | 158 | if already_processed: 159 | print(f"Skipping already processed image: {model}/{dataset}/{img_name}") 160 | continue 161 | 162 | total_pairs += 1 163 | 164 | with open(img1_path, 'rb') as f1, open(img2_path, 'rb') as f2: 165 | img1_bytes = f1.read() 166 | img2_bytes = f2.read() 167 | 168 | success = False 169 | retry_count = 0 170 | max_retries = 2 171 | error_messages = [] 172 | 173 | while not success and retry_count <= max_retries: 174 | try: 175 | if retry_count > 0: 176 | print(f"Retry #{retry_count} processing image: {model}/{dataset}/{img_name}") 177 | 178 | response = client.models.generate_content( 179 | model="gemini-2.5-flash-preview-04-17", 180 | contents=[ 181 | prompt, 182 | types.Part.from_bytes(data=img1_bytes, mime_type='image/png'), 183 | types.Part.from_bytes(data=img2_bytes, mime_type='image/png') 184 | ] 185 | ) 186 | 187 | response_text = response.text 188 | scores_match = re.search(r'\[(.*?)\]', response_text) 189 | if scores_match: 190 | scores_text = scores_match.group(1) 191 | scores = [float(s.strip()) for s in scores_text.split(',')] 192 | if len(scores) == 10: 193 | item_data = { 194 | "id": img_id, 195 | "Category": dataset, 196 | "scores": scores 197 | } 198 | model_data[model]["items"].append(item_data) 199 | 200 | update_category_summary(model, dataset, model_data) 201 | 202 | processed_pairs += 1 203 | model_processed_counts[model] += 1 204 | print(f"Processed image pair {processed_pairs}/{total_pairs}: {model}/{dataset}/{img_name}") 205 | 206 | if model_processed_counts[model] % 1 == 0: 207 | json_file = os.path.join(json_save_path, f"{model}.json") 208 | with open(json_file, 'w', encoding='utf-8') as f: 209 | json.dump(model_data[model], f, indent=2, ensure_ascii=False) 210 | print(f"Saved {model}.json (processed {model_processed_counts[model]} images)") 211 | success = True 212 | else: 213 | error_msg = f"Did not get 10 scores: {scores}" 214 | error_messages.append(error_msg) 215 | print(f"Warning: {model}/{dataset}/{img_name} {error_msg}") 216 | else: 217 | error_msg = f"Could not extract score list from response: {response_text}" 218 | error_messages.append(error_msg) 219 | print(f"Warning: {model}/{dataset}/{img_name} {error_msg}") 220 | 221 | except Exception as e: 222 | error_msg = f"Error processing image: {str(e)}" 223 | error_messages.append(error_msg) 224 | print(f"Warning: {model}/{dataset}/{img_name} {error_msg}") 225 | 226 | error_str = str(e) 227 | if "RESOURCE_EXHAUSTED" in error_str and "exceeded your current quota" in error_str: 228 | switch_to_next_api_key() 229 | continue 230 | 231 | retry_count += 1 232 | if not success and retry_count <= max_retries: 233 | time.sleep(2) 234 | 235 | if not success: 236 | error_log = f"Failed to process image (after {max_retries+1} attempts): {model}/{dataset}/{img_name}\n" 237 | for i, msg in enumerate(error_messages): 238 | error_log += f" Attempt {i+1} error: {msg}\n" 239 | write_log(error_log) 240 | 241 | except Exception as e: 242 | message = f"Error occurred, stopping script: {str(e)}" 243 | print(message) 244 | write_log(message) 245 | for model in models: 246 | if model in model_data: 247 | for dataset in datasets: 248 | update_category_summary(model, dataset, model_data) 249 | 250 | json_file = os.path.join(json_save_path, f"{model}.json") 251 | with open(json_file, 'w', encoding='utf-8') as f: 252 | json.dump(model_data[model], f, indent=2, ensure_ascii=False) 253 | 254 | for model in models: 255 | if model in model_data: 256 | for dataset in datasets: 257 | update_category_summary(model, dataset, model_data) 258 | 259 | json_file = os.path.join(json_save_path, f"{model}.json") 260 | with open(json_file, 'w', encoding='utf-8') as f: 261 | json.dump(model_data[model], f, indent=2, ensure_ascii=False) 262 | 263 | print(f"Complete! Processed {processed_pairs}/{total_pairs} image pairs.") -------------------------------------------------------------------------------- /generate_response/internvl2_5_code.py: -------------------------------------------------------------------------------- 1 | """ 2 | InternVL2.5 Code Generation Tool 3 | 4 | Usage: 5 | python internvl2_5_code.py [--data_folder DATA_FOLDER] [--model_path MODEL_PATH] 6 | [--output_dir OUTPUT_DIR] [--categories CATEGORIES] 7 | [--max_tokens MAX_TOKENS] [--temperature TEMPERATURE] 8 | [--tensor_parallel_size TP_SIZE] 9 | """ 10 | 11 | import os 12 | import json 13 | import base64 14 | import re 15 | import pandas as pd 16 | from io import BytesIO 17 | from PIL import Image 18 | from vllm import LLM, SamplingParams 19 | 20 | def load_parquet_data(file_path): 21 | try: 22 | df = pd.read_parquet(file_path) 23 | return df.to_dict('records') 24 | except Exception as e: 25 | print(f"Failed to load file {file_path}: {e}") 26 | return [] 27 | 28 | def load_existing_results(output_path): 29 | if os.path.exists(output_path): 30 | try: 31 | with open(output_path, 'r', encoding='utf-8') as f: 32 | return json.load(f) 33 | except Exception as e: 34 | print(f"Failed to load existing results file {output_path}: {e}") 35 | return [] 36 | 37 | def save_results_to_json(results, output_path="output.json"): 38 | try: 39 | with open(output_path, 'w', encoding='utf-8') as f: 40 | json.dump(results, f, ensure_ascii=False, indent=4) 41 | print(f"Results saved to: {output_path}") 42 | except Exception as e: 43 | print(f"Failed to save results to {output_path}: {e}") 44 | 45 | def extract_html_code(text): 46 | html_pattern = r'(?:<\!DOCTYPE\s+html>.*?<\/html>)' 47 | matches = re.findall(html_pattern, text, re.DOTALL | re.IGNORECASE) 48 | if matches: 49 | return matches[0] 50 | if '' in text.lower() and '' in text.lower(): 51 | start = text.lower().find('') 52 | end = text.lower().find('') + len('') 53 | return text[start:end] 54 | return text 55 | 56 | def decode_base64_image(base64_string): 57 | try: 58 | image_bytes = base64.b64decode(base64_string) 59 | image_io = BytesIO(image_bytes) 60 | image = Image.open(image_io) 61 | return image 62 | except Exception as e: 63 | print(f"Failed to decode base64 image: {e}") 64 | return None 65 | 66 | def process_text_to_code(data_list, llm, temperature, max_tokens, existing_results=None, output_path=None): 67 | results = [] 68 | sampling_params = SamplingParams( 69 | temperature=temperature, 70 | max_tokens=max_tokens, 71 | ) 72 | 73 | processed_ids = set() 74 | if existing_results: 75 | processed_ids = {item.get('Id') for item in existing_results if item.get('Id')} 76 | results = existing_results.copy() 77 | print(f"Loaded {len(processed_ids)} already processed items (from {output_path}), will skip these items.") 78 | 79 | total_items = len(data_list) 80 | newly_processed_count = 0 81 | 82 | for i, item in enumerate(data_list): 83 | item_id = item.get('Id') 84 | 85 | if item_id in processed_ids: 86 | continue 87 | 88 | newly_processed_count += 1 89 | 90 | try: 91 | prompt = item['Prompt'] 92 | input_text = item['Input_text'] 93 | 94 | internvl_prompt = f"USER: {prompt}\nDescription:{input_text}\nASSISTANT:" 95 | 96 | inputs = { 97 | "prompt": internvl_prompt, 98 | } 99 | 100 | outputs = llm.generate([inputs], sampling_params=sampling_params) 101 | original_response = outputs[0].outputs[0].text.strip() 102 | 103 | html_response = extract_html_code(original_response) 104 | 105 | if html_response != original_response: 106 | print(f"ID {item_id}: Successfully extracted HTML code") 107 | 108 | result = { 109 | "Id": item_id, 110 | "Response": html_response, 111 | "Label_html": item.get('Label_html'), 112 | "Category": item.get('Category'), 113 | "Png_id": item.get('Png_id') 114 | } 115 | 116 | results.append(result) 117 | 118 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 119 | if output_path: 120 | save_results_to_json(results, output_path) 121 | print(f"Processed {i + 1}/{total_items} items (new: {newly_processed_count}), saved results to {output_path}.") 122 | 123 | except Exception as e: 124 | print(f"Error processing Id {item_id}: {e}") 125 | result = { 126 | "Id": item_id, 127 | "Response": f"Processing error: {e}", 128 | "Label_html": item.get('Label_html'), 129 | "Category": item.get('Category'), 130 | "Png_id": item.get('Png_id') 131 | } 132 | results.append(result) 133 | 134 | if output_path: 135 | save_results_to_json(results, output_path) 136 | print(f"Error processing Id {item_id}, saved current results to {output_path}.") 137 | 138 | if newly_processed_count > 0 and output_path: 139 | save_results_to_json(results, output_path) 140 | print(f"File processing complete, final results saved to: {output_path}") 141 | 142 | return results 143 | 144 | def process_image_to_code(data_list, llm, temperature, max_tokens, existing_results=None, output_path=None): 145 | results = [] 146 | sampling_params = SamplingParams( 147 | temperature=temperature, 148 | max_tokens=max_tokens, 149 | ) 150 | 151 | processed_ids = set() 152 | if existing_results: 153 | processed_ids = {item.get('Id') for item in existing_results if item.get('Id')} 154 | results = existing_results.copy() 155 | print(f"Loaded {len(processed_ids)} already processed items (from {output_path}), will skip these items.") 156 | 157 | total_items = len(data_list) 158 | newly_processed_count = 0 159 | 160 | for i, item in enumerate(data_list): 161 | item_id = item.get('Id') 162 | 163 | if item_id in processed_ids: 164 | continue 165 | 166 | newly_processed_count += 1 167 | 168 | try: 169 | base64_image = item['Image'] 170 | image_bytes = base64.b64decode(base64_image) 171 | image_io = BytesIO(image_bytes) 172 | image = Image.open(image_io) 173 | 174 | prompt = item['Prompt'] 175 | 176 | internvl_prompt = f"USER: \n{prompt}\nASSISTANT:" 177 | 178 | inputs = { 179 | "prompt": internvl_prompt, 180 | "multi_modal_data": { 181 | "image": image 182 | }, 183 | } 184 | 185 | outputs = llm.generate([inputs], sampling_params=sampling_params) 186 | original_response = outputs[0].outputs[0].text.strip() 187 | 188 | html_response = extract_html_code(original_response) 189 | 190 | if html_response != original_response: 191 | print(f"ID {item_id}: Successfully extracted HTML code") 192 | 193 | result = { 194 | "Id": item_id, 195 | "Response": html_response, 196 | "Label_html": item.get('Label_html'), 197 | "Category": item.get('Category'), 198 | "Png_id": item.get('Png_id') 199 | } 200 | 201 | results.append(result) 202 | 203 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 204 | if output_path: 205 | save_results_to_json(results, output_path) 206 | print(f"Processed {i + 1}/{total_items} items (new: {newly_processed_count}), saved results to {output_path}.") 207 | 208 | except Exception as e: 209 | print(f"Error processing Id {item_id}: {e}") 210 | result = { 211 | "Id": item_id, 212 | "Response": f"Processing error: {e}", 213 | "Label_html": item.get('Label_html'), 214 | "Category": item.get('Category'), 215 | "Png_id": item.get('Png_id') 216 | } 217 | results.append(result) 218 | 219 | if output_path: 220 | save_results_to_json(results, output_path) 221 | print(f"Error processing Id {item_id}, saved current results to {output_path}.") 222 | 223 | if newly_processed_count > 0 and output_path: 224 | save_results_to_json(results, output_path) 225 | print(f"File processing complete, final results saved to: {output_path}") 226 | 227 | return results 228 | 229 | def process_refinement_to_code(data_list, llm, temperature, max_tokens, existing_results=None, output_path=None): 230 | results = [] 231 | sampling_params = SamplingParams( 232 | temperature=temperature, 233 | max_tokens=max_tokens, 234 | ) 235 | 236 | processed_ids = set() 237 | if existing_results: 238 | processed_ids = {item.get('Id') for item in existing_results if item.get('Id')} 239 | results = existing_results.copy() 240 | 241 | total_items = len(data_list) 242 | newly_processed_count = 0 243 | 244 | for i, item in enumerate(data_list): 245 | item_id = item.get('Id') 246 | 247 | if item_id in processed_ids: 248 | continue 249 | 250 | newly_processed_count += 1 251 | 252 | try: 253 | base64_image = item['Image'] 254 | image_bytes = base64.b64decode(base64_image) 255 | image_io = BytesIO(image_bytes) 256 | image = Image.open(image_io) 257 | 258 | prompt = item['Prompt'] 259 | input_html = item['Input_html'] 260 | 261 | internvl_prompt = f"USER: \n{prompt}\nCode:\n{input_html}\nASSISTANT:" 262 | 263 | inputs = { 264 | "prompt": internvl_prompt, 265 | "multi_modal_data": { 266 | "image": image 267 | }, 268 | } 269 | 270 | outputs = llm.generate([inputs], sampling_params=sampling_params) 271 | original_response = outputs[0].outputs[0].text.strip() 272 | 273 | html_response = extract_html_code(original_response) 274 | 275 | result = { 276 | "Id": item_id, 277 | "Response": html_response, 278 | "Label_html": item.get('Label_html'), 279 | "Category": item.get('Category'), 280 | "Png_id": item.get('Png_id') 281 | } 282 | 283 | results.append(result) 284 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 285 | if output_path: 286 | save_results_to_json(results, output_path) 287 | 288 | except Exception as e: 289 | print(f"Error processing Id {item_id}: {e}") 290 | result = { 291 | "Id": item_id, 292 | "Response": f"Processing error: {e}", 293 | "Label_html": item.get('Label_html'), 294 | "Category": item.get('Category'), 295 | "Png_id": item.get('Png_id') 296 | } 297 | results.append(result) 298 | 299 | if output_path: 300 | save_results_to_json(results, output_path) 301 | print(f"Error processing Id {item_id}, saved current results to {output_path}.") 302 | 303 | if newly_processed_count > 0 and output_path: 304 | save_results_to_json(results, output_path) 305 | print(f"File processing complete, final results saved to: {output_path}") 306 | 307 | return results 308 | 309 | def process_interaction_to_code(data_list, llm, temperature, max_tokens, existing_results=None, output_path=None): 310 | results = [] 311 | sampling_params = SamplingParams( 312 | temperature=temperature, 313 | max_tokens=max_tokens, 314 | ) 315 | 316 | processed_ids = set() 317 | if existing_results: 318 | processed_ids = {item.get('Id') for item in existing_results if item.get('Id')} 319 | results = existing_results.copy() 320 | print(f"Loaded {len(processed_ids)} already processed items (from {output_path}), will skip these items.") 321 | 322 | total_items = len(data_list) 323 | newly_processed_count = 0 324 | 325 | for i, item in enumerate(data_list): 326 | item_id = item.get('Id') 327 | 328 | if item_id in processed_ids: 329 | continue 330 | 331 | newly_processed_count += 1 332 | 333 | try: 334 | before_image = decode_base64_image(item['Before_image']) 335 | after_image = decode_base64_image(item['After_image']) 336 | 337 | if before_image is None or after_image is None: 338 | raise ValueError("Failed to decode images") 339 | 340 | prompt = item['Prompt'] 341 | 342 | internvl_prompt = f"USER: \n\n{prompt}\nASSISTANT:" 343 | 344 | inputs = { 345 | "prompt": internvl_prompt, 346 | "multi_modal_data": { 347 | "image": [before_image, after_image] 348 | }, 349 | } 350 | 351 | outputs = llm.generate([inputs], sampling_params=sampling_params) 352 | original_response = outputs[0].outputs[0].text.strip() 353 | 354 | html_response = extract_html_code(original_response) 355 | 356 | if html_response != original_response: 357 | print(f"ID {item_id}: Successfully extracted HTML code") 358 | 359 | result = { 360 | "Id": item_id, 361 | "Interaction_type": item.get('Interaction_type'), 362 | "Response": html_response, 363 | "Label_html": item.get('Label_html'), 364 | "Category": item.get('Category'), 365 | "Png_id": item.get('Png_id') 366 | } 367 | 368 | results.append(result) 369 | 370 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 371 | if output_path: 372 | save_results_to_json(results, output_path) 373 | print(f"Processed {i + 1}/{total_items} items (new: {newly_processed_count}), saved results to {output_path}.") 374 | 375 | except Exception as e: 376 | print(f"Error processing Id {item_id}: {e}") 377 | result = { 378 | "Id": item_id, 379 | "Interaction_type": item.get('Interaction_type'), 380 | "Response": f"Processing error: {e}", 381 | "Label_html": item.get('Label_html'), 382 | "Category": item.get('Category'), 383 | "Png_id": item.get('Png_id') 384 | } 385 | results.append(result) 386 | 387 | if output_path: 388 | save_results_to_json(results, output_path) 389 | print(f"Error processing Id {item_id}, saved current results to {output_path}.") 390 | 391 | if newly_processed_count > 0 and output_path: 392 | save_results_to_json(results, output_path) 393 | print(f"File processing complete, final results saved to: {output_path}") 394 | 395 | return results 396 | 397 | def main(data_folder="./data", model_path="./models/InternVL2_5", output_base_dir="./results", 398 | categories=None, max_tokens=256, temperature=0, tensor_parallel_size=8): 399 | """Main function that selects the appropriate processing method based on the specified categories.""" 400 | # Ensure the base output directory exists 401 | os.makedirs(output_base_dir, exist_ok=True) 402 | 403 | # All supported task types and their corresponding processing functions 404 | task_handlers = { 405 | "Text_to_code": process_text_to_code, 406 | "Image_to_code": process_image_to_code, 407 | "Code Refinement": process_refinement_to_code, 408 | "Interaction_Authoring": process_interaction_to_code 409 | } 410 | 411 | # If no categories are specified, default to processing all supported categories 412 | if categories is None: 413 | categories = list(task_handlers.keys()) 414 | 415 | # Validate that all specified categories are supported 416 | unsupported_categories = [cat for cat in categories if cat not in task_handlers] 417 | if unsupported_categories: 418 | print(f"Warning: The following categories are not supported: {unsupported_categories}") 419 | # Filter out unsupported categories 420 | categories = [cat for cat in categories if cat in task_handlers] 421 | if not categories: 422 | print("No valid categories to process.") 423 | return 424 | 425 | print(f"Will process the following categories: {categories}") 426 | 427 | # --- Model Loading --- 428 | print("Loading model...") 429 | try: 430 | # Determine if multimodal tasks need to be processed 431 | multimodal_tasks = ["Image_to_code", "Code_Refinemnet", "Interaction_Authoring", "Text_to_code"] 432 | requires_multimodal = any(cat in multimodal_tasks for cat in categories) 433 | 434 | # Create LLM configuration 435 | llm_config = { 436 | "model": model_path, 437 | "tensor_parallel_size": tensor_parallel_size, 438 | "max_model_len": 32768, 439 | "enforce_eager": True 440 | } 441 | 442 | # For multimodal tasks, add multimodal processing configuration 443 | if requires_multimodal: 444 | # Determine if multi-image tasks need to be processed 445 | multi_image_tasks = ["Interaction_Authoring"] 446 | max_images = 2 if any(cat in multi_image_tasks for cat in categories) else 1 447 | llm_config["limit_mm_per_prompt"] = {"image": max_images, "video": 0} 448 | 449 | llm = LLM(**llm_config) 450 | print("Model loading complete.") 451 | except Exception as e: 452 | print(f"Error loading model: {e}") 453 | return 454 | 455 | # --- Process each selected category --- 456 | for category in categories: 457 | print("=" * 50) 458 | print(f"Starting to process category: {category}") 459 | 460 | # Find the corresponding parquet file for the category 461 | target_file = os.path.join(data_folder, f"{category}.parquet") 462 | 463 | if not os.path.exists(target_file): 464 | print(f"Could not find file for category {category}: {target_file}") 465 | continue 466 | 467 | print(f"Found category file: {target_file}") 468 | 469 | # Generate output filename 470 | output_filename = f"internvl2_5_{category}.json" 471 | current_output_path = os.path.join(output_base_dir, output_filename) 472 | print(f"Results will be saved to: {current_output_path}") 473 | 474 | # Load existing results for the current file 475 | existing_results = load_existing_results(current_output_path) 476 | 477 | # Load data 478 | data_list = load_parquet_data(target_file) 479 | 480 | if not data_list: 481 | print(f"No data in file {target_file} or loading failed.") 482 | continue 483 | 484 | # Get the corresponding processing function 485 | process_func = task_handlers[category] 486 | 487 | # Process the data 488 | process_func( 489 | data_list, 490 | llm, 491 | temperature, 492 | max_tokens, 493 | existing_results=existing_results, 494 | output_path=current_output_path 495 | ) 496 | 497 | print(f"Processing for category {category} completed.") 498 | 499 | print("=" * 50) 500 | print("Processing workflow for all specified categories is complete.") 501 | 502 | if __name__ == "__main__": 503 | # Default paths - adjust these to relative paths for your environment 504 | data_folder = "./data" 505 | model_path = "./models/InternVL2_5" 506 | output_dir = "./results" 507 | 508 | # Specify which dataset categories to process 509 | categories = ["Text_to_code", "Image_to_code", "Code_Refinemnet", "Interaction_Authoring"] 510 | 511 | max_tokens = 30000 # Maximum number of tokens to generate 512 | temperature = 0 # Temperature parameter 513 | tensor_parallel_size = 4 # Number of GPUs for parallelism 514 | 515 | # Call the main function 516 | main(data_folder, model_path, output_dir, categories, max_tokens, temperature, tensor_parallel_size) 517 | -------------------------------------------------------------------------------- /generate_response/o1_code.py: -------------------------------------------------------------------------------- 1 | """ 2 | O1 Model HTML Code Generation Tool 3 | 4 | Usage: 5 | 1. Set your OpenAI API key in the main() function 6 | 2. Prepare your data in the './mini_datasets' folder (parquet files) 7 | 3. Specify categories to process in the main() function 8 | 4. Run the script: python o1_code.py 9 | 10 | This script processes text and image inputs to generate HTML code using OpenAI's models. 11 | """ 12 | 13 | import os 14 | import json 15 | import base64 16 | import re 17 | import pandas as pd 18 | from openai import OpenAI 19 | import time 20 | from typing import List, Dict, Any, Optional 21 | 22 | # Utility functions 23 | def load_parquet_data(file_path): 24 | """Load data from a single Parquet file.""" 25 | try: 26 | df = pd.read_parquet(file_path) 27 | return df.to_dict('records') 28 | except Exception as e: 29 | print(f"Failed to load file {file_path}: {e}") 30 | return [] 31 | 32 | def load_existing_results(output_path): 33 | """Load existing results file to skip already processed data.""" 34 | if os.path.exists(output_path): 35 | try: 36 | with open(output_path, 'r', encoding='utf-8') as f: 37 | return json.load(f) 38 | except Exception as e: 39 | print(f"Failed to load existing results file {output_path}: {e}") 40 | return [] 41 | 42 | def save_results_to_json(results, output_path="output.json"): 43 | """Save results to a JSON file.""" 44 | try: 45 | with open(output_path, 'w', encoding='utf-8') as f: 46 | json.dump(results, f, ensure_ascii=False, indent=4) 47 | print(f"Results saved to: {output_path}") 48 | except Exception as e: 49 | print(f"Failed to save results to {output_path}: {e}") 50 | 51 | def extract_html_code(text): 52 | """Extract HTML code from text, requiring it to start with and end with . 53 | If no matching HTML code is found, return the original text.""" 54 | html_pattern = r'(?:<\!DOCTYPE\s+html>.*?<\/html>)' 55 | matches = re.findall(html_pattern, text, re.DOTALL | re.IGNORECASE) 56 | if matches: 57 | return matches[0] 58 | if '' in text.lower() and '' in text.lower(): 59 | start = text.lower().find('') 60 | end = text.lower().find('') + len('') 61 | return text[start:end] 62 | return text 63 | 64 | def decode_base64_image(base64_string): 65 | """Keep base64 encoded image as base64 format for API calls.""" 66 | try: 67 | # Validate base64 string 68 | base64.b64decode(base64_string) 69 | return base64_string # For OpenAI, we return the original base64 string 70 | except Exception as e: 71 | print(f"Failed to decode base64 image: {e}") 72 | return None 73 | 74 | def generate_openai_response(client, model_name, messages, retries=2): 75 | """Call OpenAI API to generate a response, with retry mechanism.""" 76 | for attempt in range(retries + 1): 77 | try: 78 | response = client.chat.completions.create( 79 | model=model_name, 80 | messages=messages, 81 | ) 82 | return response.choices[0].message.content 83 | except Exception as e: 84 | if attempt < retries: 85 | wait_time = (attempt + 1) * 2 # Exponential backoff 86 | print(f"API call failed, retrying in {wait_time} seconds: {e}") 87 | time.sleep(wait_time) 88 | else: 89 | raise Exception(f"API call failed after {retries+1} attempts: {e}") 90 | 91 | def process_text_to_code(data_list, client, model_name, existing_results=None, output_path=None): 92 | """Process text-to-code conversion data.""" 93 | results = [] 94 | 95 | # Create a set of processed question IDs to skip 96 | processed_ids = set() 97 | if existing_results: 98 | processed_ids = {item.get('Id') for item in existing_results if item.get('Id')} 99 | results = existing_results.copy() # Use existing results for this file 100 | print(f"Loaded {len(processed_ids)} already processed items (from {output_path}), will skip these items.") 101 | 102 | total_items = len(data_list) 103 | newly_processed_count = 0 104 | 105 | for i, item in enumerate(data_list): 106 | item_id = item.get('Id') 107 | 108 | # Skip already processed questions 109 | if item_id in processed_ids: 110 | continue 111 | 112 | newly_processed_count += 1 # Count newly processed questions 113 | 114 | try: 115 | prompt = item['Prompt'] 116 | input_text = item['Input_text'] 117 | 118 | # Build prompt text 119 | prompt_text = f"{prompt}\nDescription:{input_text}" 120 | 121 | # Build API request messages 122 | messages = [ 123 | { 124 | "role": "user", 125 | "content": [ 126 | { 127 | "type": "text", 128 | "text": prompt_text 129 | } 130 | ] 131 | } 132 | ] 133 | 134 | # Generate reply 135 | original_response = generate_openai_response( 136 | client, 137 | model_name, 138 | messages=messages, 139 | ) 140 | 141 | # Extract HTML code 142 | html_response = extract_html_code(original_response) 143 | 144 | # If HTML code was successfully extracted, log it 145 | if html_response != original_response: 146 | print(f"ID {item_id}: Successfully extracted HTML code") 147 | 148 | result = { 149 | "Id": item_id, 150 | "Response": html_response, 151 | "Label_html": item.get('Label_html'), 152 | "Category": item.get('Category'), 153 | "Png_id": item.get('Png_id') 154 | } 155 | 156 | results.append(result) # Add new result to list 157 | 158 | # Save results after processing each new item or at the end 159 | if newly_processed_count > 0 and (newly_processed_count % 1 == 0 or i == total_items - 1): 160 | if output_path: 161 | save_results_to_json(results, output_path) 162 | print(f"Processed {i + 1}/{total_items} items (added {newly_processed_count} new), saved results to {output_path}.") 163 | 164 | except Exception as e: 165 | print(f"Error processing Id {item_id}: {e}") 166 | result = { 167 | "Id": item_id, 168 | "Response": f"Processing error: {e}", # Record error information 169 | "Label_html": item.get('Label_html'), 170 | "Category": item.get('Category'), 171 | "Png_id": item.get('Png_id') 172 | } 173 | results.append(result) 174 | 175 | # Save results when an error occurs 176 | if output_path: 177 | save_results_to_json(results, output_path) 178 | print(f"Error processing Id {item_id}, saved current results to {output_path}.") 179 | 180 | # Ensure final results are saved at the end of the function if any new items were processed 181 | if newly_processed_count > 0 and output_path: 182 | save_results_to_json(results, output_path) 183 | print(f"Completed processing file, final results saved to: {output_path}") 184 | 185 | return results 186 | 187 | def process_image_to_code(data_list, client, model_name, existing_results=None, output_path=None): 188 | """Process image-to-code conversion data.""" 189 | results = [] 190 | 191 | processed_ids = set() 192 | if existing_results: 193 | processed_ids = {item.get('Id') for item in existing_results if item.get('Id')} 194 | results = existing_results.copy() 195 | print(f"Loaded {len(processed_ids)} already processed items (from {output_path}), will skip these items.") 196 | 197 | total_items = len(data_list) 198 | newly_processed_count = 0 199 | 200 | for i, item in enumerate(data_list): 201 | item_id = item.get('Id') 202 | 203 | if item_id in processed_ids: 204 | continue 205 | 206 | newly_processed_count += 1 207 | 208 | try: 209 | base64_image = item['Image'] 210 | if not base64_image: 211 | raise ValueError("Image data is empty") 212 | 213 | prompt = item['Prompt'] 214 | 215 | messages = [ 216 | { 217 | "role": "user", 218 | "content": [ 219 | { 220 | "type": "text", 221 | "text": prompt 222 | }, 223 | { 224 | "type": "image_url", 225 | "image_url": { 226 | "url": f"data:image/jpeg;base64,{base64_image}" 227 | } 228 | } 229 | ] 230 | } 231 | ] 232 | 233 | original_response = generate_openai_response( 234 | client, 235 | model_name, 236 | messages=messages, 237 | ) 238 | 239 | html_response = extract_html_code(original_response) 240 | 241 | if html_response != original_response: 242 | print(f"ID {item_id}: Successfully extracted HTML code") 243 | 244 | result = { 245 | "Id": item_id, 246 | "Response": html_response, 247 | "Label_html": item.get('Label_html'), 248 | "Category": item.get('Category'), 249 | "Png_id": item.get('Png_id') 250 | } 251 | 252 | results.append(result) 253 | 254 | if newly_processed_count > 0 and (newly_processed_count % 1 == 0 or i == total_items - 1): 255 | if output_path: 256 | save_results_to_json(results, output_path) 257 | print(f"Processed {i + 1}/{total_items} items (added {newly_processed_count} new), saved results to {output_path}.") 258 | 259 | except Exception as e: 260 | print(f"Error processing Id {item_id}: {e}") 261 | result = { 262 | "Id": item_id, 263 | "Response": f"Processing error: {e}", 264 | "Label_html": item.get('Label_html'), 265 | "Category": item.get('Category'), 266 | "Png_id": item.get('Png_id') 267 | } 268 | results.append(result) 269 | 270 | if output_path: 271 | save_results_to_json(results, output_path) 272 | print(f"Error processing Id {item_id}, saved current results to {output_path}.") 273 | 274 | if newly_processed_count > 0 and output_path: 275 | save_results_to_json(results, output_path) 276 | print(f"Completed processing file, final results saved to: {output_path}") 277 | 278 | return results 279 | 280 | def process_refinement_to_code(data_list, client, model_name, existing_results=None, output_path=None): 281 | results = [] 282 | 283 | processed_ids = set() 284 | if existing_results: 285 | processed_ids = {item.get('Id') for item in existing_results if item.get('Id')} 286 | results = existing_results.copy() 287 | print(f"Loaded {len(processed_ids)} already processed items (from {output_path}), will skip these items.") 288 | 289 | total_items = len(data_list) 290 | newly_processed_count = 0 291 | 292 | for i, item in enumerate(data_list): 293 | item_id = item.get('Id') 294 | 295 | if item_id in processed_ids: 296 | continue 297 | 298 | newly_processed_count += 1 299 | 300 | try: 301 | base64_image = item['Image'] 302 | if not base64_image: 303 | raise ValueError("Image data is empty") 304 | 305 | prompt = item['Prompt'] 306 | input_html = item['Input_html'] 307 | 308 | prompt_text = f"{prompt}\nCode:\n{input_html}" 309 | 310 | messages = [ 311 | { 312 | "role": "user", 313 | "content": [ 314 | { 315 | "type": "text", 316 | "text": prompt_text 317 | }, 318 | { 319 | "type": "image_url", 320 | "image_url": { 321 | "url": f"data:image/jpeg;base64,{base64_image}" 322 | } 323 | } 324 | ] 325 | } 326 | ] 327 | 328 | original_response = generate_openai_response( 329 | client, 330 | model_name, 331 | messages=messages, 332 | ) 333 | 334 | html_response = extract_html_code(original_response) 335 | 336 | if html_response != original_response: 337 | print(f"ID {item_id}: Successfully extracted HTML code") 338 | 339 | result = { 340 | "Id": item_id, 341 | "Response": html_response, 342 | "Label_html": item.get('Label_html'), 343 | "Category": item.get('Category'), 344 | "Png_id": item.get('Png_id') 345 | } 346 | 347 | results.append(result) 348 | 349 | if newly_processed_count > 0 and (newly_processed_count % 1 == 0 or i == total_items - 1): 350 | if output_path: 351 | save_results_to_json(results, output_path) 352 | print(f"Processed {i + 1}/{total_items} items (added {newly_processed_count} new), saved results to {output_path}.") 353 | 354 | except Exception as e: 355 | print(f"Error processing Id {item_id}: {e}") 356 | result = { 357 | "Id": item_id, 358 | "Response": f"Processing error: {e}", 359 | "Label_html": item.get('Label_html'), 360 | "Category": item.get('Category'), 361 | "Png_id": item.get('Png_id') 362 | } 363 | results.append(result) 364 | 365 | if output_path: 366 | save_results_to_json(results, output_path) 367 | print(f"Error processing Id {item_id}, saved current results to {output_path}.") 368 | 369 | if newly_processed_count > 0 and output_path: 370 | save_results_to_json(results, output_path) 371 | print(f"Completed processing file, final results saved to: {output_path}") 372 | 373 | return results 374 | 375 | def process_interaction_to_code(data_list, client, model_name, existing_results=None, output_path=None): 376 | """Process interaction-to-code conversion data.""" 377 | results = [] 378 | 379 | processed_ids = set() 380 | if existing_results: 381 | processed_ids = {item.get('Id') for item in existing_results if item.get('Id')} 382 | results = existing_results.copy() 383 | print(f"Loaded {len(processed_ids)} already processed items (from {output_path}), will skip these items.") 384 | 385 | total_items = len(data_list) 386 | newly_processed_count = 0 387 | 388 | for i, item in enumerate(data_list): 389 | item_id = item.get('Id') 390 | 391 | if item_id in processed_ids: 392 | continue 393 | 394 | newly_processed_count += 1 395 | 396 | try: 397 | before_image = item['Before_image'] 398 | after_image = item['After_image'] 399 | 400 | if not before_image or not after_image: 401 | raise ValueError("Image data is empty") 402 | 403 | prompt = item['Prompt'] 404 | 405 | messages = [ 406 | { 407 | "role": "user", 408 | "content": [ 409 | { 410 | "type": "text", 411 | "text": prompt 412 | }, 413 | { 414 | "type": "image_url", 415 | "image_url": { 416 | "url": f"data:image/jpeg;base64,{before_image}" 417 | } 418 | }, 419 | { 420 | "type": "image_url", 421 | "image_url": { 422 | "url": f"data:image/jpeg;base64,{after_image}" 423 | } 424 | } 425 | ] 426 | } 427 | ] 428 | 429 | original_response = generate_openai_response( 430 | client, 431 | model_name, 432 | messages=messages, 433 | ) 434 | 435 | html_response = extract_html_code(original_response) 436 | 437 | if html_response != original_response: 438 | print(f"ID {item_id}: Successfully extracted HTML code") 439 | 440 | result = { 441 | "Id": item_id, 442 | "Interaction_type": item.get('Interaction_type'), 443 | "Response": html_response, 444 | "Label_html": item.get('Label_html'), 445 | "Category": item.get('Category'), 446 | "Png_id": item.get('Png_id') 447 | } 448 | 449 | results.append(result) 450 | 451 | if newly_processed_count > 0 and (newly_processed_count % 1 == 0 or i == total_items - 1): 452 | if output_path: 453 | save_results_to_json(results, output_path) 454 | print(f"Processed {i + 1}/{total_items} items (added {newly_processed_count} new), saved results to {output_path}.") 455 | 456 | except Exception as e: 457 | print(f"Error processing Id {item_id}: {e}") 458 | result = { 459 | "Id": item_id, 460 | "Interaction_type": item.get('Interaction_type'), 461 | "Response": f"Processing error: {e}", 462 | "Label_html": item.get('Label_html'), 463 | "Category": item.get('Category'), 464 | "Png_id": item.get('Png_id') 465 | } 466 | results.append(result) 467 | 468 | if output_path: 469 | save_results_to_json(results, output_path) 470 | print(f"Error processing Id {item_id}, saved current results to {output_path}.") 471 | 472 | if newly_processed_count > 0 and output_path: 473 | save_results_to_json(results, output_path) 474 | print(f"Completed processing file, final results saved to: {output_path}") 475 | 476 | return results 477 | 478 | def main(data_folder, api_key, model_name, output_base_dir, categories=None): 479 | """Main function that selects the appropriate processing method based on specified categories.""" 480 | # Ensure the base output directory exists 481 | os.makedirs(output_base_dir, exist_ok=True) 482 | 483 | # All supported task types and corresponding processing functions 484 | task_handlers = { 485 | "Text_to_code_mini": process_text_to_code, 486 | "Image_to_code_mini": process_image_to_code, 487 | "Code_Refinement_mini": process_refinement_to_code, 488 | "Interaction_Authoring_mini": process_interaction_to_code 489 | } 490 | 491 | # If no categories are specified, default to processing all supported categories 492 | if categories is None: 493 | categories = list(task_handlers.keys()) 494 | 495 | # Validate that all specified categories are supported 496 | unsupported_categories = [cat for cat in categories if cat not in task_handlers] 497 | if unsupported_categories: 498 | print(f"Warning: The following categories are not supported: {unsupported_categories}") 499 | # Filter out unsupported categories 500 | categories = [cat for cat in categories if cat in task_handlers] 501 | if not categories: 502 | print("No valid categories to process.") 503 | return 504 | 505 | print(f"Will process the following categories: {categories}") 506 | 507 | # Initialize OpenAI API client 508 | print("Initializing OpenAI API client...") 509 | try: 510 | client = OpenAI(api_key=api_key) 511 | print("OpenAI API client initialized successfully.") 512 | except Exception as e: 513 | print(f"Error initializing OpenAI API client: {e}") 514 | return 515 | 516 | # Process each selected category 517 | for category in categories: 518 | print("=" * 50) 519 | print(f"Starting to process category: {category}") 520 | 521 | # Find the corresponding parquet file for this category 522 | target_file = os.path.join(data_folder, f"{category}.parquet") 523 | 524 | if not os.path.exists(target_file): 525 | print(f"Could not find file for category {category}: {target_file}") 526 | continue 527 | 528 | print(f"Found category file: {target_file}") 529 | 530 | # Generate output filename 531 | output_filename = f"o1_{category}.json" 532 | current_output_path = os.path.join(output_base_dir, output_filename) 533 | print(f"Results will be saved to: {current_output_path}") 534 | 535 | # Load existing results for this file 536 | existing_results = load_existing_results(current_output_path) 537 | 538 | # Load data 539 | data_list = load_parquet_data(target_file) 540 | 541 | if not data_list: 542 | print(f"No data in file {target_file} or failed to load.") 543 | continue 544 | 545 | # Get the corresponding processing function 546 | process_func = task_handlers[category] 547 | 548 | # Process the data 549 | process_func( 550 | data_list, 551 | client, 552 | model_name, 553 | existing_results=existing_results, 554 | output_path=current_output_path 555 | ) 556 | 557 | print(f"Category {category} processing complete.") 558 | 559 | print("=" * 50) 560 | print("All specified category processing workflows completed.") 561 | 562 | if __name__ == "__main__": 563 | data_folder = "./mini_datasets" # Data folder path 564 | api_key = "your_api_key" # Replace with your actual API key 565 | model_name = "o1" 566 | output_dir = "./results/o1" # Directory for saving results 567 | 568 | # Specify which categories to process 569 | categories = ["Text_to_code_mini", "Image_to_code_mini", "Code_Refinement_mini", "Interaction_Authoring_mini"] 570 | 571 | # Call the main function 572 | main(data_folder, api_key, model_name, output_dir, categories) -------------------------------------------------------------------------------- /generate_response/geminipro_code.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gemini Pro Code Generation Tool 3 | 4 | Usage: 5 | 1. Install required dependencies: 6 | pip install pandas google-generativeai pillow 7 | 8 | 2. Set your Gemini API key in the script or as an environment variable. 9 | 10 | 3. Prepare your dataset files in the './datasets' directory in parquet format. 11 | 12 | 4. Run the script: 13 | python geminipro_code.py 14 | 15 | This script processes different types of tasks (text-to-code, image-to-code, etc.) 16 | using the Gemini API and saves the results to JSON files. 17 | """ 18 | 19 | import os 20 | import json 21 | import base64 22 | import glob 23 | import re 24 | import pandas as pd 25 | from io import BytesIO 26 | from PIL import Image 27 | from google import genai 28 | from google.genai import types 29 | import time 30 | from typing import List, Dict, Any, Optional 31 | 32 | # Common functions 33 | def load_parquet_data(file_path): 34 | """Load data from a single Parquet file.""" 35 | try: 36 | df = pd.read_parquet(file_path) 37 | return df.to_dict('records') 38 | except Exception as e: 39 | print(f"Failed to load file {file_path}: {e}") 40 | return [] 41 | 42 | def load_existing_results(output_path): 43 | """Load existing results file to skip already processed data.""" 44 | if os.path.exists(output_path): 45 | try: 46 | with open(output_path, 'r', encoding='utf-8') as f: 47 | return json.load(f) 48 | except Exception as e: 49 | print(f"Failed to load existing results file {output_path}: {e}") 50 | return [] 51 | 52 | def save_results_to_json(results, output_path="output.json"): 53 | """Save results to JSON file.""" 54 | try: 55 | with open(output_path, 'w', encoding='utf-8') as f: 56 | json.dump(results, f, ensure_ascii=False, indent=4) 57 | print(f"Results saved to: {output_path}") 58 | except Exception as e: 59 | print(f"Failed to save results to {output_path}: {e}") 60 | 61 | def extract_html_code(text): 62 | """Extract HTML code from text, starting with and ending with . 63 | If no matching HTML code is found, return the original text.""" 64 | html_pattern = r'(?:<\!DOCTYPE\s+html>.*?<\/html>)' 65 | matches = re.findall(html_pattern, text, re.DOTALL | re.IGNORECASE) 66 | if matches: 67 | return matches[0] 68 | if '' in text.lower() and '' in text.lower(): 69 | start = text.lower().find('') 70 | end = text.lower().find('') + len('') 71 | return text[start:end] 72 | return text 73 | 74 | def decode_base64_image(base64_string): 75 | """Decode base64 encoded image to bytes.""" 76 | try: 77 | image_bytes = base64.b64decode(base64_string) 78 | return image_bytes 79 | except Exception as e: 80 | print(f"Failed to decode base64 image: {e}") 81 | return None 82 | 83 | def generate_gemini_response(client, model_name, contents, temperature, max_tokens, retries=2): 84 | """Call Gemini API to generate response, with retry mechanism.""" 85 | for attempt in range(retries + 1): 86 | try: 87 | response = client.models.generate_content( 88 | model=model_name, 89 | contents=contents, 90 | config=types.GenerateContentConfig( 91 | temperature=temperature, 92 | thinking_config=thinking_config, 93 | max_output_tokens=max_tokens 94 | ) 95 | ) 96 | print(response) 97 | return response.text.strip() 98 | except Exception as e: 99 | if attempt < retries: 100 | wait_time = (attempt + 1) * 2 # Exponential backoff 101 | print(f"API call failed, retrying in {wait_time} seconds: {e}") 102 | time.sleep(wait_time) 103 | else: 104 | raise Exception(f"API call failed after {retries+1} attempts: {e}") 105 | 106 | def process_text_to_code(data_list, client, model_name, temperature, max_tokens, existing_results=None, output_path=None): 107 | """Process text-to-code conversion data.""" 108 | results = [] 109 | 110 | # Create a set of processed item IDs to skip 111 | processed_ids = set() 112 | if existing_results: 113 | processed_ids = {item.get('Id') for item in existing_results if item.get('Id')} 114 | results = existing_results.copy() 115 | print(f"Loaded {len(processed_ids)} already processed items (from {output_path}), will skip these items.") 116 | 117 | total_items = len(data_list) 118 | newly_processed_count = 0 119 | 120 | for i, item in enumerate(data_list): 121 | item_id = item.get('Id') 122 | 123 | # Skip already processed items 124 | if item_id in processed_ids: 125 | continue 126 | 127 | newly_processed_count += 1 128 | 129 | try: 130 | prompt = item['Prompt'] 131 | input_text = item['Input_text'] 132 | 133 | # Build prompt text 134 | prompt_text = f"{prompt}\nDescription:{input_text}" 135 | 136 | # Generate response 137 | original_response = generate_gemini_response( 138 | client, 139 | model_name, 140 | contents=prompt_text, 141 | temperature=temperature, 142 | max_tokens=max_tokens 143 | ) 144 | 145 | # Extract HTML code 146 | html_response = extract_html_code(original_response) 147 | 148 | # Log if HTML code was successfully extracted 149 | if html_response != original_response: 150 | print(f"ID {item_id}: Successfully extracted HTML code") 151 | 152 | result = { 153 | "Id": item_id, 154 | "Response": html_response, 155 | "Label_html": item.get('Label_html'), 156 | "Category": item.get('Category'), 157 | "Png_id": item.get('Png_id') 158 | } 159 | 160 | results.append(result) 161 | 162 | # Save results every 5 new items or at the last item 163 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 164 | if output_path: 165 | save_results_to_json(results, output_path) 166 | print(f"Processed {i + 1}/{total_items} items (new: {newly_processed_count}), saved results to {output_path}.") 167 | 168 | except Exception as e: 169 | print(f"Error processing Id {item_id}: {e}") 170 | result = { 171 | "Id": item_id, 172 | "Response": f"Processing error: {e}", 173 | "Label_html": item.get('Label_html'), 174 | "Category": item.get('Category'), 175 | "Png_id": item.get('Png_id') 176 | } 177 | results.append(result) 178 | 179 | # Save results on error 180 | if output_path: 181 | save_results_to_json(results, output_path) 182 | print(f"Error processing Id {item_id}, saved current results to {output_path}.") 183 | 184 | # Ensure final results are saved if any new items were processed 185 | if newly_processed_count > 0 and output_path: 186 | save_results_to_json(results, output_path) 187 | print(f"Completed processing file, final results saved to: {output_path}") 188 | 189 | return results 190 | 191 | def process_image_to_code(data_list, client, model_name, temperature, max_tokens, existing_results=None, output_path=None): 192 | """Process image-to-code conversion data.""" 193 | results = [] 194 | 195 | # Create a set of processed item IDs to skip 196 | processed_ids = set() 197 | if existing_results: 198 | processed_ids = {item.get('Id') for item in existing_results if item.get('Id')} 199 | results = existing_results.copy() 200 | print(f"Loaded {len(processed_ids)} already processed items (from {output_path}), will skip these items.") 201 | 202 | total_items = len(data_list) 203 | newly_processed_count = 0 204 | 205 | for i, item in enumerate(data_list): 206 | item_id = item.get('Id') 207 | 208 | # Skip already processed items 209 | if item_id in processed_ids: 210 | continue 211 | 212 | newly_processed_count += 1 213 | 214 | try: 215 | # Decode base64 image 216 | image_bytes = decode_base64_image(item['Image']) 217 | if image_bytes is None: 218 | raise ValueError("Image decoding failed") 219 | 220 | prompt = item['Prompt'] 221 | 222 | # Build API request content 223 | contents = [ 224 | prompt, 225 | types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg") 226 | ] 227 | 228 | # Generate response 229 | original_response = generate_gemini_response( 230 | client, 231 | model_name, 232 | contents=contents, 233 | temperature=temperature, 234 | max_tokens=max_tokens 235 | ) 236 | 237 | # Extract HTML code 238 | html_response = extract_html_code(original_response) 239 | 240 | # Log if HTML code was successfully extracted 241 | if html_response != original_response: 242 | print(f"ID {item_id}: Successfully extracted HTML code") 243 | 244 | result = { 245 | "Id": item_id, 246 | "Response": html_response, 247 | "Label_html": item.get('Label_html'), 248 | "Category": item.get('Category'), 249 | "Png_id": item.get('Png_id') 250 | } 251 | 252 | results.append(result) 253 | 254 | # Save results every 5 new items or at the last item 255 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 256 | if output_path: 257 | save_results_to_json(results, output_path) 258 | print(f"Processed {i + 1}/{total_items} items (new: {newly_processed_count}), saved results to {output_path}.") 259 | 260 | except Exception as e: 261 | print(f"Error processing Id {item_id}: {e}") 262 | result = { 263 | "Id": item_id, 264 | "Response": f"Processing error: {e}", 265 | "Label_html": item.get('Label_html'), 266 | "Category": item.get('Category'), 267 | "Png_id": item.get('Png_id') 268 | } 269 | results.append(result) 270 | 271 | # Save results on error 272 | if output_path: 273 | save_results_to_json(results, output_path) 274 | print(f"Error processing Id {item_id}, saved current results to {output_path}.") 275 | 276 | # Ensure final results are saved if any new items were processed 277 | if newly_processed_count > 0 and output_path: 278 | save_results_to_json(results, output_path) 279 | print(f"Completed processing file, final results saved to: {output_path}") 280 | 281 | return results 282 | 283 | def process_refinement_to_code(data_list, client, model_name, temperature, max_tokens, existing_results=None, output_path=None): 284 | results = [] 285 | 286 | # Create a set of processed item IDs to skip 287 | processed_ids = set() 288 | if existing_results: 289 | processed_ids = {item.get('Id') for item in existing_results if item.get('Id')} 290 | results = existing_results.copy() 291 | print(f"Loaded {len(processed_ids)} already processed items (from {output_path}), will skip these items.") 292 | 293 | total_items = len(data_list) 294 | newly_processed_count = 0 295 | 296 | for i, item in enumerate(data_list): 297 | item_id = item.get('Id') 298 | 299 | # Skip already processed items 300 | if item_id in processed_ids: 301 | continue 302 | 303 | newly_processed_count += 1 304 | 305 | try: 306 | # Decode base64 image 307 | image_bytes = decode_base64_image(item['Image']) 308 | if image_bytes is None: 309 | raise ValueError("Image decoding failed") 310 | 311 | prompt = item['Prompt'] 312 | input_html = item['Input_html'] 313 | 314 | # Build prompt text 315 | prompt_text = f"{prompt}\nCode:\n{input_html}" 316 | 317 | # Build API request content 318 | contents = [ 319 | prompt_text, 320 | types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg") 321 | ] 322 | 323 | # Generate response 324 | original_response = generate_gemini_response( 325 | client, 326 | model_name, 327 | contents=contents, 328 | temperature=temperature, 329 | max_tokens=max_tokens 330 | ) 331 | 332 | # Extract HTML code 333 | html_response = extract_html_code(original_response) 334 | 335 | # Log if HTML code was successfully extracted 336 | if html_response != original_response: 337 | print(f"ID {item_id}: Successfully extracted HTML code") 338 | 339 | result = { 340 | "Id": item_id, 341 | "Response": html_response, 342 | "Label_html": item.get('Label_html'), 343 | "Category": item.get('Category'), 344 | "Png_id": item.get('Png_id') 345 | } 346 | 347 | results.append(result) 348 | 349 | # Save results every 5 new items or at the last item 350 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 351 | if output_path: 352 | save_results_to_json(results, output_path) 353 | print(f"Processed {i + 1}/{total_items} items (new: {newly_processed_count}), saved results to {output_path}.") 354 | 355 | except Exception as e: 356 | print(f"Error processing Id {item_id}: {e}") 357 | result = { 358 | "Id": item_id, 359 | "Response": f"Processing error: {e}", 360 | "Label_html": item.get('Label_html'), 361 | "Category": item.get('Category'), 362 | "Png_id": item.get('Png_id') 363 | } 364 | results.append(result) 365 | 366 | # Save results on error 367 | if output_path: 368 | save_results_to_json(results, output_path) 369 | print(f"Error processing Id {item_id}, saved current results to {output_path}.") 370 | 371 | # Ensure final results are saved if any new items were processed 372 | if newly_processed_count > 0 and output_path: 373 | save_results_to_json(results, output_path) 374 | print(f"Completed processing file, final results saved to: {output_path}") 375 | 376 | return results 377 | 378 | def process_interaction_to_code(data_list, client, model_name, temperature, max_tokens, existing_results=None, output_path=None): 379 | """Process interaction-to-code conversion data.""" 380 | results = [] 381 | 382 | # Create a set of processed item IDs to skip 383 | processed_ids = set() 384 | if existing_results: 385 | processed_ids = {item.get('Id') for item in existing_results if item.get('Id')} 386 | results = existing_results.copy() 387 | print(f"Loaded {len(processed_ids)} already processed items (from {output_path}), will skip these items.") 388 | 389 | total_items = len(data_list) 390 | newly_processed_count = 0 391 | 392 | for i, item in enumerate(data_list): 393 | item_id = item.get('Id') 394 | 395 | # Skip already processed items 396 | if item_id in processed_ids: 397 | continue 398 | 399 | newly_processed_count += 1 400 | 401 | try: 402 | # Decode two base64 images 403 | before_image_bytes = decode_base64_image(item['Before_image']) 404 | after_image_bytes = decode_base64_image(item['After_image']) 405 | 406 | # Check if images were decoded correctly 407 | if before_image_bytes is None or after_image_bytes is None: 408 | raise ValueError("Failed to decode images") 409 | 410 | prompt = item['Prompt'] 411 | 412 | # Build API request content - using two images 413 | contents = [ 414 | prompt, 415 | types.Part.from_bytes(data=before_image_bytes, mime_type="image/jpeg"), 416 | types.Part.from_bytes(data=after_image_bytes, mime_type="image/jpeg") 417 | ] 418 | 419 | # Generate response 420 | original_response = generate_gemini_response( 421 | client, 422 | model_name, 423 | contents=contents, 424 | temperature=temperature, 425 | max_tokens=max_tokens 426 | ) 427 | 428 | # Extract HTML code 429 | html_response = extract_html_code(original_response) 430 | 431 | # Log if HTML code was successfully extracted 432 | if html_response != original_response: 433 | print(f"ID {item_id}: Successfully extracted HTML code") 434 | 435 | result = { 436 | "Id": item_id, 437 | "Interaction_type": item.get('Interaction_type'), 438 | "Response": html_response, 439 | "Label_html": item.get('Label_html'), 440 | "Category": item.get('Category'), 441 | "Png_id": item.get('Png_id') 442 | } 443 | 444 | results.append(result) 445 | 446 | # Save results every 5 new items or at the last item 447 | if newly_processed_count > 0 and (newly_processed_count % 5 == 0 or i == total_items - 1): 448 | if output_path: 449 | save_results_to_json(results, output_path) 450 | print(f"Processed {i + 1}/{total_items} items (new: {newly_processed_count}), saved results to {output_path}.") 451 | 452 | except Exception as e: 453 | print(f"Error processing Id {item_id}: {e}") 454 | result = { 455 | "Id": item_id, 456 | "Interaction_type": item.get('Interaction_type'), 457 | "Response": f"Processing error: {e}", 458 | "Label_html": item.get('Label_html'), 459 | "Category": item.get('Category'), 460 | "Png_id": item.get('Png_id') 461 | } 462 | results.append(result) 463 | 464 | # Save results on error 465 | if output_path: 466 | save_results_to_json(results, output_path) 467 | print(f"Error processing Id {item_id}, saved current results to {output_path}.") 468 | 469 | # Ensure final results are saved if any new items were processed 470 | if newly_processed_count > 0 and output_path: 471 | save_results_to_json(results, output_path) 472 | print(f"Completed processing file, final results saved to: {output_path}") 473 | 474 | return results 475 | 476 | def main(data_folder, api_key, model_name, output_base_dir, categories=None, max_tokens=8000, temperature=0): 477 | """Main function to process data based on specified categories.""" 478 | # Ensure base output directory exists 479 | os.makedirs(output_base_dir, exist_ok=True) 480 | 481 | # All supported task types and their corresponding processing functions 482 | task_handlers = { 483 | "Text_to_code": process_text_to_code, 484 | "Image_to_code": process_image_to_code, 485 | "Code_Refinement": process_refinement_to_code, 486 | "Interaction_Authoring": process_interaction_to_code 487 | } 488 | 489 | # If no categories specified, process all supported categories 490 | if categories is None: 491 | categories = list(task_handlers.keys()) 492 | 493 | # Validate if specified categories are supported 494 | unsupported_categories = [cat for cat in categories if cat not in task_handlers] 495 | if unsupported_categories: 496 | print(f"Warning: The following categories are not supported: {unsupported_categories}") 497 | # Filter out unsupported categories 498 | categories = [cat for cat in categories if cat in task_handlers] 499 | if not categories: 500 | print("No valid categories to process.") 501 | return 502 | 503 | print(f"Will process the following categories: {categories}") 504 | 505 | # --- Initialize Gemini API client --- 506 | print("Initializing Gemini API client...") 507 | try: 508 | # Instantiate Gemini API client 509 | client = genai.Client(api_key=api_key) 510 | print("Gemini API client initialized successfully.") 511 | except Exception as e: 512 | print(f"Error initializing Gemini API client: {e}") 513 | return 514 | 515 | # --- Process each selected category --- 516 | for category in categories: 517 | print("=" * 50) 518 | print(f"Starting to process category: {category}") 519 | 520 | # Find corresponding parquet file for the category 521 | target_file = os.path.join(data_folder, f"{category}.parquet") 522 | 523 | if not os.path.exists(target_file): 524 | print(f"Could not find file for category {category}: {target_file}") 525 | continue 526 | 527 | print(f"Found category file: {target_file}") 528 | 529 | # Generate output filename 530 | output_filename = f"gemini_{category}.json" 531 | current_output_path = os.path.join(output_base_dir, output_filename) 532 | print(f"Results will be saved to: {current_output_path}") 533 | 534 | # Load existing results for the current file 535 | existing_results = load_existing_results(current_output_path) 536 | 537 | # Load data 538 | data_list = load_parquet_data(target_file) 539 | 540 | if not data_list: 541 | print(f"No data in file {target_file} or loading failed.") 542 | continue 543 | 544 | # Get the corresponding processing function 545 | process_func = task_handlers[category] 546 | 547 | # Process data 548 | process_func( 549 | data_list, 550 | client, 551 | model_name, 552 | temperature, 553 | max_tokens, 554 | existing_results=existing_results, 555 | output_path=current_output_path 556 | ) 557 | 558 | print(f"Category {category} processing completed.") 559 | 560 | print("=" * 50) 561 | print("All specified categories processing completed.") 562 | 563 | if __name__ == "__main__": 564 | data_folder = "./datasets" # Data folder path 565 | api_key = "your_api_key" # Gemini API key 566 | model_name = "gemini-2.5-pro-preview-05-06" # Gemini model name 567 | output_dir = "./results/gemini" # Directory to save results 568 | 569 | # Specify categories to process 570 | categories = ["Text_to_code", "Image_to_code", "Code_Refinement", "Interaction_Authoring"] 571 | 572 | max_tokens = 20000 # Maximum tokens to generate 573 | temperature = 0 # Temperature parameter 574 | thinking_config = types.ThinkingConfig(thinking_budget=0) 575 | 576 | # Call main function 577 | main(data_folder, api_key, model_name, output_dir, categories, max_tokens, temperature) --------------------------------------------------------------------------------