├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── create_data.py │ ├── images_dalle.py │ ├── mrt_blender.py │ ├── folding_pil.py │ ├── create_prompts.py │ ├── mrt.py │ └── create_images.py ├── utils │ ├── __init__.py │ └── vlm_wrapper.py ├── eval_openai.py ├── eval │ └── acc.py └── eval.py ├── .gitignore ├── scripts ├── run_clean_eval.sh └── run.sh ├── LICENSE ├── requirements.txt └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .vlm_wrapper import VLMWrapper -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python cache and compiled files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # Virtual environments 6 | venv/ 7 | env/ 8 | .venv/ 9 | .env 10 | 11 | # Distribution/build 12 | build/ 13 | dist/ 14 | *.egg-info/ 15 | 16 | # Jupyter notebooks 17 | .ipynb_checkpoints/ 18 | 19 | # Machine learning artifacts 20 | *.h5 21 | *.ckpt 22 | logs/ 23 | models/ 24 | 25 | # Logging files 26 | *.log 27 | .neptune/ 28 | wandb 29 | 30 | # Hidden files 31 | .DS_Store 32 | *.swp 33 | ._* 34 | 35 | # Images 36 | output/ 37 | 38 | # Files 39 | *.csv 40 | *.json 41 | *.yaml 42 | *.yml 43 | *.jsonl 44 | *.html 45 | 46 | # Legacy files 47 | example/ 48 | legacy/ 49 | 50 | Dockerfile -------------------------------------------------------------------------------- /scripts/run_clean_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Simple bash script to run clean_eval.py 4 | # This script processes CSV files in the srbench evaluation output directory 5 | 6 | # Set the input directory containing CSV files 7 | INPUT_DIR="cot_test/*.csv" 8 | RESPONSE_COLUMN="response" 9 | CORRECT_COLUMN="gold answer" 10 | SPLIT_COLUMN="split" 11 | OUTPUT="cleaned_results/accuracy_results.csv" 12 | 13 | echo "Running clean_eval.py..." 14 | echo "Input directory: $INPUT_DIR" 15 | echo "Response column: $RESPONSE_COLUMN" 16 | echo "Output file: $OUTPUT" 17 | 18 | # Run the clean_eval.py script 19 | python src/eval/acc.py --input_pattern "$INPUT_DIR" \ 20 | --response_column "$RESPONSE_COLUMN" \ 21 | --correct_column "$CORRECT_COLUMN" \ 22 | --split_column "$SPLIT_COLUMN" \ 23 | --output "$OUTPUT" 24 | 25 | echo "Clean evaluation completed!" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Ilias M. Stogiannidis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/data/create_data.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset, Image, Features, Value 2 | import json 3 | from PIL import Image as PILImage 4 | import os 5 | import glob 6 | 7 | 8 | def generate_examples(): 9 | for data in annotations: 10 | image_path = os.path.join("Spatial-MM/data/spatial_mm", data["image_name"]) 11 | try: 12 | image = PILImage.open(image_path).convert("RGB") # Load image as RGB 13 | except Exception as e: 14 | print("Error loading image: ", image_path) 15 | continue 16 | yield { 17 | "image": image, 18 | "question": data["question"], 19 | "answer": data["answer"], 20 | } 21 | 22 | if __name__ == "__main__": 23 | # Load JSON annotations from multiple files 24 | annotations = [] 25 | json_files = glob.glob("Spatial-MM/data/*.json") 26 | for json_file in json_files: 27 | with open(json_file, "r") as f: 28 | annotations.extend(json.load(f)) 29 | 30 | # Define dataset features (adjust based on your JSON structure) 31 | features = Features( 32 | { 33 | "image": Image(), 34 | "question": Value("string"), 35 | "answer": Value("string"), 36 | } 37 | ) 38 | 39 | 40 | # Create the dataset 41 | dataset = Dataset.from_generator( 42 | generate_examples, 43 | features=features, 44 | ) 45 | 46 | dataset.push_to_hub("spatial_mm", private= True) # Push to the Hub 47 | 48 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | DATASET="stogian/srbench2" 2 | ONESHOT_PATH="/netdisk/users/stogian/srbench/example/oneshot.json" 3 | 4 | models=( 5 | "OpenGVLab/InternVL3_5-8B-HF" 6 | "OpenGVLab/InternVL3_5-14B-HF" 7 | "OpenGVLab/InternVL3_5-38B-HF" 8 | "OpenGVLab/InternVL3_5-30B-A3B-HF" 9 | "OpenGVLab/InternVL3_5-241B-A28B-HF" 10 | "google/gemma-3-12b-it" 11 | "google/gemma-3-27b-it" 12 | "Qwen/Qwen3-VL-8B-Thinking" 13 | "Qwen/Qwen3-VL-30B-A3B-Thinking" 14 | "Qwen/Qwen3-VL-235B-A22B-Thinking" 15 | "Qwen/Qwen3-VL-8B-Instruct" 16 | "Qwen/Qwen3-VL-30B-A3B-Instruct" 17 | "Qwen/Qwen3-VL-235B-A22B-Instruct" 18 | "meta-llama/Llama-3.2-11B-Vision-Instruct" 19 | "meta-llama/Llama-3.2-90B-Vision-Instruct" 20 | "HuggingFaceM4/Idefics3-8B-Llama3" 21 | "llava-hf/llava-1.5-7b-hf" 22 | "llava-hf/llava-v1.6-mistral-7b-hf" 23 | "openbmb/MiniCPM-V-2_6" 24 | "HuggingFaceTB/SmolVLM-500M-Instruct" 25 | "HuggingFaceTB/SmolVLM-Instruct" 26 | "moonshotai/Kimi-VL-A3B-Thinking-2506" 27 | "moonshotai/Kimi-VL-A3B-Instruct" 28 | "zai-org/GLM-4.1V-9B-Thinking" 29 | "meta-llama/Llama-4-Maverick-17B-128E-Instruct" 30 | "meta-llama/Llama-4-Scout-17B-128E-Instruct" 31 | ) 32 | 33 | for MODEL in "${models[@]}"; do 34 | echo "Running inference for model: $MODEL" 35 | 36 | # Determine optimal batch size based on model size 37 | if [[ $MODEL == *"11B"* ]]; then 38 | BATCH_SIZE=16 39 | elif [[ $MODEL == *"3B"* ]]; then 40 | BATCH_SIZE=32 41 | elif [[ $MODEL == *"8B"* ]]; then 42 | BATCH_SIZE=16 43 | else 44 | BATCH_SIZE=8 45 | fi 46 | 47 | python src/eval.py --model $MODEL \ 48 | --dataset $DATASET \ 49 | --batch_size $BATCH_SIZE \ 50 | --seed 123 \ 51 | --num_workers 8 52 | 53 | python src/eval.py --model $MODEL \ 54 | --dataset $DATASET \ 55 | --cot \ 56 | --batch_size $BATCH_SIZE \ 57 | --seed 123 \ 58 | --num_workers 8 59 | 60 | 61 | 62 | 63 | done 64 | 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.10.0 2 | aiohappyeyeballs==2.5.0 3 | aiohttp==3.12.14 4 | aiosignal==1.4.0 5 | av==14.1.0 6 | babel==2.16.0 7 | beautifulsoup4==4.12.3 8 | boto3==1.36.12 9 | botocore==1.36.12 10 | bravado==11.0.3 11 | bravado-core==6.1.1 12 | cachetools==5.5.1 13 | certifi==2024.12.14 14 | click==8.1.8 15 | contourpy==1.3.1 16 | cycler==0.12.1 17 | datasets==3.3.1 18 | diffusers==0.32.2 19 | dill==0.3.8 20 | distro==1.9.0 21 | docker-pycreds==0.4.0 22 | einops==0.8.0 23 | et-xmlfile==2.0.0 24 | filelock==3.17.0 25 | flash-attn==2.7.3 26 | fonttools==4.55.8 27 | frozenlist==1.5.0 28 | future==1.0.0 29 | google-api-core==2.24.1 30 | google-auth==2.38.0 31 | google-auth-oauthlib==1.2.1 32 | google-cloud-core==2.4.1 33 | google-cloud-storage==3.0.0 34 | google-crc32c==1.6.0 35 | google-resumable-media==2.7.2 36 | googleapis-common-protos==1.67.0 37 | hf-transfer==0.1.9 38 | huggingface-hub==0.34.2 39 | iniconfig==2.0.0 40 | inquirerpy==0.3.4 41 | jiter==0.8.2 42 | jmespath==1.0.1 43 | joblib==1.4.2 44 | jsonref==1.1.0 45 | jupyter-events==0.11.0 46 | jupyterlab==4.3.4 47 | kiwisolver==1.4.8 48 | matplotlib==3.10.0 49 | monotonic==1.6 50 | mpmath==1.3.0 51 | multidict==6.1.0 52 | multiprocess==0.70.16 53 | neptune==1.13.0 54 | networkx==3.4.2 55 | numpy==2.2.2 56 | nvitop==1.4.2 57 | oauthlib==3.2.2 58 | openai==1.63.2 59 | openpyxl==3.1.5 60 | pandas==2.2.3 61 | pandocfilters==1.5.1 62 | pfzy==0.3.4 63 | pillow==11.1.0 64 | pluggy==1.5.0 65 | propcache==0.2.1 66 | proto-plus==1.26.0 67 | protobuf==5.29.5 68 | pyarrow==19.0.0 69 | pyasn1==0.6.1 70 | pyasn1-modules==0.4.1 71 | pydantic==2.10.6 72 | pydantic-core==2.27.2 73 | pyjwt==2.10.1 74 | pyparsing==3.2.1 75 | pytest==8.3.4 76 | python-dotenv==1.0.1 77 | python-json-logger==3.2.1 78 | pytz==2024.2 79 | qwen-vl-utils==0.0.10 80 | regex==2024.11.6 81 | requests-oauthlib==2.0.0 82 | rsa==4.9 83 | s3transfer==0.11.2 84 | safetensors==0.5.2 85 | scikit-learn==1.6.1 86 | scipy==1.15.1 87 | seaborn==0.13.2 88 | sentencepiece==0.2.0 89 | sentry-sdk==2.20.0 90 | setproctitle==1.3.4 91 | setuptools==75.8.0 92 | simplejson==3.19.3 93 | soupsieve==2.6 94 | swagger-spec-validator==3.0.4 95 | sympy==1.13.1 96 | threadpoolctl==3.5.0 97 | timm==1.0.14 98 | tokenizers==0.21.0 99 | torchvision==0.21.0 100 | tqdm==4.67.1 101 | transformers==4.55.0 102 | triton==3.2.0 103 | tzdata==2025.1 104 | xxhash==3.5.0 105 | yarl==1.18.3 106 | hf-xet==1.1.5 107 | tiktoken==0.9.0 108 | num2words==0.5.14 109 | blobfile==3.0.0 -------------------------------------------------------------------------------- /src/data/images_dalle.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate images using DALL-E 3. 3 | 4 | Note: DALL-E 3 requires version 1.0.0 of the openai-python library or later. 5 | """ 6 | 7 | import os 8 | import argparse 9 | import json 10 | import requests 11 | from tqdm import tqdm 12 | from openai import AzureOpenAI 13 | from dotenv import load_dotenv 14 | 15 | load_dotenv(override=True) 16 | 17 | OUTPUT_DIR = "output/images/dalle3" 18 | os.makedirs(OUTPUT_DIR, exist_ok=True) 19 | 20 | 21 | def parse_arguments(): 22 | """Parse command-line arguments.""" 23 | parser = argparse.ArgumentParser( 24 | description="Generate images using DALL-E 3." 25 | ) 26 | parser.add_argument( 27 | "-f", 28 | "--metadata_file", 29 | type=str, 30 | required=True, 31 | help="Path to the JSON file containing the metadata for image generation.", 32 | ) 33 | return parser.parse_args() 34 | 35 | 36 | def load_metadata(metadata_file: str): 37 | """Load metadata from a JSON Lines file.""" 38 | with open(metadata_file, "r") as file: 39 | lines = file.readlines() 40 | return [json.loads(line.strip()) for line in lines if line.strip()] 41 | 42 | 43 | def main(): 44 | """Main function for generating images.""" 45 | args = parse_arguments() 46 | metadata = load_metadata(args.metadata_file) 47 | 48 | client = AzureOpenAI( 49 | api_version="2024-02-01", 50 | azure_endpoint=os.environ["DALLE_ENDPOINT"], 51 | api_key=os.environ["AZURE_OPENAI_API_KEY"], 52 | ) 53 | 54 | for idx, item in enumerate( 55 | tqdm(metadata, desc="Generating images", unit="prompt", colour="#000080") 56 | ): 57 | prompt = item["generated_scene_description"] 58 | print(f"Generating image for prompt: {prompt}") 59 | output_filename = f"image_{idx}.png" 60 | result = client.images.generate( 61 | model="dall-e-3", # the name of your DALL-E 3 deployment 62 | prompt="A photo-realistic image of " + prompt, 63 | n=1, 64 | ) 65 | image_data = json.loads(result.model_dump_json())["data"][0] 66 | image_url = image_data["url"] 67 | 68 | # Placeholder for image download functionality. 69 | print(f"Generated image '{output_filename}' from URL: {image_url}") 70 | 71 | # Save the image to the output directory 72 | image_path = os.path.join(OUTPUT_DIR, output_filename) 73 | with open(image_path, "wb") as file: 74 | file.write(requests.get(image_url).content) 75 | print(f"Saved image to '{image_path}'") 76 | 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SRBench 2 | 3 | [![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) 4 | [![Python](https://img.shields.io/badge/Python-3.12-blue.svg)](https://www.python.org/downloads/release/python-3120/) 5 | [![Hugging Face](https://img.shields.io/badge/Hugging%20Face-stogian%2Fsrbench-blue.svg)](https://huggingface.co/datasets/stogian/srbench) 6 | [![ArXiv](https://img.shields.io/badge/ArXiv-2503.19707-brown.svg)](https://arxiv.org/abs/2503.19707) 7 | 8 | Welcome to our project! This repository contains all the source code, tests, and documentation required to understand and run the project. Below is an overview of the repository structure, installation, usage instructions, and contribution guidelines. 9 | 10 | ## Overview 11 | 12 | This repository is divided into several modules that cover various aspects of the project, including: 13 | - **Data Processing**: Scripts for loading, processing, and analyzing data. 14 | - **Analysis Tools**: Modules that perform computations and run experiments. 15 | - **Visualization Components**: Code for rendering results and generating reports. 16 | 17 | ## Repository Structure 18 | 19 | The project is organized as follows: 20 | ``` 21 | SRBench/ 22 | ├── bin/ # Storage for model binaries 23 | ├── scripts/ # Bash scripts for running the project 24 | │ └── run.sh # Main execution script 25 | ├── src/ # Source code of the project 26 | │ ├── data_creation/ # Scripts for data creation 27 | │ ├── __init__.py # Initialization file 28 | │ ├── create_data.py # Script for data creation 29 | │ ├── create_images.py # Script to create images 30 | │ └── create_prompts.py # Script to generate prompts 31 | │ ├── utils/ # Utility functions 32 | │ ├── __init__.py # Initialization file 33 | │ ├── vlm_helpers.py # Helper functions for the VLM models 34 | │ ├── eval.py # Evaluation script 35 | │ ├── eval_intern.py # Evaluation script for InternVL 36 | │ ├── eval_openai.py # Evaluation script for OpenAI models 37 | │ ├── eval_mini.py # Evaluation script for MiniCPM-V 38 | ├── .gitignore # Files and directories to ignore 39 | ├── requirements.txt # Required packages 40 | ├── LICENSE # MIT License file 41 | └── README.md # Project documentation 42 | ``` 43 | 44 | ## Installation 45 | 46 | 1. Clone the repository: 47 | ```bash 48 | git clone https://github.com/stogiannidis/srbench.git 49 | cd srbench 50 | ``` 51 | 2. Create a virtual environment: 52 | ```bash 53 | python3 -m venv venv 54 | source venv/bin/activate 55 | ``` 56 | or using `conda`: 57 | ```bash 58 | conda create -n srbench python=3.12 59 | conda activate srbench 60 | ``` 61 | 3. Install the required packages: 62 | ```bash 63 | pip install -r requirements.txt 64 | ``` 65 | 66 | ## Usage 67 | 68 | To run the project, follow these steps: 69 | 1. Fetch the dataset from `Hugging Face`: 70 | ```bash 71 | huggingface-cli login 72 | huggingface-cli download stogiannidis/srbench 73 | ``` 74 | 2. Run the script: 75 | ```bash 76 | bash scripts/run.sh 77 | 78 | ## Citation 79 | ``` 80 | @misc{stogiannidis2025mindgapbenchmarkingspatial, 81 | title={Mind the Gap: Benchmarking Spatial Reasoning in Vision-Language Models}, 82 | author={Ilias Stogiannidis and Steven McDonagh and Sotirios A. Tsaftaris}, 83 | year={2025}, 84 | eprint={2503.19707}, 85 | archivePrefix={arXiv}, 86 | primaryClass={cs.CV}, 87 | url={https://arxiv.org/abs/2503.19707}, 88 | } 89 | ``` 90 | 91 | ## Contributing 92 | 93 | Contributions are welcome! Please follow these steps: 94 | - Fork the repository. 95 | - Create a new branch (`git checkout -b feature/your_feature`). 96 | - Commit your changes (`git commit -am 'Add new feature'`). 97 | - Push to the branch (`git push origin feature/your_feature`). 98 | - Open a Pull Request. 99 | 100 | ## License 101 | 102 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for more information. 103 | 104 | ## Contact 105 | 106 | For questions or feedback, please open an issue or contact me via email. 107 | -------------------------------------------------------------------------------- /src/eval_openai.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import base64 3 | import logging 4 | import os 5 | import time 6 | from functools import wraps 7 | from io import BytesIO 8 | from typing import List 9 | 10 | import dotenv 11 | import pandas as pd 12 | from datasets import load_dataset 13 | from openai import AzureOpenAI 14 | from PIL import Image 15 | from tqdm import tqdm 16 | 17 | # Load the environment variables 18 | dotenv.load_dotenv() 19 | 20 | # Set up logging 21 | logging.basicConfig(filename="example.log", encoding="utf-8", level=logging.INFO) 22 | 23 | 24 | def retry_with_exponential_backoff(max_retries=5, initial_delay=5): 25 | """ 26 | Retry decorator that retry a function call with exponential backoff. 27 | 28 | Args: 29 | max_retries (int): Maximum number of retries. 30 | initial_delay (int): Initial delay in seconds. 31 | 32 | Returns: 33 | Callable: Decorated function with retry logic. 34 | """ 35 | 36 | def decorator(func): 37 | @wraps(func) 38 | def wrapper(*args, **kwargs): 39 | delay = initial_delay 40 | for _ in range(max_retries): 41 | try: 42 | return func(*args, **kwargs) 43 | except Exception as e: 44 | if "429" in str(e): 45 | print(f"Rate limit exceeded. Retrying in {delay} seconds...") 46 | time.sleep(delay) 47 | delay *= 2 # exponential backoff 48 | else: 49 | raise e 50 | raise Exception("Max retries exceeded") 51 | 52 | return wrapper 53 | 54 | return decorator 55 | 56 | 57 | def image_to_base64(image: Image.Image) -> str: 58 | """ 59 | Convert a PIL Image to a base64 encoded string. 60 | 61 | Args: 62 | image (Image.Image): The image to convert. 63 | 64 | Returns: 65 | str: A base64 encoded string of the image. 66 | """ 67 | buffered = BytesIO() 68 | image.save(buffered, format="PNG") 69 | img_str = base64.b64encode(buffered.getvalue()).decode("ascii") 70 | return img_str 71 | 72 | 73 | @retry_with_exponential_backoff() 74 | def infer(prompts: List[str], images: List[Image.Image], model: str) -> List[str]: 75 | """ 76 | Infer responses using Azure OpenAI from provided prompts and images. 77 | 78 | Args: 79 | prompts (List[str]): A list of prompts/questions. 80 | images (List[Image.Image]): A list of images corresponding to the prompts. 81 | model (str): The model identifier to use for inference. 82 | 83 | Returns: 84 | List[str]: A list of responses from the model. 85 | """ 86 | if not os.getenv("AZURE_OPENAI_API_KEY") or not os.getenv("AZURE_OPENAI_ENDPOINT"): 87 | raise EnvironmentError( 88 | "Azure OpenAI API key or endpoint not set in environment variables." 89 | ) 90 | 91 | endpoint = ( 92 | os.getenv("AZURE_OPENAI_ENDPOINT") 93 | if model == "gpt-4o" 94 | else os.getenv("O1_ENDPOINT") 95 | ) 96 | 97 | client = AzureOpenAI( 98 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), 99 | api_version="2024-07-01-preview", 100 | azure_endpoint=endpoint, 101 | ) 102 | 103 | contents = [] 104 | for prompt, image in zip(prompts, images): 105 | if isinstance(image, Image.Image): 106 | image_content = image_to_base64(image) 107 | image_content = f"data:image/png;base64,{image_content}" 108 | else: 109 | image_content = str(image) 110 | 111 | 112 | # Merge the image and text into one message 113 | messages = [ 114 | {"role": "system", "content": "You are a helpful AI assistant."}, 115 | { 116 | "role": "user", 117 | "content": [ 118 | {"type": "image_url", "image_url": {"url": image_content}}, 119 | {"type": "text", "text": prompt}, 120 | ], 121 | }, 122 | ] 123 | 124 | response = client.chat.completions.create( 125 | model=model, 126 | messages=messages, 127 | seed=69, 128 | temperature=0.5, 129 | max_tokens=64, 130 | ) 131 | content = response.choices[0].message.content.strip() 132 | contents.append(content) 133 | 134 | return contents 135 | 136 | 137 | def main(): 138 | """ 139 | Main function to process the dataset and perform inference based on 140 | provided command-line arguments for dataset name and model. 141 | """ 142 | parser = argparse.ArgumentParser( 143 | description="Process dataset and perform inference using Azure OpenAI." 144 | ) 145 | parser.add_argument( 146 | "--dataset", 147 | type=str, 148 | required=True, 149 | help="Dataset name in Hugging Face format (e.g., stogian/mrt_pf_mix)", 150 | ) 151 | parser.add_argument( 152 | "--model", type=str, required=True, help="Model name (e.g., gpt-4o)" 153 | ) 154 | parser.add_argument( 155 | "--batch_size", type=int, default=16, help="Batch size for processing abstracts" 156 | ) 157 | args = parser.parse_args() 158 | 159 | dataset_name = args.dataset 160 | short_name = dataset_name.split("/")[-1] 161 | 162 | model = args.model 163 | model_name = model.split("/")[-1] 164 | 165 | # Load the specified dataset 166 | dataset = load_dataset(dataset_name, split="train") 167 | 168 | all_responses = [] 169 | 170 | for i in tqdm( 171 | range(0, len(dataset), args.batch_size), 172 | desc="Processing batches", 173 | unit="batch", 174 | leave=False, 175 | colour="magenta", 176 | ): 177 | batch = dataset[i : i + args.batch_size] 178 | prompts = batch["question"] 179 | images = batch["image"] 180 | 181 | responses = infer(prompts, images, model) 182 | all_responses.extend(responses) 183 | 184 | results_df = pd.DataFrame( 185 | { 186 | "question": dataset["question"], 187 | "response": all_responses, 188 | "answer": dataset["answer"], 189 | "split": dataset["split"], 190 | } 191 | ) 192 | results_dir = f"output/evaluations/{short_name}/" 193 | os.makedirs(results_dir, exist_ok=True) 194 | 195 | results_df.to_csv(os.path.join(results_dir, f"{model_name}.csv"), index=False) 196 | print(f"Results saved to {os.path.join(results_dir, 'results.csv')}") 197 | 198 | 199 | if __name__ == "__main__": 200 | main() 201 | -------------------------------------------------------------------------------- /src/eval/acc.py: -------------------------------------------------------------------------------- 1 | import re 2 | import glob 3 | import os 4 | import pandas as pd 5 | import argparse 6 | import sys 7 | 8 | 9 | def normalize_answer(ans): 10 | try: 11 | if isinstance(ans, str): 12 | ans = ans.strip().upper() 13 | ans = ans.split(".")[0] # Take text before the first period 14 | except Exception as e: 15 | pass 16 | return ans 17 | 18 | 19 | def extract_answer(text): 20 | """ 21 | Extracts the final answer from an LLM's output, supporting single letters, 22 | specific words, Markdown bolding, and mixed formats (e.g., "C. Left"). 23 | 24 | Args: 25 | text (str): The LLM-generated text. 26 | 27 | Returns: 28 | str: The extracted answer in uppercase (e.g., "A", "YES") or "None" if not found. 29 | """ 30 | if pd.isna(text) or text == "": 31 | return "None" 32 | 33 | text = str(text) 34 | 35 | # --- Step 1: Clean up the input text --- 36 | prefixes_to_remove = ["Assistant", "ASSISTANT", "[INST]", "assistant"] 37 | first_prefix_pos = len(text) 38 | for prefix in prefixes_to_remove: 39 | pos = text.find(prefix) 40 | if pos != -1: 41 | first_prefix_pos = min(first_prefix_pos, pos) 42 | 43 | if first_prefix_pos != len(text): 44 | text = text[first_prefix_pos:] 45 | 46 | # --- Step 2: Define answer patterns --- 47 | word_answers = ["yes", "no", "left", "right", "back", "front", "center"] 48 | core_pattern = r"\b(" + "|".join(word_answers) + r"|[A-Z])\b" 49 | answer_pattern = r"(?:\*\*)?" + core_pattern + r"(?:\*\*)?" 50 | 51 | # --- Step 3: Check for answers in different formats, from most to least specific --- 52 | 53 | # A. Check for answers in curly brackets, e.g., {**A**} 54 | curly_pattern = ( 55 | r"\{" + r"(?:\*\*)?(" + "|".join(word_answers) + r"|[A-Z])(?:\*\*)?" + r"\}" 56 | ) 57 | curly_match = re.search(curly_pattern, text, re.IGNORECASE) 58 | if curly_match: 59 | return curly_match.group(1).upper() 60 | 61 | # B. Check for mixed format like "C. Left" and prioritize the letter. 62 | mixed_pattern = r"\b([A-Z])(?:\.|:|\))\s*(?:" + "|".join(word_answers) + r")\b" 63 | mixed_match = re.search(mixed_pattern, text, re.IGNORECASE) 64 | if mixed_match: 65 | return mixed_match.group(1).upper() 66 | 67 | # C. Check for phrases that typically precede or follow the answer 68 | before_phrases = [ 69 | "the answer is", 70 | "i think it's", 71 | "i choose", 72 | "i'll go with", 73 | "it's", 74 | "the correct choice is", 75 | "my answer is", 76 | "i believe it's", 77 | "i select", 78 | "the best answer is", 79 | ] 80 | after_phrases = [ 81 | "is the answer", 82 | "is correct", 83 | "is the correct choice", 84 | "is right", 85 | "is the best answer", 86 | "is the right choice", 87 | ] 88 | 89 | before_pattern = ( 90 | r"(?:" 91 | + "|".join(re.escape(p) for p in before_phrases) 92 | + r")\s*" 93 | + answer_pattern 94 | ) 95 | after_pattern = ( 96 | answer_pattern 97 | + r"\s*(?:" 98 | + "|".join(re.escape(p) for p in after_phrases) 99 | + r")" 100 | ) 101 | 102 | before_regex = re.compile(before_pattern, re.IGNORECASE) 103 | after_regex = re.compile(after_pattern, re.IGNORECASE) 104 | 105 | matches = list(before_regex.finditer(text)) + list(after_regex.finditer(text)) 106 | if matches: 107 | matches.sort(key=lambda m: m.start()) 108 | return matches[-1].group(1).upper() 109 | 110 | # D. Check for a direct answer format: "**A**.", "**Yes**:", etc. 111 | direct_match = re.search( 112 | answer_pattern + r"(?:\.|:|\))?(?:\s|$)", text, re.IGNORECASE 113 | ) 114 | if direct_match: 115 | return direct_match.group(1).upper() 116 | 117 | # E. Fallback: find the last standalone answer word/letter in the text 118 | fallback_matches = re.findall(answer_pattern, text, re.IGNORECASE) 119 | if fallback_matches: 120 | return fallback_matches[-1].upper() 121 | 122 | return "None" 123 | 124 | 125 | def exact_match(pred, gold): 126 | """Compare normalized predicted and gold answers""" 127 | return normalize_answer(pred) == normalize_answer(gold) 128 | 129 | 130 | def process_csv_files( 131 | pattern, response_col, correct_col, split_col=None, output_file=None 132 | ): 133 | """ 134 | Process multiple CSV files and calculate accuracy per split. 135 | 136 | Args: 137 | pattern (str): Glob pattern for CSV files 138 | response_col (str): Name of column containing model responses 139 | correct_col (str): Name of column containing correct answers 140 | split_col (str, optional): Name of column containing splits 141 | output_file (str, optional): Path to save results CSV 142 | 143 | Returns: 144 | pd.DataFrame: DataFrame with results per model and split 145 | """ 146 | files = glob.glob(pattern) 147 | 148 | if not files: 149 | print(f"No files found matching pattern: {pattern}") 150 | return None 151 | 152 | print(f"Found {len(files)} files to process") 153 | 154 | # Dictionary to store results for all models 155 | results = {} 156 | detailed_results = [] 157 | 158 | for file in files: 159 | print(f"Processing: {file}") 160 | 161 | try: 162 | df = pd.read_csv(file) 163 | print(f" Loaded {len(df)} rows") 164 | except Exception as e: 165 | print(f" Error reading {file}: {e}") 166 | continue 167 | 168 | # Extract model name from filename 169 | model_name = os.path.basename(file).replace(".csv", "") 170 | 171 | # Check if required columns exist 172 | if response_col not in df.columns: 173 | print(f" Warning: Column '{response_col}' not found in {file}") 174 | continue 175 | if correct_col not in df.columns: 176 | print(f" Warning: Column '{correct_col}' not found in {file}") 177 | continue 178 | if split_col and split_col not in df.columns: 179 | print(f" Warning: Column '{split_col}' not found in {file}") 180 | continue 181 | 182 | # Extract answers from responses 183 | df["extracted_answer"] = ( 184 | df[response_col].fillna("").astype(str).apply(extract_answer) 185 | ) 186 | 187 | # Calculate correctness using exact match with extracted answers 188 | try: 189 | df["correct"] = df.apply( 190 | lambda row: exact_match(row["extracted_answer"], row[correct_col]), 191 | axis=1, 192 | ) 193 | except Exception as e: 194 | print(f" Error calculating correctness for {file}: {e}") 195 | continue 196 | 197 | # Calculate overall accuracy 198 | overall_accuracy = df["correct"].mean() 199 | 200 | # Initialize results for this model 201 | results[model_name] = {"overall": overall_accuracy} 202 | 203 | # Calculate accuracy per split if split column exists 204 | if split_col and split_col in df.columns: 205 | try: 206 | accuracy_per_split = df.groupby(split_col)["correct"].mean() 207 | results[model_name].update(accuracy_per_split.to_dict()) 208 | 209 | print(f" Accuracy per split:") 210 | for split, acc in accuracy_per_split.items(): 211 | print(f" {split}: {acc:.4f} ({acc * 100:.2f}%)") 212 | 213 | except Exception as e: 214 | print(f" Error calculating accuracy per split for {file}: {e}") 215 | 216 | print( 217 | f" Overall accuracy: {overall_accuracy:.4f} ({overall_accuracy * 100:.2f}%)" 218 | ) 219 | 220 | # Store detailed results for this file 221 | file_details = { 222 | "model": model_name, 223 | "total_questions": len(df), 224 | "correct_answers": int(df["correct"].sum()), 225 | "extraction_failed": int((df["extracted_answer"] == "None").sum()), 226 | "overall_accuracy": overall_accuracy, 227 | } 228 | detailed_results.append(file_details) 229 | 230 | print() 231 | 232 | if not results: 233 | print("No valid results found") 234 | return None 235 | 236 | # Create DataFrame with models as rows and splits as columns 237 | results_df = pd.DataFrame(results).T 238 | results_df = results_df.sort_index() 239 | results_df.index.name = "model" 240 | 241 | # Round all values to 4 decimal places 242 | results_df = results_df.round(4) 243 | 244 | # Save results 245 | if output_file is None: 246 | output_file = "model_accuracy_by_split.csv" 247 | 248 | results_df.to_csv(output_file) 249 | print(f"Results saved to {output_file}") 250 | 251 | # Save detailed results 252 | detailed_df = pd.DataFrame(detailed_results) 253 | detailed_output = output_file.replace(".csv", "_detailed.csv") 254 | detailed_df.to_csv(detailed_output, index=False) 255 | print(f"Detailed results saved to {detailed_output}") 256 | 257 | # Print summary 258 | print("\n" + "=" * 60) 259 | print("SUMMARY RESULTS") 260 | print("=" * 60) 261 | print(results_df) 262 | print("=" * 60) 263 | 264 | return results_df 265 | 266 | 267 | def main(): 268 | parser = argparse.ArgumentParser( 269 | description="Calculate accuracy from multiple CSV files" 270 | ) 271 | parser.add_argument( 272 | "--input-pattern", 273 | "-i", 274 | required=True, 275 | help='Glob pattern for CSV files (e.g., "data/*.csv")' 276 | ) 277 | parser.add_argument( 278 | "--response-col", 279 | "-r", 280 | required=True, 281 | help="Name of column containing model responses", 282 | ) 283 | parser.add_argument( 284 | "--correct-col", 285 | "-c", 286 | required=True, 287 | help="Name of column containing correct answers", 288 | ) 289 | parser.add_argument( 290 | "--split-col", "-s", help="Name of column containing splits (optional)" 291 | ) 292 | parser.add_argument( 293 | "--output", 294 | "-o", 295 | help="Path to save results CSV (default: model_accuracy_by_split.csv)", 296 | ) 297 | 298 | args = parser.parse_args() 299 | 300 | # Process files 301 | results_df = process_csv_files( 302 | args.pattern, args.response_col, args.correct_col, args.split_col, args.output 303 | ) 304 | 305 | if results_df is None: 306 | sys.exit(1) 307 | 308 | 309 | if __name__ == "__main__": 310 | main() 311 | -------------------------------------------------------------------------------- /src/data/mrt_blender.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import bpy 4 | import bmesh 5 | import os 6 | import sys 7 | import random 8 | import argparse 9 | import json 10 | import mathutils 11 | from mathutils import Vector, Matrix, Euler 12 | import math 13 | 14 | # Import shape definitions from mrt.py 15 | sys.path.append(os.path.dirname(__file__)) 16 | from mrt import SHAPES, EASY_SHAPES, COMPLEX_SHAPES, SIMILAR_MAPPING 17 | 18 | def clear_scene(): 19 | """Clear all objects from the scene.""" 20 | bpy.ops.object.select_all(action='SELECT') 21 | bpy.ops.object.delete(use_global=False) 22 | 23 | def create_cube_at_position(position, size=1.0, name="Cube"): 24 | """Create a cube at the specified position.""" 25 | bpy.ops.mesh.primitive_cube_add(size=size, location=position) 26 | cube = bpy.context.active_object 27 | cube.name = name 28 | return cube 29 | 30 | def create_polycube(shape_coords, cube_size=1.0, material=None): 31 | """Create a polycube from coordinate list.""" 32 | cubes = [] 33 | for i, coord in enumerate(shape_coords): 34 | pos = Vector((coord[0] * cube_size, coord[1] * cube_size, coord[2] * cube_size)) 35 | cube = create_cube_at_position(pos, cube_size, f"Cube_{i}") 36 | if material: 37 | cube.data.materials.append(material) 38 | cubes.append(cube) 39 | 40 | # Join all cubes into one object 41 | bpy.ops.object.select_all(action='DESELECT') 42 | for cube in cubes: 43 | cube.select_set(True) 44 | bpy.context.view_layer.objects.active = cubes[0] 45 | bpy.ops.object.join() 46 | 47 | polycube = bpy.context.active_object 48 | return polycube 49 | 50 | def create_material(color=(1, 1, 1, 1), name="Material"): 51 | """Create a material with specified color.""" 52 | mat = bpy.data.materials.new(name=name) 53 | mat.use_nodes = True 54 | bsdf = mat.node_tree.nodes["Principled BSDF"] 55 | bsdf.inputs[0].default_value = color # Base Color 56 | bsdf.inputs[7].default_value = 0.2 # Roughness 57 | bsdf.inputs[15].default_value = 1.0 # Specular 58 | return mat 59 | 60 | def setup_lighting(): 61 | """Set up clean lighting for the scene.""" 62 | # Add sun light 63 | bpy.ops.object.light_add(type='SUN', location=(5, 5, 10)) 64 | sun = bpy.context.active_object 65 | sun.data.energy = 3.0 66 | sun.rotation_euler = (0.785, 0, 0.785) # 45 degrees 67 | 68 | # Add fill light 69 | bpy.ops.object.light_add(type='AREA', location=(-5, -5, 8)) 70 | fill = bpy.context.active_object 71 | fill.data.energy = 1.0 72 | fill.data.size = 5.0 73 | 74 | def setup_camera(target_location, distance=8.0): 75 | """Set up camera to look at target location.""" 76 | # Add camera 77 | bpy.ops.object.camera_add(location=(distance, -distance, distance)) 78 | camera = bpy.context.active_object 79 | 80 | # Point camera at target 81 | direction = target_location - camera.location 82 | rot_quat = direction.to_track_quat('-Z', 'Y') 83 | camera.rotation_euler = rot_quat.to_euler() 84 | 85 | # Set as active camera 86 | bpy.context.scene.camera = camera 87 | return camera 88 | 89 | def get_object_bounds(obj): 90 | """Get the bounding box of an object.""" 91 | bbox_corners = [obj.matrix_world @ Vector(corner) for corner in obj.bound_box] 92 | min_coord = Vector((min(v.x for v in bbox_corners), 93 | min(v.y for v in bbox_corners), 94 | min(v.z for v in bbox_corners))) 95 | max_coord = Vector((max(v.x for v in bbox_corners), 96 | max(v.y for v in bbox_corners), 97 | max(v.z for v in bbox_corners))) 98 | center = (min_coord + max_coord) / 2 99 | size = max_coord - min_coord 100 | return center, size 101 | 102 | def transform_rotate_blender(obj, difficulty="easy"): 103 | """Apply rotation transformation to object.""" 104 | if difficulty == "easy": 105 | # Single axis rotation with simple angles 106 | axis = random.choice(['X', 'Y', 'Z']) 107 | angle = math.radians(random.choice([-90, 90, 180])) 108 | 109 | if axis == 'X': 110 | obj.rotation_euler = (angle, 0, 0) 111 | elif axis == 'Y': 112 | obj.rotation_euler = (0, angle, 0) 113 | else: # Z 114 | obj.rotation_euler = (0, 0, angle) 115 | else: # complex 116 | # Multi-axis rotation 117 | angle_x = math.radians(random.choice([0, 60, 90, 120])) 118 | angle_y = math.radians(random.choice([0, 60, 90, 120])) 119 | angle_z = math.radians(random.choice([0, 60, 90, 120])) 120 | obj.rotation_euler = (angle_x, angle_y, angle_z) 121 | 122 | def transform_mirror_blender(obj, difficulty="easy"): 123 | """Apply mirror transformation to object.""" 124 | if difficulty == "easy": 125 | # Mirror across Z axis 126 | obj.scale[2] = -1 127 | else: # complex 128 | # Mirror across random axis 129 | axis = random.choice([0, 1, 2]) 130 | scale = [1, 1, 1] 131 | scale[axis] = -1 132 | obj.scale = scale 133 | 134 | def render_image(output_path, resolution=(512, 512)): 135 | """Render the current scene to an image.""" 136 | scene = bpy.context.scene 137 | scene.render.resolution_x = resolution[0] 138 | scene.render.resolution_y = resolution[1] 139 | scene.render.filepath = output_path 140 | scene.render.image_settings.file_format = 'PNG' 141 | 142 | # Set render engine to Cycles for better quality 143 | scene.render.engine = 'CYCLES' 144 | scene.cycles.samples = 64 145 | 146 | bpy.ops.render.render(write_still=True) 147 | 148 | def create_composite_image(original_obj, candidates, output_path, difficulty="easy"): 149 | """Create composite image with original and candidates.""" 150 | # Position objects for composite layout 151 | if difficulty == "easy": 152 | positions = [ 153 | Vector((0, 0, 4)), # Original (top center) 154 | Vector((-4, 0, 0)), # Candidate A 155 | Vector((0, 0, 0)), # Candidate B 156 | Vector((4, 0, 0)), # Candidate C 157 | ] 158 | else: # complex 159 | positions = [ 160 | Vector((0, 0, 6)), # Original (top center) 161 | Vector((-6, 0, 0)), # Candidate A 162 | Vector((-2, 0, 0)), # Candidate B 163 | Vector((2, 0, 0)), # Candidate C 164 | Vector((6, 0, 0)), # Candidate D 165 | ] 166 | 167 | # Position original 168 | original_obj.location = positions[0] 169 | 170 | # Position candidates 171 | for i, (_, candidate_obj) in enumerate(candidates): 172 | candidate_obj.location = positions[i + 1] 173 | 174 | # Set up camera to capture all objects 175 | all_objects = [original_obj] + [obj for _, obj in candidates] 176 | 177 | # Calculate scene bounds 178 | min_coords = Vector((float('inf'), float('inf'), float('inf'))) 179 | max_coords = Vector((float('-inf'), float('-inf'), float('-inf'))) 180 | 181 | for obj in all_objects: 182 | for corner in obj.bound_box: 183 | world_corner = obj.matrix_world @ Vector(corner) 184 | min_coords.x = min(min_coords.x, world_corner.x) 185 | min_coords.y = min(min_coords.y, world_corner.y) 186 | min_coords.z = min(min_coords.z, world_corner.z) 187 | max_coords.x = max(max_coords.x, world_corner.x) 188 | max_coords.y = max(max_coords.y, world_corner.y) 189 | max_coords.z = max(max_coords.z, world_corner.z) 190 | 191 | scene_center = (min_coords + max_coords) / 2 192 | scene_size = max(max_coords - min_coords) 193 | 194 | # Position camera 195 | camera_distance = scene_size * 1.5 196 | setup_camera(scene_center, camera_distance) 197 | 198 | # Render 199 | render_image(output_path) 200 | 201 | def generate_one_image_blender(index, difficulty="easy", facecolor="white", outdir="data/mrt"): 202 | """Generate a single MRT image using Blender.""" 203 | clear_scene() 204 | 205 | # Create material 206 | if facecolor == "white": 207 | color = (0.9, 0.9, 0.9, 1.0) 208 | else: 209 | # Simple color mapping - extend as needed 210 | color_map = { 211 | "red": (0.8, 0.2, 0.2, 1.0), 212 | "blue": (0.2, 0.2, 0.8, 1.0), 213 | "green": (0.2, 0.8, 0.2, 1.0), 214 | } 215 | color = color_map.get(facecolor, (0.9, 0.9, 0.9, 1.0)) 216 | 217 | material = create_material(color, "PolycubeMaterial") 218 | 219 | # Select shapes based on difficulty 220 | shapes_list = EASY_SHAPES if difficulty == "easy" else COMPLEX_SHAPES 221 | shape_name = random.choice(shapes_list) 222 | shape_coords = SHAPES[shape_name] 223 | 224 | # Create original polycube 225 | original_obj = create_polycube(shape_coords, material=material) 226 | original_obj.name = "Original" 227 | 228 | # Generate candidates 229 | candidates = [] 230 | 231 | # Correct candidate (rotation) 232 | correct_obj = create_polycube(shape_coords, material=material) 233 | correct_obj.name = "Rotate" 234 | transform_rotate_blender(correct_obj, difficulty) 235 | candidates.append(("rotate", correct_obj)) 236 | 237 | # Mirror candidate 238 | mirror_obj = create_polycube(shape_coords, material=material) 239 | mirror_obj.name = "Mirror" 240 | transform_mirror_blender(mirror_obj, difficulty) 241 | transform_rotate_blender(mirror_obj, difficulty) 242 | candidates.append(("mirror", mirror_obj)) 243 | 244 | # Similar shape candidate 245 | if shape_name in SIMILAR_MAPPING: 246 | similar_candidates = SIMILAR_MAPPING[shape_name][:] 247 | random.shuffle(similar_candidates) 248 | similar_shape_name = similar_candidates[0] if similar_candidates else shape_name 249 | else: 250 | similar_shape_name = shape_name 251 | 252 | similar_coords = SHAPES[similar_shape_name] 253 | similar_obj = create_polycube(similar_coords, material=material) 254 | similar_obj.name = "Similar" 255 | transform_rotate_blender(similar_obj, difficulty) 256 | candidates.append(("similar", similar_obj)) 257 | 258 | # Add second mirror for complex mode 259 | if difficulty == "complex": 260 | mirror2_obj = create_polycube(shape_coords, material=material) 261 | mirror2_obj.name = "Mirror2" 262 | transform_mirror_blender(mirror2_obj, difficulty) 263 | transform_rotate_blender(mirror2_obj, difficulty) 264 | candidates.append(("mirror2", mirror2_obj)) 265 | 266 | # Shuffle candidates 267 | random.shuffle(candidates) 268 | correct_candidate_index = [ 269 | i for i, cand in enumerate(candidates) if cand[0] == "rotate" 270 | ][0] 271 | 272 | # Set up lighting 273 | setup_lighting() 274 | 275 | # Create composite image 276 | filename = f"{shape_name}_{index}.png" 277 | output_path = os.path.join(outdir, filename) 278 | create_composite_image(original_obj, candidates, output_path, difficulty) 279 | 280 | # Save metadata 281 | metadata = { 282 | "filename": filename, 283 | "difficulty": difficulty, 284 | "shape": shape_name, 285 | "candidate_order": [tag for tag, _ in candidates], 286 | "answer": chr(65 + correct_candidate_index), 287 | } 288 | 289 | metadata_path = os.path.join(outdir, "metadata.jsonl") 290 | with open(metadata_path, "a") as f: 291 | f.write(json.dumps(metadata) + "\n") 292 | 293 | def main(): 294 | parser = argparse.ArgumentParser( 295 | description="Generate mental rotation test images using Blender." 296 | ) 297 | parser.add_argument( 298 | "--difficulty", "-d", type=str, choices=["easy", "complex"], default="easy", 299 | help="Difficulty level: 'easy' or 'complex'" 300 | ) 301 | parser.add_argument( 302 | "--num_images", "-n", type=int, default=1, 303 | help="Number of images to generate" 304 | ) 305 | parser.add_argument( 306 | "--color", "-c", type=str, default="white", 307 | help="Color for the polycubes" 308 | ) 309 | parser.add_argument( 310 | "--seed", "-s", type=int, default=69, 311 | help="Seed for reproducible results" 312 | ) 313 | parser.add_argument( 314 | "--outdir", "-o", type=str, default=None, 315 | help="Output directory (defaults to data/mrt_blender/{difficulty})" 316 | ) 317 | 318 | args = parser.parse_args() 319 | 320 | # Set default output directory based on difficulty 321 | if args.outdir is None: 322 | args.outdir = f"data/mrt_blender/{args.difficulty}" 323 | 324 | os.makedirs(args.outdir, exist_ok=True) 325 | 326 | # Set seed for reproducibility 327 | if args.seed is not None: 328 | random.seed(args.seed) 329 | 330 | # Generate images 331 | for i in range(args.num_images): 332 | generate_one_image_blender(i, difficulty=args.difficulty, 333 | facecolor=args.color, outdir=args.outdir) 334 | 335 | print(f"Generated {args.num_images} {args.difficulty} MRT images in {args.outdir}") 336 | 337 | # Run directly in Blender 338 | if __name__ == "__main__": 339 | # When running in Blender, parse sys.argv differently 340 | if "--" in sys.argv: 341 | argv = sys.argv[sys.argv.index("--") + 1:] 342 | sys.argv = [sys.argv[0]] + argv 343 | main() 344 | -------------------------------------------------------------------------------- /src/data/folding_pil.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw, ImageFont 2 | import random 3 | import argparse 4 | import os 5 | import json 6 | import math 7 | from collections import defaultdict 8 | 9 | OUTPUT_DIR = "data/pf" 10 | os.makedirs(OUTPUT_DIR, exist_ok=True) 11 | 12 | # Cache font loading 13 | _font_cache = {} 14 | 15 | 16 | def get_font(size=18): 17 | """Get cached font instance.""" 18 | if size not in _font_cache: 19 | try: 20 | _font_cache[size] = ImageFont.truetype("Arial", size=size) 21 | except OSError: 22 | _font_cache[size] = ImageFont.load_default() 23 | return _font_cache[size] 24 | 25 | 26 | def draw_paper(draw): 27 | """Draw a 100x100 white square paper with a black outline.""" 28 | draw.rectangle((10, 10, 110, 110), outline="black", fill="white") 29 | 30 | 31 | def draw_holes(draw, holes): 32 | """ 33 | Draw holes as small circles. If duplicate coordinates occur, 34 | offset them slightly in a circular pattern so that all are visible. 35 | """ 36 | if not holes: 37 | return 38 | 39 | groups = defaultdict(list) 40 | # Group holes by rounded coordinates. 41 | for h in holes: 42 | key = (round(h[0], 1), round(h[1], 1)) 43 | groups[key].append(h) 44 | 45 | for group in groups.values(): 46 | n = len(group) 47 | if n == 1: 48 | x, y = group[0] 49 | draw.ellipse((x - 2, y - 2, x + 2, y + 2), fill="black") 50 | else: 51 | # Distribute duplicates in a circle of small radius. 52 | radius_offset = 2 53 | angle_step = 2 * math.pi / n 54 | for i, h in enumerate(group): 55 | angle = angle_step * i 56 | offset_x = radius_offset * math.cos(angle) 57 | offset_y = radius_offset * math.sin(angle) 58 | x, y = h[0] + offset_x, h[1] + offset_y 59 | draw.ellipse((x - 2, y - 2, x + 2, y + 2), fill="black") 60 | 61 | 62 | # Reflection logic - use lambda functions for better performance 63 | FOLD_REFLECTIONS = { 64 | "V": lambda p: (120 - p[0], p[1]), 65 | "H": lambda p: (p[0], 120 - p[1]), 66 | "D": lambda p: (p[1], p[0]), 67 | "N": lambda p: (120 - p[1], 120 - p[0]), 68 | } 69 | 70 | 71 | def compute_all_layers(punched_holes, folds): 72 | """ 73 | Compute unfolded holes by doubling the layers for each fold. 74 | """ 75 | if not punched_holes: 76 | return [] 77 | 78 | layers = [punched_holes] 79 | for fold in folds: 80 | reflection_func = FOLD_REFLECTIONS[fold] 81 | new_layers = [] 82 | for layer in layers: 83 | new_layers.append(layer) 84 | new_layers.append([reflection_func(h) for h in layer]) 85 | layers = new_layers 86 | 87 | # Flatten layers 88 | return [h for layer in layers for h in layer] 89 | 90 | 91 | def point_in_poly(x, y, poly): 92 | """Optimized point-in-polygon test using ray casting.""" 93 | if not poly: 94 | return False 95 | 96 | inside = False 97 | j = len(poly) - 1 98 | 99 | for i in range(len(poly)): 100 | xi, yi = poly[i] 101 | xj, yj = poly[j] 102 | 103 | if ((yi > y) != (yj > y)) and (x < (xj - xi) * (y - yi) / (yj - yi) + xi): 104 | inside = not inside 105 | j = i 106 | 107 | return inside 108 | 109 | 110 | def generate_hole_in_poly(poly, margin=5): 111 | """Generate a random point within the polygon with margin.""" 112 | if not poly: 113 | return (60, 60) # Default center point 114 | 115 | xs = [p[0] for p in poly] 116 | ys = [p[1] for p in poly] 117 | bx_min, bx_max = min(xs) + margin, max(xs) - margin 118 | by_min, by_max = min(ys) + margin, max(ys) - margin 119 | 120 | # Ensure valid bounds 121 | if bx_min >= bx_max or by_min >= by_max: 122 | return (int((min(xs) + max(xs)) / 2), int((min(ys) + max(ys)) / 2)) 123 | 124 | for _ in range(100): # Limit attempts to avoid infinite loop 125 | x = random.uniform(bx_min, bx_max) 126 | y = random.uniform(by_min, by_max) 127 | if point_in_poly(x, y, poly): 128 | return (int(round(x)), int(round(y))) 129 | 130 | # Fallback to polygon centroid 131 | cx = sum(p[0] for p in poly) / len(poly) 132 | cy = sum(p[1] for p in poly) / len(poly) 133 | return (int(round(cx)), int(round(cy))) 134 | 135 | 136 | def clip_polygon(poly, fold): 137 | """Optimized polygon clipping.""" 138 | if not poly: 139 | return [] 140 | 141 | # Pre-compute bounding box 142 | xs = [p[0] for p in poly] 143 | ys = [p[1] for p in poly] 144 | bx_min, bx_max = min(xs), max(xs) 145 | by_min, by_max = min(ys), max(ys) 146 | 147 | if fold == "V": 148 | mid = (bx_min + bx_max) / 2 149 | inside = lambda p: p[0] >= mid 150 | intersect = lambda p1, p2: ( 151 | mid, 152 | p1[1] + (mid - p1[0]) * (p2[1] - p1[1]) / (p2[0] - p1[0]), 153 | ) 154 | elif fold == "H": 155 | mid = (by_min + by_max) / 2 156 | inside = lambda p: p[1] >= mid 157 | intersect = lambda p1, p2: ( 158 | p1[0] + (mid - p1[1]) * (p2[0] - p1[0]) / (p2[1] - p1[1]), 159 | mid, 160 | ) 161 | elif fold == "D": 162 | cx, cy = (bx_min + bx_max) / 2, (by_min + by_max) / 2 163 | b = cy - cx 164 | inside = lambda p: p[1] <= p[0] + b 165 | 166 | def intersect(p1, p2): 167 | denom = (p2[0] - p1[0]) - (p2[1] - p1[1]) 168 | if abs(denom) < 1e-10: 169 | return p1 170 | t = (p1[1] - p1[0] - b) / denom 171 | return (p1[0] + t * (p2[0] - p1[0]), p1[1] + t * (p2[1] - p1[1])) 172 | elif fold == "N": 173 | cx, cy = (bx_min + bx_max) / 2, (by_min + by_max) / 2 174 | inside = lambda p: (p[0] + p[1]) <= (cx + cy) 175 | 176 | def intersect(p1, p2): 177 | denom = (p2[0] - p1[0]) + (p2[1] - p1[1]) 178 | if abs(denom) < 1e-10: 179 | return p1 180 | t = ((cx + cy) - (p1[0] + p1[1])) / denom 181 | return (p1[0] + t * (p2[0] - p1[0]), p1[1] + t * (p2[1] - p1[1])) 182 | else: 183 | return poly 184 | 185 | new_poly = [] 186 | for i in range(len(poly)): 187 | curr = poly[i] 188 | prev = poly[i - 1] 189 | 190 | if inside(curr): 191 | if not inside(prev): 192 | new_poly.append(intersect(prev, curr)) 193 | new_poly.append(curr) 194 | elif inside(prev): 195 | new_poly.append(intersect(prev, curr)) 196 | 197 | return new_poly 198 | 199 | 200 | # Global list to collect final view images for each fold. 201 | process_images = [] 202 | 203 | 204 | def recursive_fold(current_folds, idx, poly): 205 | """Recursively generate the final folded view for each fold.""" 206 | global process_images 207 | 208 | if idx == 0: 209 | img_unfolded = Image.new("RGB", (120, 120), "white") 210 | draw_unfolded = ImageDraw.Draw(img_unfolded) 211 | draw_paper(draw_unfolded) 212 | process_images.append((img_unfolded, "Unfolded")) 213 | 214 | if idx >= len(current_folds): 215 | return poly 216 | 217 | fold = current_folds[idx] 218 | new_poly = clip_polygon(poly, fold) 219 | 220 | img_result = Image.new("RGB", (120, 120), "white") 221 | draw_result = ImageDraw.Draw(img_result) 222 | 223 | if poly: 224 | draw_result.polygon(poly, fill="lightgray", outline="black") 225 | if new_poly: 226 | draw_result.polygon(new_poly, fill="white", outline="black") 227 | 228 | process_images.append((img_result, f"Fold {idx + 1}")) 229 | return recursive_fold(current_folds, idx + 1, new_poly) 230 | 231 | 232 | def generate_wrong_options(holes): 233 | """Generate all wrong options at once for efficiency.""" 234 | if not holes: 235 | return [[], []] 236 | 237 | # Option 1: Remove one hole or shift if only one hole 238 | option1 = holes.copy() 239 | if len(option1) > 1: 240 | option1.pop(random.randrange(len(option1))) 241 | else: 242 | x, y = option1[0] 243 | dx, dy = random.randint(-3, 3), random.randint(-3, 3) 244 | option1[0] = (x + dx, y + dy) 245 | 246 | # Option 2: Mirror transformation 247 | option2 = [(120 - x, 120 - y) for x, y in holes] 248 | 249 | # If mirror is same as original, try rotation 250 | if set(option2) == set(holes): 251 | option2 = [ 252 | (ty + 60, -tx + 60) for x, y in holes for tx, ty in [(x - 60, y - 60)] 253 | ] 254 | # If rotation is also same, shift horizontally 255 | if set(option2) == set(holes): 256 | option2 = [(min(x + 5, 110), y) for x, y in holes] 257 | 258 | return [option1, option2] 259 | 260 | 261 | def generate_test_image(folds, test_number, num_folds, num_holes): 262 | """Generate a test image with optimized processing.""" 263 | global process_images 264 | process_images = [] 265 | 266 | font_bigger = get_font(18) 267 | 268 | initial_poly = [(10, 10), (110, 10), (110, 110), (10, 110)] 269 | final_poly = recursive_fold(folds, 0, initial_poly) 270 | 271 | # Generate punched holes 272 | punched_holes = [ 273 | generate_hole_in_poly(final_poly, margin=5) for _ in range(num_holes) 274 | ] 275 | unfolded_holes = compute_all_layers(punched_holes, folds) 276 | 277 | # Create final view image 278 | img_final = Image.new("RGB", (120, 120), "white") 279 | draw_final = ImageDraw.Draw(img_final) 280 | if final_poly: 281 | draw_final.polygon(final_poly, fill="white", outline="black") 282 | draw_holes(draw_final, punched_holes) 283 | process_images.append((img_final, "Final view")) 284 | 285 | # Create candidate images 286 | img_correct = Image.new("RGB", (120, 120), "white") 287 | draw_correct = ImageDraw.Draw(img_correct) 288 | draw_paper(draw_correct) 289 | draw_holes(draw_correct, unfolded_holes) 290 | 291 | wrong_options = generate_wrong_options(unfolded_holes) 292 | candidate_images = [("correct", img_correct)] 293 | 294 | for wrong_holes in wrong_options: 295 | img_wrong = Image.new("RGB", (120, 120), "white") 296 | draw_wrong = ImageDraw.Draw(img_wrong) 297 | draw_paper(draw_wrong) 298 | draw_holes(draw_wrong, wrong_holes) 299 | candidate_images.append(("wrong", img_wrong)) 300 | 301 | # Shuffle and assign labels 302 | random.shuffle(candidate_images) 303 | option_labels = ["A", "B", "C"] 304 | correct_label = None 305 | 306 | for i, (kind, img) in enumerate(candidate_images): 307 | if kind == "correct": 308 | correct_label = option_labels[i] 309 | 310 | # Create final composite image 311 | small_size = 120 312 | top_width = len(process_images) * small_size + (len(process_images) - 1) * 10 313 | bottom_width = 3 * small_size + 2 * 10 314 | total_width = max(top_width, bottom_width) 315 | total_height = 320 316 | 317 | total_img = Image.new("RGB", (total_width, total_height), "white") 318 | draw_total = ImageDraw.Draw(total_img) 319 | 320 | # Draw top row 321 | start_x = (total_width - top_width) // 2 322 | for i, (img, label) in enumerate(process_images): 323 | x = start_x + i * (small_size + 10) 324 | total_img.paste(img, (x, 30)) 325 | draw_total.text((x + 15, 10), label, fill="black", font=font_bigger) 326 | 327 | # Draw bottom row 328 | start_x_bottom = (total_width - bottom_width) // 2 329 | for i, (_, img) in enumerate(candidate_images): 330 | x = start_x_bottom + i * (small_size + 10) 331 | total_img.paste(img, (x, 190)) 332 | draw_total.text((x + 45, 180), option_labels[i], fill="black", font=font_bigger) 333 | 334 | # Save image 335 | f_name = f"{test_number}_fold-{num_folds}_holes-{num_holes}.png" 336 | out_path = os.path.join(OUTPUT_DIR, f_name) 337 | total_img.save(out_path) 338 | 339 | return out_path, correct_label 340 | 341 | 342 | if __name__ == "__main__": 343 | parser = argparse.ArgumentParser( 344 | description="Generate folded paper images with holes." 345 | ) 346 | parser.add_argument( 347 | "-n", 348 | "--num-images", 349 | type=int, 350 | default=10, 351 | help="Number of images to generate (default: 10)", 352 | ) 353 | parser.add_argument( 354 | "-s", 355 | "--seed", 356 | type=int, 357 | default=42, 358 | help="Seed for random number generator (default: 42)", 359 | ) 360 | parser.add_argument( 361 | "-f", 362 | "--num-folds", 363 | type=int, 364 | default=2, 365 | choices=range(1, 10), 366 | help="Number of folds (minimum 1, maximum 9) (default: 2)", 367 | ) 368 | parser.add_argument( 369 | "-H", "--num-holes", type=int, default=1, help="Number of holes (default: 1)" 370 | ) 371 | parser.add_argument( 372 | "-m", 373 | "--metadata-file", 374 | type=str, 375 | default="metadata.jsonl", 376 | help="Metadata JSONL file to append to (default: metadata.jsonl)", 377 | ) 378 | args = parser.parse_args() 379 | 380 | random.seed(args.seed) 381 | metadata_path = os.path.join(OUTPUT_DIR, args.metadata_file) 382 | with open(metadata_path, "a", encoding="utf-8") as metaf: 383 | for i in range(1, args.num_images + 1): 384 | fold_group = random.choice(["VH", "Diagonal"]) 385 | if fold_group == "VH": 386 | folds = [random.choice(["V", "H"]) for _ in range(args.num_folds)] 387 | else: 388 | folds = [random.choice(["D", "N"]) for _ in range(args.num_folds)] 389 | image_path, correct_option = generate_test_image( 390 | folds, i, args.num_folds, args.num_holes 391 | ) 392 | rel_image_path = os.path.basename(image_path) 393 | metadata_obj = { 394 | "filename": rel_image_path, 395 | "correct_option": correct_option, 396 | } 397 | metaf.write(json.dumps(metadata_obj) + "\n") 398 | -------------------------------------------------------------------------------- /src/data/create_prompts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | This script demonstrates how to: 4 | 1. Define a list of simple objects with attributes and a set of spatial test tasks. 5 | 2. For each spatial task and each model, generate multiple scene descriptions using an LLM. 6 | 3. Post-process the model's output to remove internal commentary. 7 | 4. Log each generated prompt and metadata to both WandB and Neptune.ai. 8 | 5. Save each generated scene description and metadata to a JSONL file. 9 | """ 10 | 11 | import re 12 | import os 13 | import random 14 | import json 15 | import logging 16 | import glob 17 | from typing import List, Dict, Tuple, Any 18 | 19 | import neptune.utils 20 | import torch 21 | import pandas as pd 22 | from transformers import pipeline 23 | from dotenv import load_dotenv 24 | import wandb 25 | import neptune 26 | 27 | # Load environment variables from the .env file. 28 | load_dotenv(override=True) 29 | 30 | logger = logging.getLogger(__name__) 31 | logging.basicConfig(filename="debug_prompts.log", encoding="utf-8", level=logging.DEBUG) 32 | 33 | # ------------------------------- 34 | # Type Aliases 35 | # ------------------------------- 36 | ObjectType = Dict[str, Any] 37 | SpatialTaskType = Dict[str, str] 38 | MessageType = Dict[str, str] 39 | 40 | # ------------------------------- 41 | # Constants 42 | # ------------------------------- 43 | SAVE_DIR = "output/prompts/" 44 | TRIAL_NAME = "fiveshotv6-prompts-beam" 45 | OUTPUT_FILENAME = f"{SAVE_DIR}{TRIAL_NAME}.jsonl" # All prompts will be appended here. 46 | PROMPTS_PER_TASK = 25 47 | # ------------------------------- 48 | # 7. List of Models to Use 49 | # ------------------------------- 50 | MODELS: List[str] = glob.glob("bin/models/llms/*") 51 | 52 | # ------------------------------- 53 | # Initialize WandB 54 | # ------------------------------- 55 | wandb.init( 56 | project="SpatialScenePrompts", 57 | name=TRIAL_NAME, 58 | config={ 59 | "num_prompts_per_task": PROMPTS_PER_TASK, 60 | "models": MODELS, 61 | }, 62 | ) 63 | 64 | # ------------------------------- 65 | # Initialize Neptune.ai Run 66 | # ------------------------------- 67 | neptune_run = neptune.init_run( 68 | project="stogiannidis/create-prompts", # Replace with your project name. 69 | api_token=os.getenv("NEPTUNE_API_TOKEN"), # Replace with your Neptune API token. 70 | name=TRIAL_NAME, 71 | ) 72 | neptune_run["config/num_prompts_per_task"] = PROMPTS_PER_TASK 73 | neptune_run["config/models"] = neptune.utils.stringify_unsupported(MODELS) 74 | 75 | # ------------------------------- 76 | # 1. List of Simple Objects with Attributes 77 | # ------------------------------- 78 | OBJECTS: List[str] = [ 79 | "bicycle", 80 | "motorcycle", 81 | "scooter", 82 | "car", 83 | "truck", 84 | "bus", 85 | "train", 86 | "airplane", 87 | "helicopter", 88 | "boat", 89 | "ship", 90 | "dog", 91 | "cat", 92 | "bird", 93 | "fish", 94 | "rabbit", 95 | ] 96 | 97 | # ------------------------------- 98 | # 2. List of Spatial Test Tasks 99 | # ------------------------------- 100 | # Each task now uses a concise definition and an example. 101 | SPATIAL_TASKS: List[SpatialTaskType] = [ 102 | { 103 | "task_type": "mental rotation", 104 | "examples": [ 105 | "A photo-realistic image of a car, viewed from the front, with the wheels turned to the left, parked in a driveway, with a clear blue sky in the background.", 106 | "A photo-realistic image of a mug, viewed from the side, with the handle on the right, filled with steaming coffee, on a wooden table, with a window in the background showing a sunny day.", 107 | "A photo-realistic image of a laptop, viewed from the back, with the screen open", 108 | "A photo-realistic image of a bicycle, viewed from the side, with the front wheel turned to the right, parked on a cobblestone street, with a row of colorful houses in the background.", 109 | "A photo-realistic image of a cat, viewed from the front, with the tail curled to the right, sitting on a windowsill, with a potted plant in the background.", 110 | ], 111 | }, 112 | # { 113 | # "task_type": "perspective taking", 114 | # "examples": [ 115 | # "A photograph of a family of four, two adults and two children, standing in a row, with the camera positioned at the height of the children, looking up at the adults, who are smiling down at them.", 116 | # "Photograph of street artist, vividly painting a vibrant mural, surrounded by captivated pedestrians, in a stencil-like graffiti style, with a gritty urban setting, drenched in chiaroscuro lighting for a dramatic and lively atmosphere.", 117 | # "Photograph of vibrant street vendors, laden with an array of ripe fruits, amidst the lively hustle of a farmers market - captured in the style of a vivid, Impressionist oil painting, with warm sunlight filtering through a cloud-speckled sky.", 118 | # "Photograph of scuba diver capturing a vibrant, up-close moment with a majestic sea turtle among intricately detailed, luminous coral reef, in the style of a high-definition underwater photograph blending vivid hues and soft shadows, with a serene, lively atmosphere.", 119 | # "Photograph of toddler and playful puppy in sunlit backyard, chasing iridescent bubbles in whimsical Impressionist style, vibrant colors, tender atmosphere, capturing the joy of childhood and canine companionship.", 120 | # ], 121 | # }, 122 | ] 123 | 124 | # ------------------------------- 125 | # 3. Enhanced System Prompt 126 | # ------------------------------- 127 | # This prompt instructs the LLM to generate a test prompt for spatial reasoning. 128 | SYSTEM_PROMPT: str = """**Role:** Text-Based Image Description Generation Assistant 129 | 130 | **Objective:** To generate high-quality text image descriptions for generative models. 131 | 132 | **Input (Textual):** 133 | * A list of objects (provided as text by the user). 134 | * Example image descriptions (provided as text by the user). 135 | 136 | **Output (Textual):** New image descriptions (as text). 137 | 138 | **Task:** Generate new text image descriptions that meet the following criteria: 139 | 1. **Style Mimicry:** Replicate the writing style, sentence structure, and vocabulary used in the example text descriptions. 140 | 2. **Object Novelty:** Feature the provided list of objects, ensuring they are different from objects explicitly mentioned in the example text descriptions. 141 | 3. **Setting Novelty:** Describe the objects in new and different settings or contexts compared to those presented in the example text descriptions. 142 | 4. **Logical Coherence & Realism:** Ensure all generated text descriptions are logically sound, realistic, and portray plausible scenarios in text. Avoid nonsensical or physically impossible descriptions in text. 143 | """ 144 | 145 | # ------------------------------- 146 | # 4. Function to Construct a Prompt for a Given Task 147 | # ------------------------------- 148 | 149 | 150 | def construct_prompt_for_task(task: SpatialTaskType) -> Tuple[str, str]: 151 | """ 152 | Constructs a test prompt using the task's definition and example, and optionally 153 | includes one or two randomly selected simple objects with attributes. 154 | 155 | Args: 156 | task: A spatial test task containing "task_type", "definition", and "example". 157 | 158 | Returns: 159 | A tuple with the task type and the constructed user prompt. 160 | """ 161 | task_type = task["task_type"] 162 | examples = task["examples"] 163 | user_prompt = (f"""Please generate a text-to-image prompt inspired by the following examples and using these objects: 164 | **Example Prompts:** 165 | {examples} 166 | 167 | **Objects to Include:** 168 | {OBJECTS} 169 | 170 | Respond with the brief yet consise generated scene description.""" 171 | ) 172 | return task_type, user_prompt 173 | 174 | 175 | # ------------------------------- 176 | # 5. Post-Processing Function 177 | # ------------------------------- 178 | def post_process_output(output_text: str) -> str: 179 | """ 180 | Removes any text between and tags and strips whitespace. 181 | 182 | Args: 183 | output_text: The raw output text from the model. 184 | 185 | Returns: 186 | The cleaned output text. 187 | """ 188 | cleaned_text: str = re.sub(r".*?", "", output_text, flags=re.DOTALL) 189 | return cleaned_text.strip() 190 | 191 | 192 | # ------------------------------- 193 | # 6. Helper Functions to Adapt to Different Pipeline Variants 194 | # ------------------------------- 195 | def supports_system_prompt(model_name: str) -> bool: 196 | """ 197 | Heuristic: if the model id contains "gemma" (case-insensitive) we assume it does not support chat-style prompts. 198 | """ 199 | return "gemma" not in model_name.lower() 200 | 201 | 202 | def prepare_pipeline_input(messages: List[MessageType], model_name: str) -> Any: 203 | """ 204 | Prepares the input for the pipeline call based on whether the model supports system prompts. 205 | """ 206 | if supports_system_prompt(model_name): 207 | return messages 208 | else: 209 | combined_prompt = "\n\n".join(msg["content"].strip() for msg in messages) 210 | return [{"role": "user", "content": combined_prompt.strip()}] 211 | 212 | 213 | def extract_generated_text(outputs: Any) -> str: 214 | """ 215 | Extracts the generated text from the output of the pipeline. 216 | 217 | Supports both nested (chat-style) and flat output. 218 | """ 219 | if isinstance(outputs, list) and outputs: 220 | output_item = outputs[0] 221 | if "generated_text" in output_item: 222 | gen = output_item["generated_text"] 223 | if isinstance(gen, list): 224 | last = gen[-1] 225 | if isinstance(last, dict) and "content" in last: 226 | return last["content"].strip() 227 | elif isinstance(last, str): 228 | return last.strip() 229 | elif isinstance(gen, str): 230 | return gen.strip() 231 | return "" 232 | 233 | 234 | # ------------------------------- 235 | # 8. Main Routine: Iterate Over All Model and Task Combinations 236 | # ------------------------------- 237 | def main() -> None: 238 | logger.info("\n\nStarting main routine with WandB and Neptune.ai logging.") 239 | 240 | for model_name in MODELS: 241 | logger.info("Processing model: %s", model_name) 242 | # Get the original name of the model without the path. 243 | sanitized_model: str = model_name.split("/")[-1] 244 | logger.info("Sanitized model name: %s", sanitized_model) 245 | 246 | # Initialize the text-generation pipeline for the current model. 247 | llm_pipe: Any = pipeline( 248 | "text-generation", 249 | model=model_name, 250 | device_map="auto", 251 | torch_dtype=torch.bfloat16, 252 | return_full_text=False, 253 | do_sample=True, 254 | temperature=1, 255 | top_k=10, 256 | top_p=0.9, 257 | num_beams=10, 258 | ) 259 | eos_token_id = llm_pipe.tokenizer.eos_token_id 260 | logger.info("Pipeline initialized for %s", sanitized_model) 261 | 262 | for task in SPATIAL_TASKS: 263 | # Generate a number of prompts for each task. 264 | for i in range(PROMPTS_PER_TASK): 265 | logger.info( 266 | "Generating prompt %d for task: %s", i + 1, task["task_type"] 267 | ) 268 | task_type, user_prompt = construct_prompt_for_task(task) 269 | logger.info("User prompt generated for task: %s", task_type) 270 | 271 | # Prepare the messages. 272 | messages: List[MessageType] = [ 273 | {"role": "system", "content": SYSTEM_PROMPT}, 274 | {"role": "user", "content": user_prompt}, 275 | ] 276 | prompt_input = prepare_pipeline_input(messages, model_name) 277 | 278 | # Generate the scene description using the LLM. 279 | outputs: Any = llm_pipe( 280 | prompt_input, 281 | max_new_tokens=1024, 282 | pad_token_id=eos_token_id, 283 | ) 284 | generated_text: str = extract_generated_text(outputs) 285 | if generated_text: 286 | cleaned_text = post_process_output(generated_text) 287 | 288 | # Prepare output metadata. 289 | output_data: Dict[str, Any] = { 290 | "model": sanitized_model, 291 | "task_type": task_type, 292 | "user_prompt": user_prompt, 293 | "generated_scene_description": cleaned_text, 294 | "iteration": i + 1, 295 | } 296 | 297 | # Log the metadata to WandB. 298 | wandb.log(output_data) 299 | 300 | # Log the metadata to Neptune. 301 | neptune_run[ 302 | f"prompt/{sanitized_model}/{task_type}/iteration_{i + 1}" 303 | ] = output_data 304 | 305 | # Save the generated description and metadata to a JSONL file. 306 | with open(OUTPUT_FILENAME, "a", encoding="utf-8") as f: 307 | f.write(json.dumps(output_data, ensure_ascii=False) + "\n") 308 | 309 | # Finish logging for both platforms. 310 | logger.info("Main routine completed.") 311 | 312 | # Create a Table of Contents for the generated prompts. 313 | results_df: pd.DataFrame = pd.read_json(OUTPUT_FILENAME, lines=True) 314 | 315 | # Log the Table of Contents to both platforms. 316 | wandb.log({"results": wandb.Table(data=results_df)}) 317 | neptune_run["output"].upload(neptune.types.File.as_html(results_df)) 318 | 319 | # Finish the run for both platforms. 320 | wandb.finish() 321 | neptune_run.stop() 322 | 323 | 324 | if __name__ == "__main__": 325 | main() 326 | -------------------------------------------------------------------------------- /src/data/mrt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import random 4 | import argparse 5 | import json 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 9 | from matplotlib.gridspec import GridSpec 10 | 11 | # Global dictionary of polycube shapes 12 | SHAPES = { 13 | "Snake": [ 14 | (0, 0, 0), (1, 0, 0), (1, 1, 0), (1, 1, 1), 15 | (1, 2, 1), (2, 2, 1), (2, 2, 2), (2, 3, 2), 16 | ], 17 | "Zigzag": [ 18 | (0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0), 19 | (2, 1, 1), (2, 2, 1), (3, 2, 1), (3, 2, 2), 20 | ], 21 | "SnakeComplex1": [ 22 | (0, 0, 0), (1, 0, 0), (2, 0, 0), (2, 1, 0), 23 | (2, 1, 1), (2, 2, 1), (1, 2, 1), (1, 3, 1), (1, 3, 2), 24 | ], 25 | "HookedCorner": [ 26 | (0, 0, 0), (1, 0, 0), (2, 0, 0), (0, 1, 0), 27 | (0, 2, 0), (0, 2, 1), (0, 2, 2), 28 | ], 29 | "TopPlate": [ 30 | (0, 0, 0), (0, 1, 0), (0, 2, 0), 31 | (0, 2, 1), (1, 2, 1), (2, 2, 1), 32 | ], 33 | "CornerStaircase": [ 34 | (0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0), 35 | (1, 1, 0), (2, 1, 0), (3, 1, 0), (3, 2, 0), (3, 3, 0), 36 | ], 37 | "TripleArm": [ 38 | (3, -1, 0), (3, -1, 1), (3, -1, 2), (0, 0, 0), 39 | (1, 0, 0), (2, 0, 0), (3, 0, 0), (0, 1, 0), (0, 2, 0), 40 | ], 41 | } 42 | 43 | # Shapes available for each difficulty 44 | EASY_SHAPES = ["Snake", "HookedCorner", "TopPlate", "CornerStaircase", "TripleArm"] 45 | COMPLEX_SHAPES = list(SHAPES.keys()) 46 | 47 | # Dynamic similar-object mapping 48 | all_shape_keys = list(SHAPES.keys()) 49 | SIMILAR_MAPPING = { 50 | key: [s for s in all_shape_keys if s != key] for key in all_shape_keys 51 | } 52 | 53 | 54 | def set_axes_equal(ax, all_vertices): 55 | """Make the aspect ratio equal and remove visual distractions.""" 56 | all_vertices = np.array(all_vertices) 57 | x_limits = [np.min(all_vertices[:, 0]), np.max(all_vertices[:, 0])] 58 | y_limits = [np.min(all_vertices[:, 1]), np.max(all_vertices[:, 1])] 59 | z_limits = [np.min(all_vertices[:, 2]), np.max(all_vertices[:, 2])] 60 | 61 | x_range = x_limits[1] - x_limits[0] 62 | y_range = y_limits[1] - y_limits[0] 63 | z_range = z_limits[1] - z_limits[0] 64 | max_range = max(x_range, y_range, z_range) 65 | 66 | x_mid = np.mean(x_limits) 67 | y_mid = np.mean(y_limits) 68 | z_mid = np.mean(z_limits) 69 | 70 | ax.set_xlim(x_mid - max_range / 2, x_mid + max_range / 2) 71 | ax.set_ylim(y_mid - max_range / 2, y_mid + max_range / 2) 72 | ax.set_zlim(z_mid - max_range / 2, z_mid + max_range / 2) 73 | ax.set_box_aspect([1, 1, 1]) 74 | 75 | # Set ticks and remove labels 76 | ax.set_xticks(np.linspace(x_limits[0], x_limits[1], 5)) 77 | ax.set_yticks(np.linspace(y_limits[0], y_limits[1], 5)) 78 | ax.set_zticks(np.linspace(z_limits[0], z_limits[1], 5)) 79 | ax.set_xticklabels([]) 80 | ax.set_yticklabels([]) 81 | ax.set_zticklabels([]) 82 | 83 | # Remove 3D visual elements for complex mode 84 | if hasattr(ax, '_remove_3d_elements'): 85 | ax.xaxis.pane.set_visible(False) 86 | ax.yaxis.pane.set_visible(False) 87 | ax.zaxis.pane.set_visible(False) 88 | ax.grid(False) 89 | ax._axis3don = False 90 | 91 | 92 | def cube_vertices(origin, size=1.0): 93 | """Return the 8 corner vertices of a cube.""" 94 | x, y, z = origin 95 | return np.array([ 96 | [x, y, z], [x + size, y, z], [x + size, y + size, z], [x, y + size, z], 97 | [x, y, z + size], [x + size, y, z + size], 98 | [x + size, y + size, z + size], [x, y + size, z + size], 99 | ]) 100 | 101 | 102 | def plot_cubes(ax, vertices, title="", facecolor="white", hide_3d_elements=False): 103 | """Plot cubes using their vertices.""" 104 | if hide_3d_elements: 105 | ax._remove_3d_elements = True 106 | 107 | n_cubes = len(vertices) // 8 108 | vertices_reshaped = vertices.reshape((n_cubes, 8, 3)) 109 | 110 | for cube_verts in vertices_reshaped: 111 | faces = [ 112 | [0, 1, 2, 3], [4, 5, 6, 7], [0, 1, 5, 4], 113 | [2, 3, 7, 6], [1, 2, 6, 5], [0, 3, 7, 4], 114 | ] 115 | for face in faces: 116 | polygon = Poly3DCollection( 117 | [cube_verts[face]], facecolors=facecolor, 118 | edgecolors="black", alpha=1.0, 119 | ) 120 | ax.add_collection3d(polygon) 121 | 122 | ax.set_title(title, fontsize=12) 123 | set_axes_equal(ax, vertices) 124 | 125 | 126 | def generate_shape_vertices(shape_name, cube_size=1.0): 127 | """Generate vertices for a given shape.""" 128 | if shape_name not in SHAPES: 129 | raise ValueError(f"Unknown shape {shape_name}") 130 | 131 | cube_origins = SHAPES[shape_name] 132 | all_vertices = [] 133 | for origin in cube_origins: 134 | corners = cube_vertices(origin, size=cube_size) 135 | all_vertices.append(corners) 136 | return np.vstack(all_vertices) 137 | 138 | 139 | def get_transformed_candidate(transformation_func, original, max_attempts=10): 140 | """Apply transformation until result differs from original.""" 141 | for _ in range(max_attempts): 142 | candidate = transformation_func(original) 143 | if not np.allclose(candidate, original, atol=1e-6): 144 | return candidate 145 | return candidate 146 | 147 | 148 | def transform_rotate(vertices, difficulty="easy"): 149 | """Rotate shape based on difficulty level.""" 150 | center = vertices.mean(axis=0) 151 | shifted = vertices - center 152 | 153 | if difficulty == "easy": 154 | # Single axis rotation with simple angles 155 | axis = random.choice(["x", "y", "z"]) 156 | angle = np.deg2rad(random.choice([-90, 90, 180])) 157 | 158 | if axis == "x": 159 | R = np.array([ 160 | [1, 0, 0], 161 | [0, np.cos(angle), -np.sin(angle)], 162 | [0, np.sin(angle), np.cos(angle)], 163 | ]) 164 | elif axis == "y": 165 | R = np.array([ 166 | [np.cos(angle), 0, np.sin(angle)], 167 | [0, 1, 0], 168 | [-np.sin(angle), 0, np.cos(angle)], 169 | ]) 170 | else: # z axis 171 | R = np.array([ 172 | [np.cos(angle), -np.sin(angle), 0], 173 | [np.sin(angle), np.cos(angle), 0], 174 | [0, 0, 1], 175 | ]) 176 | else: # complex 177 | # Multi-axis rotation with varied angles 178 | angle_x = np.deg2rad(np.random.choice([0, 60, 90, 120])) 179 | angle_y = np.deg2rad(np.random.choice([0, 60, 90, 120])) 180 | angle_z = np.deg2rad(np.random.choice([0, 60, 90, 120])) 181 | 182 | Rx = np.array([ 183 | [1, 0, 0], 184 | [0, np.cos(angle_x), -np.sin(angle_x)], 185 | [0, np.sin(angle_x), np.cos(angle_x)], 186 | ]) 187 | Ry = np.array([ 188 | [np.cos(angle_y), 0, np.sin(angle_y)], 189 | [0, 1, 0], 190 | [-np.sin(angle_y), 0, np.cos(angle_y)], 191 | ]) 192 | Rz = np.array([ 193 | [np.cos(angle_z), -np.sin(angle_z), 0], 194 | [np.sin(angle_z), np.cos(angle_z), 0], 195 | [0, 0, 1], 196 | ]) 197 | R = Rz @ Ry @ Rx 198 | 199 | rotated = (R @ shifted.T).T 200 | return rotated + center 201 | 202 | 203 | def transform_mirror(vertices, difficulty="easy"): 204 | """Mirror shape based on difficulty level.""" 205 | center = vertices.mean(axis=0) 206 | shifted = vertices - center 207 | mirrored = shifted.copy() 208 | 209 | if difficulty == "easy": 210 | # Mirror across XY plane (Z-axis) 211 | mirrored[:, 2] = -mirrored[:, 2] 212 | else: # complex 213 | # Mirror across random axis 214 | axis = random.choice([0, 1, 2]) 215 | mirrored[:, axis] = -mirrored[:, axis] 216 | 217 | return mirrored + center 218 | 219 | 220 | def get_visually_similar_candidate(chosen_shape_name, original_vertices, cube_size=1.0, difficulty="easy"): 221 | """Get a similar shape candidate.""" 222 | if chosen_shape_name in SIMILAR_MAPPING: 223 | similar_candidates = SIMILAR_MAPPING[chosen_shape_name][:] 224 | random.shuffle(similar_candidates) 225 | 226 | for similar_shape_name in similar_candidates: 227 | if similar_shape_name in SHAPES: 228 | candidate_vertices = generate_shape_vertices(similar_shape_name, cube_size) 229 | candidate_vertices = get_transformed_candidate( 230 | lambda v: transform_rotate(v, difficulty), candidate_vertices 231 | ) 232 | if (candidate_vertices.shape != original_vertices.shape or 233 | not np.allclose(candidate_vertices, original_vertices, atol=1e-6)): 234 | return candidate_vertices 235 | return None 236 | 237 | 238 | def generate_one_image(index, difficulty="easy", facecolor="white", outdir="data/mrt"): 239 | """Generate a single MRT image based on difficulty.""" 240 | cube_size = 1.0 241 | 242 | # Select shapes based on difficulty 243 | shapes_list = EASY_SHAPES if difficulty == "easy" else COMPLEX_SHAPES 244 | shape_name = random.choice(shapes_list) 245 | original_vertices = generate_shape_vertices(shape_name, cube_size=cube_size) 246 | 247 | # Generate correct candidate (rotation) 248 | correct_candidate = get_transformed_candidate( 249 | lambda v: transform_rotate(v, difficulty), original_vertices 250 | ) 251 | 252 | # Generate wrong candidates 253 | mirror_candidate = get_transformed_candidate( 254 | lambda v: transform_rotate(transform_mirror(v, difficulty), difficulty), 255 | original_vertices 256 | ) 257 | 258 | similar_candidate = get_visually_similar_candidate( 259 | shape_name, original_vertices, cube_size, difficulty 260 | ) 261 | if similar_candidate is None: 262 | similar_candidate = mirror_candidate 263 | 264 | # Set up candidates based on difficulty 265 | if difficulty == "easy": 266 | candidates = [ 267 | ("rotate", correct_candidate), 268 | ("mirror", mirror_candidate), 269 | ("similar", similar_candidate), 270 | ] 271 | num_candidates = 3 272 | figure_size = (6, 6) 273 | else: # complex 274 | mirror_candidate2 = get_transformed_candidate( 275 | lambda v: transform_rotate(transform_mirror(v, difficulty), difficulty), 276 | original_vertices 277 | ) 278 | candidates = [ 279 | ("rotate", correct_candidate), 280 | ("mirror", mirror_candidate), 281 | ("similar", similar_candidate), 282 | ("mirror2", mirror_candidate2), 283 | ] 284 | num_candidates = 4 285 | figure_size = (12, 8) 286 | 287 | random.shuffle(candidates) 288 | correct_candidate_index = [ 289 | i for i, cand in enumerate(candidates) if cand[0] == "rotate" 290 | ][0] 291 | 292 | # Create figure 293 | fig = plt.figure(figsize=figure_size) 294 | gs = GridSpec(2, num_candidates, height_ratios=[0.5, 1], wspace=0.1, hspace=0.1) 295 | 296 | # Plot original shape 297 | ax_orig = fig.add_subplot(gs[0, :], projection="3d") 298 | plot_cubes(ax_orig, original_vertices, title="Original Shape", 299 | facecolor=facecolor, hide_3d_elements=(difficulty == "complex")) 300 | 301 | # Plot candidates 302 | for i in range(num_candidates): 303 | ax = fig.add_subplot(gs[1, i], projection="3d") 304 | _, candidate_vertices = candidates[i] 305 | plot_cubes(ax, candidate_vertices, title=f"Option {chr(65 + i)}", 306 | facecolor=facecolor, hide_3d_elements=(difficulty == "complex")) 307 | 308 | # Save image 309 | filename = f"{shape_name}_{index}.png" 310 | output_path = os.path.join(outdir, filename) 311 | plt.savefig(output_path, dpi=60, bbox_inches="tight", pad_inches=0) 312 | plt.close(fig) 313 | 314 | # Save metadata 315 | metadata = { 316 | "filename": filename, 317 | "difficulty": difficulty, 318 | "shape": shape_name, 319 | "candidate_order": [tag for tag, _ in candidates], 320 | "answer": chr(65 + correct_candidate_index), 321 | } 322 | with open(os.path.join(outdir, "metadata.jsonl"), "a") as f: 323 | f.write(json.dumps(metadata) + "\n") 324 | 325 | 326 | def main(): 327 | parser = argparse.ArgumentParser( 328 | description="Generate mental rotation test images with variable difficulty." 329 | ) 330 | parser.add_argument( 331 | "--difficulty", "-d", type=str, choices=["easy", "complex"], default="easy", 332 | help="Difficulty level: 'easy' (3 candidates, simple rotations) or 'complex' (4 candidates, complex transformations)" 333 | ) 334 | parser.add_argument( 335 | "--num_images", "-n", type=int, default=1, 336 | help="Number of images to generate" 337 | ) 338 | parser.add_argument( 339 | "--color", "-c", type=str, default="white", 340 | help="Color for the polycubes" 341 | ) 342 | parser.add_argument( 343 | "--seed", "-s", type=int, default=69, 344 | help="Seed for reproducible results" 345 | ) 346 | parser.add_argument( 347 | "--outdir", "-o", type=str, default=None, 348 | help="Output directory (defaults to data/mrt/{difficulty})" 349 | ) 350 | 351 | args = parser.parse_args() 352 | 353 | # Set default output directory based on difficulty 354 | if args.outdir is None: 355 | args.outdir = f"data/mrt/{args.difficulty}" 356 | 357 | os.makedirs(args.outdir, exist_ok=True) 358 | 359 | # Set seed for reproducibility 360 | if args.seed is not None: 361 | random.seed(args.seed) 362 | np.random.seed(args.seed) 363 | 364 | # Generate images 365 | for i in range(args.num_images): 366 | generate_one_image(i, difficulty=args.difficulty, 367 | facecolor=args.color, outdir=args.outdir) 368 | 369 | print(f"Generated {args.num_images} {args.difficulty} MRT images in {args.outdir}") 370 | 371 | 372 | if __name__ == "__main__": 373 | main() 374 | -------------------------------------------------------------------------------- /src/data/create_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import logging 5 | import random 6 | import numpy as np 7 | from diffusers import FluxPipeline, StableDiffusion3Pipeline, DiffusionPipeline 8 | import gc 9 | from tqdm import tqdm 10 | 11 | """ 12 | Image generation module for creating synthetic images using diffusion models. 13 | 14 | This module provides a wrapper for diffusion models and utilities for batch processing 15 | image generation tasks. It supports Stable Diffusion 3.5 and FLUX.1-dev models with 16 | optimized memory management and error handling. 17 | """ 18 | 19 | # Configure logging 20 | LOG_DIR = "logs" 21 | os.makedirs(LOG_DIR, exist_ok=True) 22 | LOG_FILE = os.path.join(LOG_DIR, "image_creation.log") 23 | logging.basicConfig( 24 | level=logging.INFO, 25 | format="%(asctime)s - %(levelname)s - %(message)s", 26 | handlers=[logging.FileHandler(LOG_FILE)], 27 | ) 28 | 29 | def set_seed(seed: int = 42): 30 | """ 31 | Set random seeds for reproducibility across all libraries. 32 | 33 | Args: 34 | seed (int): The random seed to use. Defaults to 42. 35 | """ 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed(seed) 40 | torch.cuda.manual_seed_all(seed) 41 | # Make CuDNN deterministic 42 | torch.backends.cudnn.deterministic = True 43 | torch.backends.cudnn.benchmark = False 44 | logging.info(f"Random seed set to {seed} for reproducibility") 45 | 46 | 47 | class DiffusionPipelineWrapper: 48 | """ 49 | A wrapper class for diffusion model pipelines with optimized memory management. 50 | 51 | This class provides a unified interface for different diffusion models, handles 52 | lazy loading/unloading of models to optimize GPU memory usage, and provides 53 | consistent image generation capabilities. 54 | 55 | Attributes: 56 | model_id (str): Identifier for the diffusion model to use 57 | steps (int): Number of inference steps for image generation 58 | scale (float): Guidance scale for controlling adherence to prompt 59 | pipeline: The loaded diffusion pipeline (None when unloaded) 60 | 61 | Supported Models: 62 | - stabilityai/stable-diffusion-3.5-large 63 | - black-forest-labs/FLUX.1-dev 64 | """ 65 | 66 | def __init__(self, model_id: str, steps: int = 40, scale: float = 4.5): 67 | """ 68 | Initialize the diffusion pipeline wrapper. 69 | 70 | Args: 71 | model_id (str): The model identifier for the diffusion model 72 | steps (int, optional): Number of inference steps. Defaults to 40. 73 | scale (float, optional): Guidance scale value. Defaults to 4.5. 74 | 75 | Raises: 76 | ValueError: If an unsupported model_id is provided 77 | """ 78 | self.model_id = model_id 79 | self.steps = steps 80 | self.scale = scale 81 | self.pipeline = None 82 | logging.info(f"Initializing wrapper for {model_id}") 83 | 84 | def _create_pipeline(self, model_id: str): 85 | """ 86 | Create and configure the appropriate diffusion pipeline based on model ID. 87 | 88 | Args: 89 | model_id (str): The model identifier 90 | 91 | Returns: 92 | Union[StableDiffusion3Pipeline, FluxPipeline]: Configured pipeline instance 93 | 94 | Raises: 95 | ValueError: If the model_id is not supported 96 | """ 97 | if model_id == "stabilityai/stable-diffusion-3.5-large": 98 | pipe = StableDiffusion3Pipeline.from_pretrained( 99 | "bin/models/diffusers/stable-diffusion-3.5-large", 100 | device_map="balanced", 101 | ) 102 | elif model_id == "black-forest-labs/FLUX.1-dev": 103 | pipe = FluxPipeline.from_pretrained( 104 | "bin/models/diffusers/FLUX.1-dev", 105 | torch_dtype=torch.bfloat16, 106 | device_map="balanced", 107 | ) 108 | else: 109 | raise ValueError( 110 | "Unsupported model id. Use 'stabilityai/stable-diffusion-3.5-large' or 'black-forest-labs/FLUX.1-dev'." 111 | ) 112 | return pipe 113 | 114 | def load_pipeline(self): 115 | """ 116 | Lazy load the diffusion pipeline to optimize memory usage. 117 | 118 | This method implements lazy loading, only creating the pipeline when needed. 119 | Subsequent calls return the already loaded pipeline without recreating it. 120 | 121 | Returns: 122 | Union[StableDiffusion3Pipeline, FluxPipeline]: The loaded pipeline 123 | """ 124 | if self.pipeline is None: 125 | self.pipeline = self._create_pipeline(self.model_id) 126 | return self.pipeline 127 | 128 | def unload_pipeline(self): 129 | """ 130 | Unload the pipeline and free GPU memory. 131 | 132 | This method properly deallocates the pipeline from memory, clears CUDA cache, 133 | and forces garbage collection to maximize memory availability for other models. 134 | Essential for processing multiple models sequentially on limited GPU memory. 135 | """ 136 | if self.pipeline is not None: 137 | del self.pipeline 138 | self.pipeline = None 139 | torch.cuda.empty_cache() 140 | gc.collect() 141 | logging.info(f"Pipeline for {self.model_id} unloaded from memory") 142 | 143 | def _call_pipeline(self, prompt: str): 144 | """ 145 | Execute the pipeline with appropriate parameters based on model type. 146 | 147 | Args: 148 | prompt (str): The text prompt for image generation 149 | 150 | Returns: 151 | GenerationOutput: The pipeline output containing generated images 152 | 153 | Raises: 154 | ValueError: If the pipeline type is not supported 155 | 156 | Note: 157 | Clips the prompt to 300 characters for CLIP compatibility and uses 158 | the full prompt for secondary prompt parameters when available. 159 | """ 160 | clip_prompt = prompt[:300] 161 | if isinstance(self.pipeline, StableDiffusion3Pipeline): 162 | return self.pipeline( 163 | prompt=clip_prompt, 164 | prompt_3=prompt, 165 | negative_prompt="bad anatomy, poorly drawn face, low resolution, blurry, artifacts, bad lighting, bad composition, cartoonish", 166 | num_inference_steps=self.steps, 167 | guidance_scale=self.scale, 168 | ) 169 | elif isinstance(self.pipeline, FluxPipeline): 170 | return self.pipeline( 171 | prompt=clip_prompt, 172 | prompt_2=prompt, 173 | negative_prompt="bad anatomy, poorly drawn face, low resolution, blurry, artifacts, bad lighting, bad composition, cartoonish", 174 | num_inference_steps=self.steps, 175 | guidance_scale=self.scale, 176 | ) 177 | else: 178 | raise ValueError("Unsupported pipeline type.") 179 | 180 | def generate_image(self, prompt: str, output_filename: str, output_dir: str) -> str: 181 | """ 182 | Generate an image from a text prompt and save it to disk. 183 | 184 | Args: 185 | prompt (str): Text description for image generation 186 | output_filename (str): Filename for the saved image 187 | output_dir (str): Directory path where the image will be saved 188 | 189 | Returns: 190 | str: Full path to the saved image file 191 | 192 | Raises: 193 | Exception: Re-raises any exception that occurs during generation or saving 194 | 195 | Note: 196 | Creates the output directory if it doesn't exist and uses torch.no_grad() 197 | context for memory efficiency during inference. 198 | """ 199 | try: 200 | self.load_pipeline() 201 | with torch.no_grad(): 202 | result = self._call_pipeline(prompt) 203 | image = result.images[0] 204 | os.makedirs(output_dir, exist_ok=True) 205 | image_path = os.path.join(output_dir, output_filename) 206 | image.save(image_path) 207 | logging.info(f"Image saved to {image_path}") 208 | return image_path 209 | except Exception as e: 210 | logging.error(f"Error generating image {output_filename}: {e}") 211 | raise 212 | 213 | def __call__(self, prompt: str, output_filename: str, output_dir: str) -> str: 214 | """ 215 | Make the wrapper callable, delegating to generate_image method. 216 | 217 | Args: 218 | prompt (str): Text description for image generation 219 | output_filename (str): Filename for the saved image 220 | output_dir (str): Directory path where the image will be saved 221 | 222 | Returns: 223 | str: Full path to the saved image file 224 | """ 225 | return self.generate_image(prompt, output_filename, output_dir) 226 | 227 | @staticmethod 228 | def load_metadata_from_json(json_path: str): 229 | """ 230 | Load metadata from a JSONL file containing prompt information. 231 | 232 | Args: 233 | json_path (str): Path to the JSONL file containing metadata 234 | 235 | Returns: 236 | list: List of dictionaries containing prompt metadata 237 | 238 | Note: 239 | Expects JSONL format where each line is a valid JSON object. 240 | Empty lines are automatically skipped. 241 | """ 242 | with open(json_path, "r") as f: 243 | data = [json.loads(line.strip()) for line in f if line.strip()] 244 | return data 245 | 246 | 247 | def process_model_batch(model_config, metadata_list, base_output_dir, output_metadata): 248 | """ 249 | Process all prompts for a single diffusion model in batch. 250 | 251 | This function handles the complete workflow for generating images with one model: 252 | initializing the wrapper, processing all prompts with progress tracking, 253 | collecting metadata, and proper cleanup. 254 | 255 | Args: 256 | model_config (dict): Configuration dictionary containing: 257 | - model_id (str): The model identifier 258 | - steps (int, optional): Number of inference steps 259 | - scale (float, optional): Guidance scale 260 | metadata_list (list): List of prompt metadata dictionaries 261 | base_output_dir (str): Base directory for saving images 262 | output_metadata (list): List to append generation metadata to 263 | 264 | Note: 265 | Automatically handles errors for individual prompts and continues processing. 266 | Ensures proper cleanup of GPU memory regardless of success or failure. 267 | """ 268 | model_id = model_config["model_id"] 269 | steps = model_config.get("steps", 40) 270 | scale = model_config.get("scale", 4.5) 271 | safe_model_id = model_id.replace("/", "_") 272 | output_dir = os.path.join(base_output_dir, safe_model_id) 273 | 274 | logging.info(f"Processing {len(metadata_list)} prompts for model '{model_id}'") 275 | 276 | wrapper = None 277 | try: 278 | wrapper = DiffusionPipelineWrapper(model_id, steps, scale) 279 | 280 | # Process prompts with progress bar 281 | for idx, item in enumerate(tqdm(metadata_list, desc=f"Generating with {safe_model_id}")): 282 | prompt = item.get("generated_scene_description") 283 | if not prompt: 284 | logging.warning(f"Skipping prompt index {idx} due to missing description.") 285 | continue 286 | 287 | output_filename = f"image_{idx:03d}.png" 288 | try: 289 | image_path = wrapper.generate_image(prompt, output_filename, output_dir) 290 | record = { 291 | "model_id": model_id, 292 | "prompt": prompt, 293 | "image_path": image_path, 294 | "steps": steps, 295 | "scale": scale, 296 | "prompt_index": idx 297 | } 298 | output_metadata.append(record) 299 | 300 | except Exception as e: 301 | logging.error(f"Error generating image for prompt index {idx}: {e}") 302 | continue 303 | 304 | except Exception as e: 305 | logging.error(f"Failed to initialize wrapper for model {model_id}: {e}") 306 | return 307 | finally: 308 | if wrapper: 309 | wrapper.unload_pipeline() 310 | 311 | 312 | def save_metadata_batch(output_metadata, output_file): 313 | """ 314 | Save generation metadata to a JSONL file in batch. 315 | 316 | Args: 317 | output_metadata (list): List of metadata dictionaries to save 318 | output_file (str): Path to the output JSONL file 319 | 320 | Note: 321 | Overwrites the output file if it exists. Each metadata record is written 322 | as a separate JSON line for easy parsing and streaming. 323 | """ 324 | with open(output_file, "w") as f: 325 | for record in output_metadata: 326 | f.write(json.dumps(record) + "\n") 327 | logging.info(f"Saved {len(output_metadata)} records to '{output_file}'") 328 | 329 | 330 | def main(): 331 | """ 332 | Main execution function for the image generation pipeline. 333 | 334 | This function orchestrates the complete image generation workflow: 335 | 1. Loads prompt metadata from JSONL file 336 | 2. Processes each configured diffusion model sequentially 337 | 3. Generates images for all prompts with each model 338 | 4. Saves comprehensive metadata about the generation process 339 | 5. Handles memory optimization between models 340 | 341 | The function is designed to be memory-efficient for multi-GPU setups and 342 | provides comprehensive logging for monitoring long-running generation tasks. 343 | 344 | Configuration: 345 | - Supports multiple diffusion models processed sequentially 346 | - Automatic memory cleanup between models 347 | - Progress tracking with tqdm 348 | - Comprehensive error handling and logging 349 | 350 | Raises: 351 | Exception: Logs and handles any unexpected errors during execution 352 | """ 353 | 354 | set_seed(42) # Set a fixed seed for reproducibility 355 | 356 | logging.info("\n\nStarting image generation process using DiffusionPipelineWrapper.") 357 | try: 358 | diffusion_models = [ 359 | {"model_id": "black-forest-labs/FLUX.1-dev", "steps": 50, "scale": 7.5}, 360 | {"model_id": "stabilityai/stable-diffusion-3.5-large", "steps": 50, "scale": 7.5}, 361 | ] 362 | json_file = "output/prompts/claude3.7-prompt.jsonl" 363 | base_output_dir = "output/images_" + json_file.split("/")[-1].split("-")[0] 364 | 365 | metadata_list = DiffusionPipelineWrapper.load_metadata_from_json(json_file) 366 | if not metadata_list: 367 | logging.error("No metadata found in the JSON file.") 368 | return 369 | 370 | output_metadata = [] 371 | 372 | # Process each model sequentially to optimize memory usage 373 | for model_config in diffusion_models: 374 | process_model_batch(model_config, metadata_list, base_output_dir, output_metadata) 375 | 376 | # Force garbage collection between models 377 | torch.cuda.empty_cache() 378 | gc.collect() 379 | 380 | # Save all metadata at once 381 | output_metadata_file = f"metadata_{json_file.split('/')[-1].split('-')[0]}.jsonl" 382 | save_metadata_batch(output_metadata, output_metadata_file) 383 | 384 | except Exception as err: 385 | logging.exception(f"An unexpected error occurred: {err}") 386 | 387 | 388 | if __name__ == "__main__": 389 | main() 390 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import gc 4 | import json 5 | import random 6 | import hashlib 7 | import pandas as pd 8 | import logging 9 | from typing import List, Dict, Optional 10 | from tqdm import tqdm 11 | from datasets import load_dataset, Dataset 12 | import torch 13 | from torch.utils.data import DataLoader 14 | from pathlib import Path 15 | from datetime import datetime 16 | import numpy as np 17 | from PIL import Image 18 | 19 | from utils.vlm_wrapper import VLMWrapper 20 | 21 | # Configure logging 22 | logging.basicConfig( 23 | level=logging.INFO, 24 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 25 | datefmt="%Y-%m-%d %H:%M:%S", 26 | handlers=[ 27 | logging.StreamHandler() 28 | ] 29 | ) 30 | logger = logging.getLogger(__name__) 31 | 32 | class EvaluationEngine: 33 | """Evaluation engine for all VLM types with enhanced features.""" 34 | 35 | def __init__(self, model_id: str, device_map: str = "auto", seed: int = 42, 36 | use_cot: bool = False, one_shot_example: Optional[Dict] = None): 37 | """ 38 | Initialize the evaluation engine. 39 | 40 | Args: 41 | model_id: HuggingFace model identifier 42 | device_map: Device mapping strategy 43 | seed: Random seed for reproducibility 44 | use_cot: Enable Chain-of-Thought prompting 45 | one_shot_example: One-shot example dictionary 46 | """ 47 | self.model_id = model_id 48 | self.short_name = self._extract_model_name(model_id) 49 | self.device_map = device_map 50 | self.seed = seed 51 | self.use_cot = use_cot 52 | self.one_shot_example = one_shot_example 53 | 54 | # Initialize reproducibility 55 | self._set_seed(seed) 56 | 57 | # Lazy initialization 58 | self.vlm = None 59 | 60 | # Evaluation metadata 61 | self.eval_metadata = { 62 | "model_id": model_id, 63 | "seed": seed, 64 | "use_cot": use_cot, 65 | "has_one_shot": one_shot_example is not None, 66 | "timestamp": datetime.now().isoformat(), 67 | "torch_version": torch.__version__, 68 | } 69 | 70 | logger.info(f"Initialized EvaluationEngine for model: {model_id}") 71 | logger.info(f"Seed: {seed}, CoT: {use_cot}, One-shot: {one_shot_example is not None}") 72 | 73 | def _set_seed(self, seed: int): 74 | """Set all random seeds for reproducibility.""" 75 | random.seed(seed) 76 | np.random.seed(seed) 77 | torch.manual_seed(seed) 78 | if torch.cuda.is_available(): 79 | torch.cuda.manual_seed(seed) 80 | torch.cuda.manual_seed_all(seed) 81 | # Make cudnn deterministic 82 | torch.backends.cudnn.deterministic = True 83 | torch.backends.cudnn.benchmark = False 84 | 85 | # Set environment variables for reproducibility 86 | os.environ['PYTHONHASHSEED'] = str(seed) 87 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8' # For deterministic CUDA operations 88 | 89 | logger.info(f"Set all seeds to {seed} for reproducibility") 90 | 91 | def _extract_model_name(self, model_id: str) -> str: 92 | """Extract a clean model name for file naming.""" 93 | return model_id.split("/")[-1].replace("-", "_") 94 | 95 | def _compute_dataset_hash(self, dataset_id: str, max_samples: Optional[int] = None) -> str: 96 | """Compute a hash of the dataset configuration for reproducibility tracking.""" 97 | config_str = f"{dataset_id}_{max_samples}_{self.seed}" 98 | return hashlib.md5(config_str.encode()).hexdigest()[:8] 99 | 100 | def _load_model(self): 101 | """Load the VLM model lazily to save memory.""" 102 | if self.vlm is None: 103 | logger.info(f"Loading model: {self.model_id}") 104 | self.vlm = VLMWrapper(self.model_id, self.device_map, dtype=torch.bfloat16) 105 | logger.info("Model loaded successfully") 106 | 107 | # Add model-specific metadata 108 | self.eval_metadata.update({ 109 | "model_type": self.vlm.model_type, 110 | "inference_type": self.vlm.config.inference_type, 111 | "device_map": self.device_map, 112 | "dtype": str(self.vlm.dtype), 113 | }) 114 | 115 | def _prepare_messages(self, questions: List[str], images: List[Image.Image]) -> List[List[Dict]]: 116 | """Prepare messages in the required format with optional CoT and one-shot examples.""" 117 | messages = [] 118 | 119 | for question, image in zip(questions, images): 120 | conversation = [] 121 | 122 | # Add system message with model and strategy info 123 | system_message = { 124 | "role": "system", 125 | "content": [{ 126 | "type": "text", 127 | "text": ("You are a spatial reasoning AI assistant specialized in analyzing, " 128 | "understanding, and solving problems involving spatial relationships, " 129 | "geometric transformations, and visual-spatial concepts." 130 | ) 131 | }] 132 | } 133 | conversation.append(system_message) 134 | 135 | # Add one-shot example if provided 136 | if self.one_shot_example: 137 | # Add the example question with its image 138 | conversation.append({ 139 | "role": "user", 140 | "content": [ 141 | {"type": "image", "image": self.one_shot_example["image"]}, 142 | {"type": "text", "text": self._format_question_with_cot(self.one_shot_example["question"])}, 143 | ], 144 | }) 145 | 146 | # Add the example response (with CoT reasoning if available) 147 | example_response = self.one_shot_example.get("reasoning", "") 148 | if example_response and self.use_cot: 149 | example_response += f"\n\nTherefore, the answer is: {self.one_shot_example['answer']}" 150 | else: 151 | example_response = self.one_shot_example["answer"] 152 | 153 | conversation.append({ 154 | "role": "assistant", 155 | "content": example_response 156 | }) 157 | 158 | # Add the actual question 159 | conversation.append({ 160 | "role": "user", 161 | "content": [ 162 | {"type": "image", "image": image}, 163 | {"type": "text", "text": self._format_question_with_cot(question)}, 164 | ], 165 | }) 166 | 167 | messages.append(conversation) 168 | 169 | return messages 170 | 171 | def _format_question_with_cot(self, question: str) -> str: 172 | """Format question with Chain-of-Thought prompting if enabled.""" 173 | if not self.use_cot: 174 | return question.strip() 175 | 176 | cot_prompt = ( 177 | "Please think step by step and explain your reasoning before providing answering the question.\n" 178 | "Provide the final answer at the end of your reasoning in curly bracket, e.g., {A} or {yes}.\n\n" 179 | f"Question: {question.strip()}\n\n" 180 | ) 181 | 182 | return cot_prompt 183 | 184 | def _process_batch(self, batch: Dict[str, List], batch_idx: int, total_batches: int) -> List[Dict[str, str]]: 185 | """ 186 | Process a single batch of examples with unified interface. 187 | 188 | Args: 189 | batch: Dictionary containing batch data 190 | batch_idx: Current batch index 191 | total_batches: Total number of batches 192 | 193 | Returns: 194 | List of results for this batch 195 | """ 196 | try: 197 | batch_size = len(batch["question"]) 198 | logger.debug(f"Processing batch {batch_idx + 1}/{total_batches} with {batch_size} examples") 199 | 200 | # Prepare inputs 201 | messages = self._prepare_messages(batch["question"], batch["image"]) 202 | 203 | # Unified batch processing leveraging unified preprocessing in VLMWrapper 204 | results = self._process_batch_standard(messages, batch, batch_idx) 205 | 206 | return results 207 | 208 | except Exception as e: 209 | logger.error(f"Error processing batch {batch_idx + 1}: {e}", exc_info=True) 210 | # Return empty results for failed batch 211 | return [{ 212 | "response": f"ERROR: {str(e)}", 213 | "answer": answer, 214 | "question": question, 215 | "raw_response": f"ERROR: {str(e)}", 216 | "batch_idx": batch_idx, 217 | "example_idx": idx, 218 | } for idx, (question, answer) in enumerate(zip(batch["question"], batch["answer"]))] 219 | 220 | def _process_batch_standard(self, messages: List[List[Dict]], batch: Dict[str, List], batch_idx: int) -> List[Dict[str, str]]: 221 | """Process batch for standard VLM models.""" 222 | results = [] 223 | 224 | with self.vlm.memory_efficient_mode(): 225 | # Preprocess the entire batch 226 | inputs = self.vlm.preprocess(conversation=messages, image_input=batch["image"]) 227 | 228 | 229 | generated_ids = self.vlm.generate( 230 | inputs, 231 | max_new_tokens=1024, 232 | do_sample=False 233 | ) 234 | 235 | output_texts = self.vlm.decode(generated_ids) 236 | 237 | del inputs, generated_ids # Free memory 238 | 239 | # Process results for each example in the batch 240 | for idx, (raw_prediction, question, ground_truth) in enumerate(zip(output_texts, batch["question"], batch["answer"])): 241 | results.append({ 242 | "question": question, 243 | "response": raw_prediction, 244 | "gold answer": ground_truth, 245 | "batch_idx": batch_idx, 246 | "example_idx": idx, 247 | }) 248 | 249 | return results 250 | 251 | def evaluate(self, dataset_id: str, batch_size: int = 16, max_samples: Optional[int] = None, 252 | sample_strategy: str = "first", num_workers: int = 4) -> str: 253 | """ 254 | Evaluate the model on the specified dataset with reproducible sampling. 255 | 256 | Args: 257 | dataset_id: HuggingFace dataset identifier 258 | batch_size: Number of examples per batch 259 | max_samples: Maximum number of samples to process 260 | sample_strategy: Sampling strategy ("first", "random", "stratified") 261 | num_workers: Number of parallel data loading workers 262 | 263 | Returns: 264 | Path to the saved results CSV file 265 | """ 266 | # Load model lazily 267 | self._load_model() 268 | 269 | # Load and sample dataset 270 | try: 271 | data = load_dataset(dataset_id, split="train") 272 | original_size = len(data) 273 | 274 | # Apply sampling strategy 275 | if max_samples and max_samples < original_size: 276 | data = self._sample_dataset(data, max_samples, sample_strategy) 277 | 278 | logger.info(f"Loaded dataset: {original_size} total, {len(data)} selected samples") 279 | 280 | # Add dataset hash to metadata 281 | self.eval_metadata.update({ 282 | "dataset_id": dataset_id, 283 | "dataset_hash": self._compute_dataset_hash(dataset_id, max_samples), 284 | "original_size": original_size, 285 | "sampled_size": len(data), 286 | "sample_strategy": sample_strategy, 287 | "max_samples": max_samples, 288 | }) 289 | 290 | except Exception as e: 291 | logger.error(f"Failed to load dataset {dataset_id}: {e}") 292 | raise 293 | 294 | # Validate dataset format 295 | required_columns = ["question", "answer", "image"] 296 | missing_columns = [col for col in required_columns if col not in data.column_names] 297 | if missing_columns: 298 | raise ValueError(f"Dataset missing required columns: {missing_columns}") 299 | 300 | # Convert to iterable format for efficient batching 301 | # Use HuggingFace's built-in batching which is more memory efficient 302 | data = data.with_format("python") # Ensure consistent format 303 | 304 | # Process dataset in batches using DataLoader for prefetching 305 | all_results = [] 306 | total_batches = (len(data) + batch_size - 1) // batch_size 307 | 308 | # Create DataLoader with prefetching for parallel data loading 309 | # Note: prefetch_factor only valid when num_workers > 0 310 | dataloader_kwargs = { 311 | "batch_size": batch_size, 312 | "shuffle": False, 313 | "num_workers": num_workers, 314 | "pin_memory": torch.cuda.is_available(), 315 | "collate_fn": self._collate_batch, 316 | } 317 | if num_workers > 0: 318 | dataloader_kwargs["prefetch_factor"] = 2 319 | 320 | dataloader = DataLoader(data, **dataloader_kwargs) 321 | with tqdm(total=total_batches, desc="Processing batches", colour="green") as pbar: 322 | for batch_idx, batch in enumerate(dataloader): 323 | batch_results = self._process_batch(batch, batch_idx, total_batches) 324 | all_results.extend(batch_results) 325 | 326 | pbar.update(1) 327 | pbar.set_postfix({ 328 | "processed": len(all_results), 329 | "memory": f"{torch.cuda.memory_allocated() / 1e9:.1f}GB" if torch.cuda.is_available() else "N/A", 330 | "success_rate": f"{sum(1 for r in all_results if not r['response'].startswith('ERROR')) / len(all_results) * 100:.1f}%" 331 | }) 332 | 333 | gc.collect() 334 | if torch.cuda.is_available(): 335 | torch.cuda.empty_cache() 336 | 337 | # Save results with metadata 338 | output_path = self._save_results_with_metadata(all_results, dataset_id) 339 | logger.info(f"Evaluation completed. Results saved to: {output_path}") 340 | 341 | return output_path 342 | 343 | def _collate_batch(self, batch: List[Dict]) -> Dict[str, List]: 344 | """ 345 | Custom collate function to convert list of dicts to dict of lists. 346 | Handles PIL images properly without tensor conversion. 347 | """ 348 | collated = { 349 | "question": [item["question"] for item in batch], 350 | "answer": [item["answer"] for item in batch], 351 | "image": [item["image"] for item in batch], 352 | } 353 | return collated 354 | 355 | def _sample_dataset(self, dataset: Dataset, max_samples: int, strategy: str) -> Dataset: 356 | """Sample dataset using specified strategy for reproducibility.""" 357 | if strategy == "first": 358 | return dataset.select(range(max_samples)) 359 | elif strategy == "random": 360 | indices = list(range(len(dataset))) 361 | random.shuffle(indices) # Uses the set seed 362 | selected_indices = indices[:max_samples] 363 | return dataset.select(selected_indices) 364 | elif strategy == "stratified": 365 | # Try to stratify by answer if possible, otherwise fall back to random 366 | try: 367 | answers = dataset["answer"] 368 | unique_answers = list(set(answers)) 369 | samples_per_answer = max_samples // len(unique_answers) 370 | 371 | selected_indices = [] 372 | for answer in unique_answers: 373 | answer_indices = [i for i, a in enumerate(answers) if a == answer] 374 | random.shuffle(answer_indices) 375 | selected_indices.extend(answer_indices[:samples_per_answer]) 376 | 377 | # Fill remaining slots randomly 378 | remaining = max_samples - len(selected_indices) 379 | if remaining > 0: 380 | all_indices = set(range(len(dataset))) 381 | available_indices = list(all_indices - set(selected_indices)) 382 | random.shuffle(available_indices) 383 | selected_indices.extend(available_indices[:remaining]) 384 | 385 | return dataset.select(selected_indices[:max_samples]) 386 | except Exception as e: 387 | logger.warning(f"Stratified sampling failed, falling back to random: {e}") 388 | return self._sample_dataset(dataset, max_samples, "random") 389 | else: 390 | raise ValueError(f"Unknown sampling strategy: {strategy}") 391 | 392 | def _save_results_with_metadata(self, results: List[Dict[str, str]], dataset_id: str) -> str: 393 | """Save results with comprehensive metadata for reproducibility.""" 394 | dataset_name = dataset_id.split("/")[-1] 395 | 396 | # Create directory structure with prompting strategy 397 | strategy_dir = self._get_strategy_directory_name() 398 | output_dir = Path("output/evaluations") / dataset_name / strategy_dir 399 | output_dir.mkdir(parents=True, exist_ok=True) 400 | 401 | # Save main results 402 | results_df = pd.DataFrame(results) 403 | output_path = output_dir / f"{self.short_name}.csv" 404 | 405 | try: 406 | results_df.to_csv(output_path, index=False) 407 | logger.info(f"Successfully saved {len(results)} results to {output_path}") 408 | except Exception as e: 409 | logger.error(f"Failed to save results: {e}") 410 | raise 411 | 412 | # Save detailed metadata for reproducibility 413 | metadata_path = output_dir / f"{self.short_name}_metadata.json" 414 | 415 | # Add evaluation statistics to metadata 416 | successful_responses = sum(1 for r in results if not r['response'].startswith('ERROR')) 417 | self.eval_metadata.update({ 418 | "total_samples": len(results), 419 | "successful_responses": successful_responses, 420 | "success_rate": successful_responses / len(results) * 100, 421 | "average_response_length": sum(len(r['response']) for r in results) / len(results), 422 | "output_path": str(output_path), 423 | "evaluation_completed_at": datetime.now().isoformat(), 424 | }) 425 | 426 | with open(metadata_path, 'w') as f: 427 | json.dump(self.eval_metadata, f, indent=2, default=str) 428 | 429 | logger.info(f"Saved metadata to {metadata_path}") 430 | 431 | return str(output_path) 432 | 433 | def _get_strategy_directory_name(self) -> str: 434 | """Generate directory name based on prompting strategy.""" 435 | if self.use_cot and self.one_shot_example: 436 | return "cot_oneshot" 437 | elif self.use_cot: 438 | return "cot_test" 439 | elif self.one_shot_example: 440 | return "oneshot" 441 | else: 442 | return "_normal" 443 | 444 | 445 | def load_one_shot_example(json_path: str) -> Optional[Dict]: 446 | """Load one-shot example from JSON file with image loading.""" 447 | if not json_path: 448 | return None 449 | 450 | # Handle relative paths for JSON file 451 | if not os.path.isabs(json_path): 452 | json_path = os.path.join('/users/stogian/srbench', json_path) 453 | 454 | if not os.path.exists(json_path): 455 | logger.warning(f"One-shot JSON file not found: {json_path}") 456 | return None 457 | 458 | try: 459 | with open(json_path, 'r') as f: 460 | data = json.load(f) 461 | 462 | # Handle image path with multiple fallbacks 463 | image_path = data.get('image_path', '') 464 | if not image_path: 465 | logger.warning("No image_path found in one-shot example") 466 | return None 467 | 468 | # Try multiple path variations 469 | possible_paths = [ 470 | image_path, 471 | os.path.join(os.path.dirname(json_path), os.path.basename(image_path)), 472 | os.path.join('/users/stogian/srbench', image_path.lstrip('./')), 473 | os.path.join('/users/stogian/srbench/example', os.path.basename(image_path)), 474 | ] 475 | 476 | image_loaded = False 477 | for path in possible_paths: 478 | if os.path.exists(path): 479 | try: 480 | data['image'] = Image.open(path).convert('RGB') 481 | logger.info(f"Successfully loaded one-shot image from: {path}") 482 | image_loaded = True 483 | break 484 | except Exception as e: 485 | logger.warning(f"Failed to load image from {path}: {e}") 486 | continue 487 | 488 | if not image_loaded: 489 | logger.error(f"Could not load image from any of these paths: {possible_paths}") 490 | return None 491 | 492 | # Validate required keys 493 | required_keys = ['question', 'answer'] 494 | if not all(key in data for key in required_keys): 495 | logger.warning(f"One-shot example missing required keys: {required_keys}") 496 | return None 497 | 498 | logger.info(f"Loaded one-shot example: {data['question'][:50]}...") 499 | return data 500 | 501 | except Exception as e: 502 | logger.error(f"Failed to load one-shot example: {e}") 503 | return None 504 | 505 | 506 | def parse_args(): 507 | """Parse command-line arguments with comprehensive options.""" 508 | parser = argparse.ArgumentParser( 509 | description="Reproducible evaluation of vision-language models", 510 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 511 | ) 512 | parser.add_argument( 513 | "-m", "--model", 514 | type=str, required=True, 515 | help="Model identifier (e.g., 'meta-llama/Llama-3.2-90B-Vision-Instruct')" 516 | ) 517 | parser.add_argument( 518 | "-d", "--dataset", 519 | type=str, required=True, 520 | help="Dataset identifier (e.g., 'stogian/sr_test')" 521 | ) 522 | parser.add_argument( 523 | "-b", "--batch_size", 524 | type=int, default=16, 525 | help="Batch size for processing" 526 | ) 527 | parser.add_argument( 528 | "--num_workers", 529 | type=int, default=4, 530 | help="Number of data loading workers (0 for main process only)" 531 | ) 532 | parser.add_argument( 533 | "--max_samples", 534 | type=int, default=None, 535 | help="Maximum samples to process (for testing)" 536 | ) 537 | parser.add_argument( 538 | "--sample_strategy", 539 | type=str, default="first", choices=["first", "random", "stratified"], 540 | help="Sampling strategy for subset selection" 541 | ) 542 | parser.add_argument( 543 | "--device_map", 544 | type=str, default="auto", 545 | help="Device mapping strategy" 546 | ) 547 | parser.add_argument( 548 | "--seed", 549 | type=int, default=42, 550 | help="Random seed for reproducibility" 551 | ) 552 | parser.add_argument( 553 | "--cot", "--chain-of-thought", 554 | action="store_true", 555 | help="Enable Chain-of-Thought prompting" 556 | ) 557 | parser.add_argument( 558 | "--one-shot", 559 | type=str, default=None, 560 | help="Path to one-shot example JSON file" 561 | ) 562 | parser.add_argument( 563 | "-v", "--verbose", 564 | action="store_true", 565 | help="Enable verbose logging" 566 | ) 567 | 568 | return parser.parse_args() 569 | 570 | 571 | def main(): 572 | """Main function with comprehensive error handling.""" 573 | 574 | # print gpu info 575 | if torch.cuda.is_available(): 576 | gpu_count = torch.cuda.device_count() 577 | print(f"Found {gpu_count} GPU(s):") 578 | for i in range(gpu_count): 579 | print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({torch.cuda.get_device_properties(i).total_memory / 1e9:.1f}GB)") 580 | print(f"CUDA version: {torch.version.cuda}") 581 | else: 582 | print("No GPU available, using CPU") 583 | 584 | args = parse_args() 585 | 586 | if args.verbose: 587 | logging.getLogger().setLevel(logging.DEBUG) 588 | 589 | # Load one-shot example if provided 590 | one_shot_example = load_one_shot_example(args.one_shot) if args.one_shot else None 591 | 592 | try: 593 | # Create reproducible evaluation engine 594 | eval_engine = EvaluationEngine( 595 | model_id=args.model, 596 | device_map=args.device_map, 597 | seed=args.seed, 598 | use_cot=args.cot, 599 | one_shot_example=one_shot_example 600 | ) 601 | 602 | output_path = eval_engine.evaluate( 603 | dataset_id=args.dataset, 604 | batch_size=args.batch_size, 605 | max_samples=args.max_samples, 606 | sample_strategy=args.sample_strategy, 607 | num_workers=args.num_workers 608 | ) 609 | 610 | print("\n✅ Reproducible evaluation completed successfully!") 611 | print(f"📁 Results saved to: {output_path}") 612 | print(f"🔄 Evaluation can be reproduced using seed: {args.seed}") 613 | 614 | except KeyboardInterrupt: 615 | logger.info("Evaluation interrupted by user") 616 | print("\n⚠️ Evaluation interrupted by user") 617 | except Exception as e: 618 | logger.error(f"Evaluation failed: {e}", exc_info=True) 619 | print(f"\n❌ Evaluation failed: {e}") 620 | raise 621 | 622 | 623 | if __name__ == "__main__": 624 | main() -------------------------------------------------------------------------------- /src/utils/vlm_wrapper.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import requests 4 | import logging 5 | import gc 6 | import numpy as np 7 | from typing import List, Dict, Any, Optional, Union, Tuple 8 | from dataclasses import dataclass 9 | from functools import lru_cache 10 | from contextlib import contextmanager 11 | from PIL import Image 12 | import torchvision.transforms as T 13 | from torchvision.transforms.functional import InterpolationMode 14 | from transformers import ( 15 | Qwen2VLForConditionalGeneration, 16 | Qwen2_5_VLForConditionalGeneration, 17 | AutoProcessor, 18 | LlavaForConditionalGeneration, 19 | LlavaNextForConditionalGeneration, 20 | AutoModelForCausalLM, 21 | GenerationConfig, 22 | AutoModelForVision2Seq, 23 | AutoModelForImageTextToText, 24 | MllamaForConditionalGeneration, 25 | AutoModel, 26 | AutoTokenizer, 27 | Gemma3ForConditionalGeneration, 28 | Glm4vForConditionalGeneration 29 | ) 30 | from qwen_vl_utils import process_vision_info 31 | 32 | # Configure logging 33 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 34 | logger = logging.getLogger(__name__) 35 | 36 | # Constants for InternVL 37 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 38 | IMAGENET_STD = (0.229, 0.224, 0.225) 39 | 40 | @dataclass 41 | class ModelConfig: 42 | """Configuration for different VLM models.""" 43 | model_class: type 44 | processor_class: type 45 | requires_trust_remote_code: bool = False 46 | supports_flash_attention: bool = False 47 | padding_side: str = "left" 48 | special_args: Dict[str, Any] = None 49 | processor_args: Dict[str, Any] = None 50 | inference_type: str = "standard" # standard, internvl, minicpm 51 | 52 | def __post_init__(self): 53 | if self.special_args is None: 54 | self.special_args = {} 55 | if self.processor_args is None: 56 | self.processor_args = {} 57 | 58 | # Model configurations registry 59 | MODEL_CONFIGS = { 60 | "qwen": ModelConfig( 61 | model_class=Qwen2_5_VLForConditionalGeneration, 62 | processor_class=AutoProcessor, 63 | supports_flash_attention=True, 64 | processor_args={"use_fast": True}, 65 | ), 66 | "kimi": ModelConfig( 67 | model_class=AutoModelForCausalLM, 68 | processor_class=AutoProcessor, 69 | requires_trust_remote_code=True, 70 | supports_flash_attention=True, 71 | processor_args={"use_fast": True, "padding_side": "left"}, 72 | ), 73 | "llava": ModelConfig( 74 | model_class=LlavaForConditionalGeneration, 75 | processor_class=AutoProcessor, 76 | special_args={"low_cpu_mem_usage": True}, 77 | processor_args={"use_fast": True} 78 | ), 79 | "llava_next": ModelConfig( 80 | model_class=LlavaNextForConditionalGeneration, 81 | processor_class=AutoProcessor, 82 | special_args={"low_cpu_mem_usage": True}, 83 | processor_args={"use_fast": True} 84 | ), 85 | "idefics": ModelConfig( 86 | model_class=AutoModelForImageTextToText, 87 | processor_class=AutoProcessor, 88 | supports_flash_attention=True, 89 | processor_args={"use_fast": True} 90 | ), 91 | "smolvlm": ModelConfig( 92 | model_class=AutoModelForImageTextToText, 93 | processor_class=AutoProcessor, 94 | supports_flash_attention=True, 95 | special_args={}, 96 | processor_args={"use_fast": True, "padding_side": "left"} 97 | ), 98 | "mllama": ModelConfig( 99 | model_class=MllamaForConditionalGeneration, 100 | processor_class=AutoProcessor, 101 | processor_args={"use_fast": True, "padding_side": "left"} 102 | ), 103 | "minicpm": ModelConfig( 104 | model_class=AutoModel, 105 | processor_class=AutoTokenizer, 106 | requires_trust_remote_code=True, 107 | supports_flash_attention=True, 108 | inference_type="minicpm" 109 | ), 110 | "internvl": ModelConfig( 111 | model_class=AutoModel, 112 | processor_class=AutoTokenizer, 113 | requires_trust_remote_code=True, 114 | supports_flash_attention=True, 115 | inference_type="internvl" 116 | ), 117 | "internvl_hf": ModelConfig( 118 | model_class=AutoModelForImageTextToText, 119 | processor_class=AutoProcessor, 120 | requires_trust_remote_code=True, 121 | supports_flash_attention=True, 122 | processor_args={"use_fast": True}, 123 | inference_type="standard" # Use standard HF generation 124 | ), 125 | "gemma3": ModelConfig( 126 | model_class=Gemma3ForConditionalGeneration, 127 | processor_class=AutoProcessor, 128 | supports_flash_attention=True, 129 | processor_args={"use_fast": True}, 130 | ), 131 | "glm4v": ModelConfig( 132 | model_class=Glm4vForConditionalGeneration, 133 | processor_class=AutoProcessor, 134 | requires_trust_remote_code=True, 135 | supports_flash_attention=True, 136 | padding_side="left", 137 | processor_args={"use_fast": True} 138 | ), 139 | } 140 | 141 | class VLMWrapper: 142 | """Unified Vision-Language Model wrapper supporting all model types.""" 143 | 144 | def __init__(self, model_id: str, device_map: str = "auto", dtype: torch.dtype = torch.bfloat16): 145 | """ 146 | Initialize unified VLM wrapper with automatic model type detection. 147 | 148 | Args: 149 | model_id: HuggingFace model identifier 150 | device_map: Device mapping strategy 151 | dtype: Model precision 152 | """ 153 | self.model_id = model_id 154 | self.model_type = self._detect_model_type(model_id) 155 | self.config = MODEL_CONFIGS[self.model_type] 156 | self.dtype = dtype 157 | self.device_map = self._optimize_device_map(device_map) 158 | 159 | # Lazy initialization 160 | self._model = None 161 | self._processor = None 162 | self._device = None 163 | self._transform = None # For InternVL 164 | 165 | logger.info(f"Initialized VLMWrapper: for {self.model_type} model: {model_id}") 166 | 167 | @property 168 | def model(self): 169 | """Lazy loading of model.""" 170 | if self._model is None: 171 | self._model = self._load_model() 172 | return self._model 173 | 174 | @property 175 | def processor(self): 176 | """Lazy loading of processor.""" 177 | if self._processor is None: 178 | self._processor = self._load_processor() 179 | return self._processor 180 | 181 | @property 182 | def device(self): 183 | """Get model device.""" 184 | if self._device is None: 185 | self._device = self.model.device 186 | return self._device 187 | 188 | @property 189 | def transform(self): 190 | """Lazy load image transform for InternVL.""" 191 | if self._transform is None and self.model_type == "internvl": 192 | self._transform = self._build_transform(input_size=448) 193 | return self._transform 194 | 195 | def _optimize_device_map(self, device_map: str) -> str: 196 | """Optimize device mapping based on available hardware.""" 197 | if device_map == "auto": 198 | if torch.cuda.is_available(): 199 | gpu_count = torch.cuda.device_count() 200 | if gpu_count > 1: 201 | logger.info(f"Multiple GPUs detected ({gpu_count}), using auto device mapping") 202 | return "auto" 203 | else: 204 | return "cuda:0" 205 | else: 206 | logger.warning("CUDA not available, using CPU") 207 | return "cpu" 208 | return device_map 209 | 210 | def _detect_model_type(self, model_id: str) -> str: 211 | """Detect model type from model_id.""" 212 | # Check for specific patterns first (more specific before generic) 213 | model_patterns = { 214 | "internvl_hf": r"OpenGVLab/InternVL.*-HF", # HF-native InternVL models (InternVL3-HF, etc.) 215 | "qwen": r"Qwen/", 216 | "llava": r"llava-hf/llava-1\.5", 217 | "llava_next": r"llava-hf/llava-v1\.6", 218 | "idefics": r"HuggingFaceM4/Idefics", 219 | "smolvlm": r"HuggingFaceTB/SmolVLM", 220 | "mllama": r"meta-llama", 221 | "minicpm": r"openbmb/MiniCPM", 222 | # "internvl": r"OpenGVLab/InternVL", # Original InternVL with batch_chat 223 | "gemma3": r"google/gemma-3", 224 | "kimi": r"moonshotai/Kimi-VL", 225 | "glm4v": r"zai-org/GLM-4", 226 | } 227 | 228 | for model_type, pattern in model_patterns.items(): 229 | if re.search(pattern, model_id): 230 | return model_type 231 | 232 | raise ValueError(f"Unsupported model_id: {model_id}") 233 | 234 | def _load_model(self): 235 | """Load model with optimized settings.""" 236 | try: 237 | model_args = { 238 | "torch_dtype": self.dtype, 239 | "device_map": self.device_map, 240 | } 241 | 242 | if self.config.requires_trust_remote_code: 243 | model_args["trust_remote_code"] = True 244 | 245 | if self.config.supports_flash_attention and torch.cuda.is_available(): 246 | model_args["attn_implementation"] = "flash_attention_2" 247 | 248 | # Add special arguments 249 | model_args.update(self.config.special_args) 250 | 251 | model = self.config.model_class.from_pretrained(self.model_id, **model_args) 252 | # model = torch.compile(model, mode="default") # Compile model for performance 253 | return model.eval() 254 | 255 | except Exception as e: 256 | logger.error(f"Failed to load model {self.model_id}: {e}") 257 | raise 258 | 259 | def _load_processor(self): 260 | """Load processor with optimized settings.""" 261 | try: 262 | processor_args = {} 263 | 264 | if self.config.requires_trust_remote_code: 265 | processor_args["trust_remote_code"] = True 266 | 267 | # Add special arguments 268 | processor_args.update(self.config.processor_args) 269 | 270 | processor = self.config.processor_class.from_pretrained(self.model_id, **processor_args) 271 | 272 | # Configure padding 273 | if hasattr(processor, "tokenizer"): 274 | # Ensure padding_side 275 | processor.tokenizer.padding_side = self.config.padding_side 276 | # Ensure pad_token_id exists; fallback to eos if missing 277 | if getattr(processor.tokenizer, "pad_token_id", None) is None: 278 | eos_id = getattr(processor.tokenizer, "eos_token_id", None) 279 | if eos_id is not None: 280 | processor.tokenizer.pad_token_id = eos_id 281 | 282 | # Set pad token on model generation config if available 283 | if hasattr(processor, "tokenizer") and hasattr(self.model, "generation_config"): 284 | if getattr(self.model.generation_config, "pad_token_id", None) is None: 285 | self.model.generation_config.pad_token_id = processor.tokenizer.pad_token_id 286 | else: 287 | self.model.generation_config.pad_token_id = processor.tokenizer.pad_token_id 288 | 289 | return processor 290 | 291 | except Exception as e: 292 | logger.error(f"Failed to load processor for {self.model_id}: {e}") 293 | raise 294 | 295 | def _build_transform(self, input_size: int = 448) -> T.Compose: 296 | """Create optimized image transform pipeline for InternVL.""" 297 | return T.Compose([ 298 | T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), 299 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 300 | T.ToTensor(), 301 | T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), 302 | ]) 303 | 304 | @lru_cache(maxsize=128) 305 | def load_image_from_url(self, image_url: str) -> Image.Image: 306 | """Load image from URL with caching.""" 307 | try: 308 | response = requests.get(image_url, stream=True, timeout=10) 309 | response.raise_for_status() 310 | return Image.open(response.raw) 311 | except Exception as e: 312 | logger.error(f"Failed to load image from {image_url}: {e}") 313 | raise 314 | 315 | def preprocess(self, conversation: List, image_input: Optional[Union[Image.Image, List[Image.Image]]] = None) -> Any: 316 | # Unified preprocessing for all inference types 317 | return self._preprocess_standard(conversation, image_input) 318 | 319 | 320 | def _preprocess_standard(self, batch_conversations: List, batch_images: Optional[Union[Image.Image, List[Image.Image]]] = None) -> Any: 321 | """Unified preprocessing for all VLM models, including internvl and minicpm.""" 322 | if self.config.inference_type == "internvl": 323 | return self._preprocess_internvl(batch_conversations, batch_images) 324 | if self.config.inference_type == "minicpm": 325 | return self._preprocess_minicpm(batch_conversations, batch_images) 326 | # batch_conversations, batch_images = self._normalize_inputs(conversation, image_input) 327 | try: 328 | return self._preprocess(batch_conversations, batch_images) 329 | 330 | except Exception as e: 331 | logger.error(f"Preprocessing failed for {self.model_type}: {e}", exc_info=True) 332 | raise 333 | 334 | def _preprocess(self, batch_conversations: List[List], batch_images: Optional[Union[Image.Image, List[Image.Image]]]) -> Dict[str, torch.Tensor]: 335 | """Preprocess input data for the model. 336 | Args: 337 | batch_conversations: List of conversations, each a list of turns. 338 | batch_images: List of images corresponding to each conversation. 339 | Returns: 340 | Dictionary of preprocessed inputs ready for model inference. 341 | """ 342 | 343 | # Apply chat template to each conversation 344 | prompts = [ 345 | self.processor.apply_chat_template( 346 | conv, 347 | add_generation_prompt=True, 348 | tokenize=False 349 | ) 350 | for conv in batch_conversations 351 | ] 352 | 353 | if self.model_type in ["mllama", "smolvlm", "idefics", "gemma3", "glm4v", "internvl_hf"]: 354 | images_to_process = [[img] for img in batch_images] 355 | else: 356 | images_to_process = batch_images 357 | 358 | assert len(prompts) == len(images_to_process), "Number of prompts must match number of image inputs" 359 | 360 | # Preprocess inputs 361 | inputs = self.processor( 362 | text=prompts, 363 | images=images_to_process, 364 | return_tensors="pt", 365 | padding=True, 366 | ).to(self.device, dtype=self.dtype) 367 | 368 | 369 | return inputs 370 | 371 | 372 | def _preprocess_internvl(self, conversation: List, image_input: Optional[Union[Image.Image, List[Image.Image]]] = None) -> Dict[str, Any]: 373 | """ 374 | Robust preprocessing for InternVL models. 375 | - Supports batch or single conversations 376 | - Extracts the last user turn to avoid mixing one-shot examples 377 | - Builds a per-sample list of pixel tensors (later concatenated for batch_chat) 378 | - Ensures input dtype matches model parameters 379 | """ 380 | # Normalize to list of conversations 381 | conversation_list = conversation if isinstance(conversation[0], list) else [conversation] 382 | 383 | questions: List[str] = [] 384 | pixel_values_list: List[torch.Tensor] = [] 385 | num_patches_list: List[int] = [] 386 | 387 | # Determine model param dtype to avoid dtype mismatch (e.g., bf16 vs fp16) 388 | try: 389 | model_param_dtype = next(self.model.parameters()).dtype 390 | except Exception: 391 | model_param_dtype = self.dtype 392 | 393 | for conv in conversation_list: 394 | # Find last user turn 395 | user_turns = [turn for turn in conv if isinstance(turn, dict) and turn.get("role") == "user"] 396 | if not user_turns: 397 | raise ValueError("No user turn found in conversation for InternVL") 398 | last_user = user_turns[-1] 399 | 400 | # Extract image and text 401 | image = None 402 | text_parts: List[str] = [] 403 | for item in last_user.get('content', []): 404 | if isinstance(item, dict) and 'image' in item and image is None: 405 | image = item['image'] 406 | elif isinstance(item, dict) and 'text' in item: 407 | text_parts.append(item['text']) 408 | 409 | if image is None and image_input is not None: 410 | # Fallback if image not embedded 411 | image = image_input[0] if isinstance(image_input, list) and image_input else image_input 412 | if image is None: 413 | raise ValueError("No image found in conversation content for InternVL") 414 | 415 | question = " ".join(text_parts).strip() 416 | questions.append(question) 417 | 418 | # Load and preprocess image into patches (N, 3, 448, 448) 419 | patches = self._load_image_internvl(image, max_num=12) 420 | # Match dtype/device to model parameters to avoid conv2d dtype mismatch 421 | patches = patches.to(device=self.device, dtype=model_param_dtype) 422 | pixel_values_list.append(patches) 423 | num_patches_list.append(patches.size(0)) 424 | 425 | return { 426 | "pixel_values": pixel_values_list, # keep per-sample; we will concat at generation 427 | "questions": questions, 428 | "num_patches_list": num_patches_list, 429 | } 430 | 431 | def _preprocess_minicpm(self, conversation: List, image_input: Optional[Union[Image.Image, List[Image.Image]]] = None) -> Tuple: 432 | """Preprocessing for MiniCPM models. Extract the last user turn, then find image/text by keys.""" 433 | # Normalize to list of conversations 434 | conv_list = conversation if isinstance(conversation[0], list) else [conversation] 435 | msgs_batch: List[List[Dict[str, Any]]] = [] 436 | for conv in conv_list: 437 | # Find last user turn (to avoid one-shot example turns) 438 | user_turns = [turn for turn in conv if isinstance(turn, dict) and turn.get("role") == "user"] 439 | if not user_turns: 440 | raise ValueError("No user turn found in conversation") 441 | last_user = user_turns[-1] 442 | # Extract image and text from the content list, regardless of order 443 | img = None 444 | text = "" 445 | for item in last_user.get("content", []): 446 | if isinstance(item, dict) and "image" in item and img is None: 447 | img = item["image"] 448 | elif isinstance(item, dict) and "text" in item and not text: 449 | text = item["text"] 450 | # Fallback to image_input if not embedded in conversation 451 | if img is None and image_input is not None: 452 | if isinstance(image_input, list) and len(image_input) > 0: 453 | img = image_input[0] 454 | else: 455 | img = image_input 456 | if img is None: 457 | # Mirror the KeyError seen in logs for clarity 458 | raise KeyError("image") 459 | # Prepare image and build msgs 460 | np_img = self._prepare_image_minicpm(img) 461 | msgs = [{"role": "user", "content": [np_img, text]}] 462 | msgs_batch.append(msgs) 463 | # Return batch or single 464 | return msgs_batch if len(msgs_batch) > 1 else msgs_batch[0] 465 | 466 | def _load_image_internvl(self, image_input: Any, input_size: int = 448, max_num: int = 12) -> torch.Tensor: 467 | """Load and preprocess image for InternVL with dynamic preprocessing.""" 468 | try: 469 | # Handle different image input types 470 | if isinstance(image_input, str): 471 | image = Image.open(image_input).convert("RGB") 472 | elif isinstance(image_input, dict) and "path" in image_input: 473 | image = Image.open(image_input["path"]).convert("RGB") 474 | elif isinstance(image_input, Image.Image): 475 | image = image_input.convert("RGB") 476 | else: 477 | raise ValueError(f"Unsupported image input type: {type(image_input)}") 478 | 479 | # Dynamic preprocessing 480 | images = self._dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) 481 | 482 | # Apply transforms 483 | pixel_values = [self.transform(img) for img in images] 484 | return torch.stack(pixel_values) 485 | 486 | except Exception as e: 487 | logger.error(f"Failed to load image for InternVL: {e}") 488 | raise 489 | 490 | def _prepare_image_minicpm(self, image_input: Union[str, Image.Image, Dict]) -> np.ndarray: 491 | """Prepare image for MiniCPM processing.""" 492 | try: 493 | # Handle different image input types 494 | if isinstance(image_input, str): 495 | image = Image.open(image_input).convert("RGB") 496 | elif isinstance(image_input, dict) and "path" in image_input: 497 | image = Image.open(image_input["path"]).convert("RGB") 498 | elif isinstance(image_input, Image.Image): 499 | image = image_input.convert("RGB") 500 | else: 501 | image = image_input 502 | 503 | # Convert to numpy array with channel-first format (C x H x W) 504 | np_img = np.array(image) 505 | if len(np_img.shape) == 3: 506 | np_img = np_img.transpose(2, 0, 1) 507 | 508 | return np_img 509 | 510 | except Exception as e: 511 | logger.error(f"Failed to prepare image for MiniCPM: {e}") 512 | raise 513 | 514 | def _dynamic_preprocess(self, image: Image.Image, min_num: int = 1, max_num: int = 12, 515 | image_size: int = 448, use_thumbnail: bool = False) -> List[Image.Image]: 516 | """Dynamic preprocessing for InternVL.""" 517 | orig_width, orig_height = image.size 518 | aspect_ratio = orig_width / orig_height 519 | 520 | # Generate target ratios 521 | target_ratios = [ 522 | (i, j) for n in range(min_num, max_num + 1) 523 | for i in range(1, n + 1) for j in range(1, n + 1) 524 | if min_num <= i * j <= max_num 525 | ] 526 | target_ratios.sort(key=lambda x: x[0] * x[1]) 527 | 528 | target_aspect_ratio = self._find_closest_aspect_ratio( 529 | aspect_ratio, target_ratios, orig_width, orig_height, image_size 530 | ) 531 | 532 | target_width = image_size * target_aspect_ratio[0] 533 | target_height = image_size * target_aspect_ratio[1] 534 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 535 | 536 | # Resize image 537 | resized_img = image.resize((target_width, target_height)) 538 | 539 | # Extract blocks 540 | processed_images = [] 541 | cols = target_width // image_size 542 | 543 | for i in range(blocks): 544 | col = i % cols 545 | row = i // cols 546 | box = ( 547 | col * image_size, 548 | row * image_size, 549 | (col + 1) * image_size, 550 | (row + 1) * image_size, 551 | ) 552 | processed_images.append(resized_img.crop(box)) 553 | 554 | if use_thumbnail and len(processed_images) != 1: 555 | thumbnail_img = image.resize((image_size, image_size)) 556 | processed_images.append(thumbnail_img) 557 | 558 | return processed_images 559 | 560 | def _find_closest_aspect_ratio(self, aspect_ratio: float, target_ratios: List[Tuple[int, int]], 561 | width: int, height: int, image_size: int) -> Tuple[int, int]: 562 | """Find the best matching target aspect ratio for InternVL.""" 563 | best_ratio_diff = float("inf") 564 | best_ratio = (1, 1) 565 | area = width * height 566 | 567 | for ratio in target_ratios: 568 | target_aspect_ratio = ratio[0] / ratio[1] 569 | ratio_diff = abs(aspect_ratio - target_aspect_ratio) 570 | 571 | if ratio_diff < best_ratio_diff: 572 | best_ratio_diff = ratio_diff 573 | best_ratio = ratio 574 | elif (ratio_diff == best_ratio_diff and 575 | area > 0.5 * image_size * image_size * ratio[0] * ratio[1]): 576 | best_ratio = ratio 577 | 578 | return best_ratio 579 | 580 | 581 | 582 | def decode(self, generated_ids: Any) -> List[str]: 583 | """ 584 | Decode generated outputs. 585 | - For standard models: decode token IDs to text (slice per-sample new tokens using prompt lengths). 586 | - For internvl/minicpm: pass through strings/lists. 587 | """ 588 | try: 589 | # Pass-through for models that already return strings 590 | if self.config.inference_type in ("internvl", "minicpm"): 591 | if isinstance(generated_ids, list): 592 | return generated_ids 593 | if isinstance(generated_ids, str): 594 | return [generated_ids] 595 | return [str(generated_ids)] 596 | 597 | return self.processor.batch_decode( 598 | generated_ids, 599 | skip_special_tokens=True, 600 | clean_up_tokenization_spaces=True 601 | ) 602 | 603 | except Exception as e: 604 | logger.error(f"Decoding failed for {self.model_type}: {e}") 605 | raise 606 | 607 | 608 | @torch.inference_mode() 609 | def generate(self, inputs: Any, **generation_kwargs) -> Any: 610 | """ 611 | Generate text using the model with unified interface. 612 | 613 | Args: 614 | inputs: Preprocessed inputs 615 | **generation_kwargs: Generation parameters 616 | 617 | Returns: 618 | Generated outputs (format depends on model type) 619 | """ 620 | try: 621 | if self.config.inference_type == "internvl": 622 | return self._generate_internvl(inputs, **generation_kwargs) 623 | elif self.config.inference_type == "minicpm": 624 | return self._generate_minicpm(inputs, **generation_kwargs) 625 | else: 626 | sequences = self.model.generate( 627 | **inputs, 628 | **generation_kwargs 629 | ) 630 | 631 | # Slice prompt tokens from generated sequences 632 | prompt_length = inputs["input_ids"].shape[-1] 633 | if sequences.dim() == 1: 634 | generated_ids = sequences[prompt_length:] 635 | else: 636 | generated_ids = sequences[:, prompt_length:] 637 | 638 | return generated_ids 639 | 640 | except Exception as e: 641 | logger.error(f"Generation failed: {e}") 642 | raise 643 | 644 | def _generate_internvl(self, inputs: Dict[str, Any], **generation_kwargs) -> List[str]: 645 | """Generate using InternVL's chat API per-sample to avoid batch misalignment issues.""" 646 | pixel_values_in = inputs["pixel_values"] 647 | questions = inputs["questions"] 648 | num_patches_list = inputs["num_patches_list"] 649 | 650 | # Build per-sample pixel tensors list 651 | if isinstance(pixel_values_in, list): 652 | pv_list = pixel_values_in 653 | else: 654 | # Split concatenated tensor by patch counts 655 | pv_list = list(torch.split(pixel_values_in, num_patches_list, dim=0)) 656 | 657 | if len(pv_list) != len(questions): 658 | raise ValueError( 659 | f"InternVL mismatch: {len(pv_list)} pixel groups vs {len(questions)} questions" 660 | ) 661 | 662 | responses: List[str] = [] 663 | for pv, q, n in zip(pv_list, questions, num_patches_list): 664 | # Sanity checks 665 | if pv.dim() != 4 or pv.size(0) != int(n): 666 | raise ValueError( 667 | f"InternVL per-sample mismatch: pixel batch={pv.size(0)} vs n={n}, shape={tuple(pv.shape)}" 668 | ) 669 | out = self.model.batch_chat( 670 | self.processor, 671 | pv, 672 | num_patches_list=[int(n)], 673 | questions=[q], 674 | generation_config=generation_kwargs, 675 | ) 676 | # Normalize to string 677 | if isinstance(out, list): 678 | if len(out) == 1 and isinstance(out[0], dict) and 'response' in out[0]: 679 | responses.append(out[0]['response']) 680 | elif len(out) == 1 and isinstance(out[0], str): 681 | responses.append(out[0]) 682 | else: 683 | responses.append(str(out)) 684 | else: 685 | responses.append(str(out)) 686 | return responses 687 | 688 | def _generate_minicpm(self, msgs: Any, **generation_kwargs) -> List[str]: 689 | """Generate using MiniCPM's chat interface.""" 690 | # Default generation config 691 | generation_config = { 692 | "max_tokens": 128, 693 | "do_sample": False, 694 | } 695 | generation_config.update(generation_kwargs) 696 | 697 | if isinstance(msgs, list) and isinstance(msgs[0], list): 698 | # Batch processing 699 | responses = [] 700 | for msg in msgs: 701 | try: 702 | response = self.model.chat( 703 | image=None, 704 | msgs=msg, 705 | tokenizer=self.processor, 706 | **generation_config 707 | ) 708 | responses.append(response) 709 | except Exception as e: 710 | logger.warning(f"Individual MiniCPM inference failed: {e}") 711 | responses.append(f"ERROR: {str(e)}") 712 | return responses 713 | else: 714 | # Single processing 715 | return self.model.chat( 716 | image=None, 717 | msgs=msgs, 718 | tokenizer=self.processor, 719 | **generation_config 720 | ) 721 | 722 | @contextmanager 723 | def memory_efficient_mode(self): 724 | """Context manager for memory-efficient inference.""" 725 | original_grad_state = torch.is_grad_enabled() 726 | try: 727 | torch.set_grad_enabled(False) 728 | if torch.cuda.is_available(): 729 | torch.cuda.empty_cache() 730 | yield 731 | finally: 732 | torch.set_grad_enabled(original_grad_state) 733 | if torch.cuda.is_available(): 734 | torch.cuda.empty_cache() 735 | 736 | def __del__(self): 737 | """Cleanup resources.""" 738 | try: 739 | if torch is not None and torch.cuda.is_available(): 740 | torch.cuda.empty_cache() 741 | except Exception: 742 | pass 743 | 744 | 745 | # Alias for backward compatibility 746 | VLM = VLMWrapper --------------------------------------------------------------------------------