├── src
├── __init__.py
├── data
│ ├── __init__.py
│ ├── create_data.py
│ ├── images_dalle.py
│ ├── mrt_blender.py
│ ├── folding_pil.py
│ ├── create_prompts.py
│ ├── mrt.py
│ └── create_images.py
├── utils
│ ├── __init__.py
│ └── vlm_wrapper.py
├── eval_openai.py
├── eval
│ └── acc.py
└── eval.py
├── .gitignore
├── scripts
├── run_clean_eval.sh
└── run.sh
├── LICENSE
├── requirements.txt
└── README.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .vlm_wrapper import VLMWrapper
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python cache and compiled files
2 | __pycache__/
3 | *.py[cod]
4 |
5 | # Virtual environments
6 | venv/
7 | env/
8 | .venv/
9 | .env
10 |
11 | # Distribution/build
12 | build/
13 | dist/
14 | *.egg-info/
15 |
16 | # Jupyter notebooks
17 | .ipynb_checkpoints/
18 |
19 | # Machine learning artifacts
20 | *.h5
21 | *.ckpt
22 | logs/
23 | models/
24 |
25 | # Logging files
26 | *.log
27 | .neptune/
28 | wandb
29 |
30 | # Hidden files
31 | .DS_Store
32 | *.swp
33 | ._*
34 |
35 | # Images
36 | output/
37 |
38 | # Files
39 | *.csv
40 | *.json
41 | *.yaml
42 | *.yml
43 | *.jsonl
44 | *.html
45 |
46 | # Legacy files
47 | example/
48 | legacy/
49 |
50 | Dockerfile
--------------------------------------------------------------------------------
/scripts/run_clean_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Simple bash script to run clean_eval.py
4 | # This script processes CSV files in the srbench evaluation output directory
5 |
6 | # Set the input directory containing CSV files
7 | INPUT_DIR="cot_test/*.csv"
8 | RESPONSE_COLUMN="response"
9 | CORRECT_COLUMN="gold answer"
10 | SPLIT_COLUMN="split"
11 | OUTPUT="cleaned_results/accuracy_results.csv"
12 |
13 | echo "Running clean_eval.py..."
14 | echo "Input directory: $INPUT_DIR"
15 | echo "Response column: $RESPONSE_COLUMN"
16 | echo "Output file: $OUTPUT"
17 |
18 | # Run the clean_eval.py script
19 | python src/eval/acc.py --input_pattern "$INPUT_DIR" \
20 | --response_column "$RESPONSE_COLUMN" \
21 | --correct_column "$CORRECT_COLUMN" \
22 | --split_column "$SPLIT_COLUMN" \
23 | --output "$OUTPUT"
24 |
25 | echo "Clean evaluation completed!"
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Ilias M. Stogiannidis
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/data/create_data.py:
--------------------------------------------------------------------------------
1 | from datasets import Dataset, Image, Features, Value
2 | import json
3 | from PIL import Image as PILImage
4 | import os
5 | import glob
6 |
7 |
8 | def generate_examples():
9 | for data in annotations:
10 | image_path = os.path.join("Spatial-MM/data/spatial_mm", data["image_name"])
11 | try:
12 | image = PILImage.open(image_path).convert("RGB") # Load image as RGB
13 | except Exception as e:
14 | print("Error loading image: ", image_path)
15 | continue
16 | yield {
17 | "image": image,
18 | "question": data["question"],
19 | "answer": data["answer"],
20 | }
21 |
22 | if __name__ == "__main__":
23 | # Load JSON annotations from multiple files
24 | annotations = []
25 | json_files = glob.glob("Spatial-MM/data/*.json")
26 | for json_file in json_files:
27 | with open(json_file, "r") as f:
28 | annotations.extend(json.load(f))
29 |
30 | # Define dataset features (adjust based on your JSON structure)
31 | features = Features(
32 | {
33 | "image": Image(),
34 | "question": Value("string"),
35 | "answer": Value("string"),
36 | }
37 | )
38 |
39 |
40 | # Create the dataset
41 | dataset = Dataset.from_generator(
42 | generate_examples,
43 | features=features,
44 | )
45 |
46 | dataset.push_to_hub("spatial_mm", private= True) # Push to the Hub
47 |
48 |
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | DATASET="stogian/srbench2"
2 | ONESHOT_PATH="/netdisk/users/stogian/srbench/example/oneshot.json"
3 |
4 | models=(
5 | "OpenGVLab/InternVL3_5-8B-HF"
6 | "OpenGVLab/InternVL3_5-14B-HF"
7 | "OpenGVLab/InternVL3_5-38B-HF"
8 | "OpenGVLab/InternVL3_5-30B-A3B-HF"
9 | "OpenGVLab/InternVL3_5-241B-A28B-HF"
10 | "google/gemma-3-12b-it"
11 | "google/gemma-3-27b-it"
12 | "Qwen/Qwen3-VL-8B-Thinking"
13 | "Qwen/Qwen3-VL-30B-A3B-Thinking"
14 | "Qwen/Qwen3-VL-235B-A22B-Thinking"
15 | "Qwen/Qwen3-VL-8B-Instruct"
16 | "Qwen/Qwen3-VL-30B-A3B-Instruct"
17 | "Qwen/Qwen3-VL-235B-A22B-Instruct"
18 | "meta-llama/Llama-3.2-11B-Vision-Instruct"
19 | "meta-llama/Llama-3.2-90B-Vision-Instruct"
20 | "HuggingFaceM4/Idefics3-8B-Llama3"
21 | "llava-hf/llava-1.5-7b-hf"
22 | "llava-hf/llava-v1.6-mistral-7b-hf"
23 | "openbmb/MiniCPM-V-2_6"
24 | "HuggingFaceTB/SmolVLM-500M-Instruct"
25 | "HuggingFaceTB/SmolVLM-Instruct"
26 | "moonshotai/Kimi-VL-A3B-Thinking-2506"
27 | "moonshotai/Kimi-VL-A3B-Instruct"
28 | "zai-org/GLM-4.1V-9B-Thinking"
29 | "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
30 | "meta-llama/Llama-4-Scout-17B-128E-Instruct"
31 | )
32 |
33 | for MODEL in "${models[@]}"; do
34 | echo "Running inference for model: $MODEL"
35 |
36 | # Determine optimal batch size based on model size
37 | if [[ $MODEL == *"11B"* ]]; then
38 | BATCH_SIZE=16
39 | elif [[ $MODEL == *"3B"* ]]; then
40 | BATCH_SIZE=32
41 | elif [[ $MODEL == *"8B"* ]]; then
42 | BATCH_SIZE=16
43 | else
44 | BATCH_SIZE=8
45 | fi
46 |
47 | python src/eval.py --model $MODEL \
48 | --dataset $DATASET \
49 | --batch_size $BATCH_SIZE \
50 | --seed 123 \
51 | --num_workers 8
52 |
53 | python src/eval.py --model $MODEL \
54 | --dataset $DATASET \
55 | --cot \
56 | --batch_size $BATCH_SIZE \
57 | --seed 123 \
58 | --num_workers 8
59 |
60 |
61 |
62 |
63 | done
64 |
65 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.10.0
2 | aiohappyeyeballs==2.5.0
3 | aiohttp==3.12.14
4 | aiosignal==1.4.0
5 | av==14.1.0
6 | babel==2.16.0
7 | beautifulsoup4==4.12.3
8 | boto3==1.36.12
9 | botocore==1.36.12
10 | bravado==11.0.3
11 | bravado-core==6.1.1
12 | cachetools==5.5.1
13 | certifi==2024.12.14
14 | click==8.1.8
15 | contourpy==1.3.1
16 | cycler==0.12.1
17 | datasets==3.3.1
18 | diffusers==0.32.2
19 | dill==0.3.8
20 | distro==1.9.0
21 | docker-pycreds==0.4.0
22 | einops==0.8.0
23 | et-xmlfile==2.0.0
24 | filelock==3.17.0
25 | flash-attn==2.7.3
26 | fonttools==4.55.8
27 | frozenlist==1.5.0
28 | future==1.0.0
29 | google-api-core==2.24.1
30 | google-auth==2.38.0
31 | google-auth-oauthlib==1.2.1
32 | google-cloud-core==2.4.1
33 | google-cloud-storage==3.0.0
34 | google-crc32c==1.6.0
35 | google-resumable-media==2.7.2
36 | googleapis-common-protos==1.67.0
37 | hf-transfer==0.1.9
38 | huggingface-hub==0.34.2
39 | iniconfig==2.0.0
40 | inquirerpy==0.3.4
41 | jiter==0.8.2
42 | jmespath==1.0.1
43 | joblib==1.4.2
44 | jsonref==1.1.0
45 | jupyter-events==0.11.0
46 | jupyterlab==4.3.4
47 | kiwisolver==1.4.8
48 | matplotlib==3.10.0
49 | monotonic==1.6
50 | mpmath==1.3.0
51 | multidict==6.1.0
52 | multiprocess==0.70.16
53 | neptune==1.13.0
54 | networkx==3.4.2
55 | numpy==2.2.2
56 | nvitop==1.4.2
57 | oauthlib==3.2.2
58 | openai==1.63.2
59 | openpyxl==3.1.5
60 | pandas==2.2.3
61 | pandocfilters==1.5.1
62 | pfzy==0.3.4
63 | pillow==11.1.0
64 | pluggy==1.5.0
65 | propcache==0.2.1
66 | proto-plus==1.26.0
67 | protobuf==5.29.5
68 | pyarrow==19.0.0
69 | pyasn1==0.6.1
70 | pyasn1-modules==0.4.1
71 | pydantic==2.10.6
72 | pydantic-core==2.27.2
73 | pyjwt==2.10.1
74 | pyparsing==3.2.1
75 | pytest==8.3.4
76 | python-dotenv==1.0.1
77 | python-json-logger==3.2.1
78 | pytz==2024.2
79 | qwen-vl-utils==0.0.10
80 | regex==2024.11.6
81 | requests-oauthlib==2.0.0
82 | rsa==4.9
83 | s3transfer==0.11.2
84 | safetensors==0.5.2
85 | scikit-learn==1.6.1
86 | scipy==1.15.1
87 | seaborn==0.13.2
88 | sentencepiece==0.2.0
89 | sentry-sdk==2.20.0
90 | setproctitle==1.3.4
91 | setuptools==75.8.0
92 | simplejson==3.19.3
93 | soupsieve==2.6
94 | swagger-spec-validator==3.0.4
95 | sympy==1.13.1
96 | threadpoolctl==3.5.0
97 | timm==1.0.14
98 | tokenizers==0.21.0
99 | torchvision==0.21.0
100 | tqdm==4.67.1
101 | transformers==4.55.0
102 | triton==3.2.0
103 | tzdata==2025.1
104 | xxhash==3.5.0
105 | yarl==1.18.3
106 | hf-xet==1.1.5
107 | tiktoken==0.9.0
108 | num2words==0.5.14
109 | blobfile==3.0.0
--------------------------------------------------------------------------------
/src/data/images_dalle.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate images using DALL-E 3.
3 |
4 | Note: DALL-E 3 requires version 1.0.0 of the openai-python library or later.
5 | """
6 |
7 | import os
8 | import argparse
9 | import json
10 | import requests
11 | from tqdm import tqdm
12 | from openai import AzureOpenAI
13 | from dotenv import load_dotenv
14 |
15 | load_dotenv(override=True)
16 |
17 | OUTPUT_DIR = "output/images/dalle3"
18 | os.makedirs(OUTPUT_DIR, exist_ok=True)
19 |
20 |
21 | def parse_arguments():
22 | """Parse command-line arguments."""
23 | parser = argparse.ArgumentParser(
24 | description="Generate images using DALL-E 3."
25 | )
26 | parser.add_argument(
27 | "-f",
28 | "--metadata_file",
29 | type=str,
30 | required=True,
31 | help="Path to the JSON file containing the metadata for image generation.",
32 | )
33 | return parser.parse_args()
34 |
35 |
36 | def load_metadata(metadata_file: str):
37 | """Load metadata from a JSON Lines file."""
38 | with open(metadata_file, "r") as file:
39 | lines = file.readlines()
40 | return [json.loads(line.strip()) for line in lines if line.strip()]
41 |
42 |
43 | def main():
44 | """Main function for generating images."""
45 | args = parse_arguments()
46 | metadata = load_metadata(args.metadata_file)
47 |
48 | client = AzureOpenAI(
49 | api_version="2024-02-01",
50 | azure_endpoint=os.environ["DALLE_ENDPOINT"],
51 | api_key=os.environ["AZURE_OPENAI_API_KEY"],
52 | )
53 |
54 | for idx, item in enumerate(
55 | tqdm(metadata, desc="Generating images", unit="prompt", colour="#000080")
56 | ):
57 | prompt = item["generated_scene_description"]
58 | print(f"Generating image for prompt: {prompt}")
59 | output_filename = f"image_{idx}.png"
60 | result = client.images.generate(
61 | model="dall-e-3", # the name of your DALL-E 3 deployment
62 | prompt="A photo-realistic image of " + prompt,
63 | n=1,
64 | )
65 | image_data = json.loads(result.model_dump_json())["data"][0]
66 | image_url = image_data["url"]
67 |
68 | # Placeholder for image download functionality.
69 | print(f"Generated image '{output_filename}' from URL: {image_url}")
70 |
71 | # Save the image to the output directory
72 | image_path = os.path.join(OUTPUT_DIR, output_filename)
73 | with open(image_path, "wb") as file:
74 | file.write(requests.get(image_url).content)
75 | print(f"Saved image to '{image_path}'")
76 |
77 |
78 |
79 | if __name__ == "__main__":
80 | main()
81 |
82 |
83 |
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SRBench
2 |
3 | [](https://opensource.org/licenses/MIT)
4 | [](https://www.python.org/downloads/release/python-3120/)
5 | [](https://huggingface.co/datasets/stogian/srbench)
6 | [](https://arxiv.org/abs/2503.19707)
7 |
8 | Welcome to our project! This repository contains all the source code, tests, and documentation required to understand and run the project. Below is an overview of the repository structure, installation, usage instructions, and contribution guidelines.
9 |
10 | ## Overview
11 |
12 | This repository is divided into several modules that cover various aspects of the project, including:
13 | - **Data Processing**: Scripts for loading, processing, and analyzing data.
14 | - **Analysis Tools**: Modules that perform computations and run experiments.
15 | - **Visualization Components**: Code for rendering results and generating reports.
16 |
17 | ## Repository Structure
18 |
19 | The project is organized as follows:
20 | ```
21 | SRBench/
22 | ├── bin/ # Storage for model binaries
23 | ├── scripts/ # Bash scripts for running the project
24 | │ └── run.sh # Main execution script
25 | ├── src/ # Source code of the project
26 | │ ├── data_creation/ # Scripts for data creation
27 | │ ├── __init__.py # Initialization file
28 | │ ├── create_data.py # Script for data creation
29 | │ ├── create_images.py # Script to create images
30 | │ └── create_prompts.py # Script to generate prompts
31 | │ ├── utils/ # Utility functions
32 | │ ├── __init__.py # Initialization file
33 | │ ├── vlm_helpers.py # Helper functions for the VLM models
34 | │ ├── eval.py # Evaluation script
35 | │ ├── eval_intern.py # Evaluation script for InternVL
36 | │ ├── eval_openai.py # Evaluation script for OpenAI models
37 | │ ├── eval_mini.py # Evaluation script for MiniCPM-V
38 | ├── .gitignore # Files and directories to ignore
39 | ├── requirements.txt # Required packages
40 | ├── LICENSE # MIT License file
41 | └── README.md # Project documentation
42 | ```
43 |
44 | ## Installation
45 |
46 | 1. Clone the repository:
47 | ```bash
48 | git clone https://github.com/stogiannidis/srbench.git
49 | cd srbench
50 | ```
51 | 2. Create a virtual environment:
52 | ```bash
53 | python3 -m venv venv
54 | source venv/bin/activate
55 | ```
56 | or using `conda`:
57 | ```bash
58 | conda create -n srbench python=3.12
59 | conda activate srbench
60 | ```
61 | 3. Install the required packages:
62 | ```bash
63 | pip install -r requirements.txt
64 | ```
65 |
66 | ## Usage
67 |
68 | To run the project, follow these steps:
69 | 1. Fetch the dataset from `Hugging Face`:
70 | ```bash
71 | huggingface-cli login
72 | huggingface-cli download stogiannidis/srbench
73 | ```
74 | 2. Run the script:
75 | ```bash
76 | bash scripts/run.sh
77 |
78 | ## Citation
79 | ```
80 | @misc{stogiannidis2025mindgapbenchmarkingspatial,
81 | title={Mind the Gap: Benchmarking Spatial Reasoning in Vision-Language Models},
82 | author={Ilias Stogiannidis and Steven McDonagh and Sotirios A. Tsaftaris},
83 | year={2025},
84 | eprint={2503.19707},
85 | archivePrefix={arXiv},
86 | primaryClass={cs.CV},
87 | url={https://arxiv.org/abs/2503.19707},
88 | }
89 | ```
90 |
91 | ## Contributing
92 |
93 | Contributions are welcome! Please follow these steps:
94 | - Fork the repository.
95 | - Create a new branch (`git checkout -b feature/your_feature`).
96 | - Commit your changes (`git commit -am 'Add new feature'`).
97 | - Push to the branch (`git push origin feature/your_feature`).
98 | - Open a Pull Request.
99 |
100 | ## License
101 |
102 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for more information.
103 |
104 | ## Contact
105 |
106 | For questions or feedback, please open an issue or contact me via email.
107 |
--------------------------------------------------------------------------------
/src/eval_openai.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import base64
3 | import logging
4 | import os
5 | import time
6 | from functools import wraps
7 | from io import BytesIO
8 | from typing import List
9 |
10 | import dotenv
11 | import pandas as pd
12 | from datasets import load_dataset
13 | from openai import AzureOpenAI
14 | from PIL import Image
15 | from tqdm import tqdm
16 |
17 | # Load the environment variables
18 | dotenv.load_dotenv()
19 |
20 | # Set up logging
21 | logging.basicConfig(filename="example.log", encoding="utf-8", level=logging.INFO)
22 |
23 |
24 | def retry_with_exponential_backoff(max_retries=5, initial_delay=5):
25 | """
26 | Retry decorator that retry a function call with exponential backoff.
27 |
28 | Args:
29 | max_retries (int): Maximum number of retries.
30 | initial_delay (int): Initial delay in seconds.
31 |
32 | Returns:
33 | Callable: Decorated function with retry logic.
34 | """
35 |
36 | def decorator(func):
37 | @wraps(func)
38 | def wrapper(*args, **kwargs):
39 | delay = initial_delay
40 | for _ in range(max_retries):
41 | try:
42 | return func(*args, **kwargs)
43 | except Exception as e:
44 | if "429" in str(e):
45 | print(f"Rate limit exceeded. Retrying in {delay} seconds...")
46 | time.sleep(delay)
47 | delay *= 2 # exponential backoff
48 | else:
49 | raise e
50 | raise Exception("Max retries exceeded")
51 |
52 | return wrapper
53 |
54 | return decorator
55 |
56 |
57 | def image_to_base64(image: Image.Image) -> str:
58 | """
59 | Convert a PIL Image to a base64 encoded string.
60 |
61 | Args:
62 | image (Image.Image): The image to convert.
63 |
64 | Returns:
65 | str: A base64 encoded string of the image.
66 | """
67 | buffered = BytesIO()
68 | image.save(buffered, format="PNG")
69 | img_str = base64.b64encode(buffered.getvalue()).decode("ascii")
70 | return img_str
71 |
72 |
73 | @retry_with_exponential_backoff()
74 | def infer(prompts: List[str], images: List[Image.Image], model: str) -> List[str]:
75 | """
76 | Infer responses using Azure OpenAI from provided prompts and images.
77 |
78 | Args:
79 | prompts (List[str]): A list of prompts/questions.
80 | images (List[Image.Image]): A list of images corresponding to the prompts.
81 | model (str): The model identifier to use for inference.
82 |
83 | Returns:
84 | List[str]: A list of responses from the model.
85 | """
86 | if not os.getenv("AZURE_OPENAI_API_KEY") or not os.getenv("AZURE_OPENAI_ENDPOINT"):
87 | raise EnvironmentError(
88 | "Azure OpenAI API key or endpoint not set in environment variables."
89 | )
90 |
91 | endpoint = (
92 | os.getenv("AZURE_OPENAI_ENDPOINT")
93 | if model == "gpt-4o"
94 | else os.getenv("O1_ENDPOINT")
95 | )
96 |
97 | client = AzureOpenAI(
98 | api_key=os.getenv("AZURE_OPENAI_API_KEY"),
99 | api_version="2024-07-01-preview",
100 | azure_endpoint=endpoint,
101 | )
102 |
103 | contents = []
104 | for prompt, image in zip(prompts, images):
105 | if isinstance(image, Image.Image):
106 | image_content = image_to_base64(image)
107 | image_content = f"data:image/png;base64,{image_content}"
108 | else:
109 | image_content = str(image)
110 |
111 |
112 | # Merge the image and text into one message
113 | messages = [
114 | {"role": "system", "content": "You are a helpful AI assistant."},
115 | {
116 | "role": "user",
117 | "content": [
118 | {"type": "image_url", "image_url": {"url": image_content}},
119 | {"type": "text", "text": prompt},
120 | ],
121 | },
122 | ]
123 |
124 | response = client.chat.completions.create(
125 | model=model,
126 | messages=messages,
127 | seed=69,
128 | temperature=0.5,
129 | max_tokens=64,
130 | )
131 | content = response.choices[0].message.content.strip()
132 | contents.append(content)
133 |
134 | return contents
135 |
136 |
137 | def main():
138 | """
139 | Main function to process the dataset and perform inference based on
140 | provided command-line arguments for dataset name and model.
141 | """
142 | parser = argparse.ArgumentParser(
143 | description="Process dataset and perform inference using Azure OpenAI."
144 | )
145 | parser.add_argument(
146 | "--dataset",
147 | type=str,
148 | required=True,
149 | help="Dataset name in Hugging Face format (e.g., stogian/mrt_pf_mix)",
150 | )
151 | parser.add_argument(
152 | "--model", type=str, required=True, help="Model name (e.g., gpt-4o)"
153 | )
154 | parser.add_argument(
155 | "--batch_size", type=int, default=16, help="Batch size for processing abstracts"
156 | )
157 | args = parser.parse_args()
158 |
159 | dataset_name = args.dataset
160 | short_name = dataset_name.split("/")[-1]
161 |
162 | model = args.model
163 | model_name = model.split("/")[-1]
164 |
165 | # Load the specified dataset
166 | dataset = load_dataset(dataset_name, split="train")
167 |
168 | all_responses = []
169 |
170 | for i in tqdm(
171 | range(0, len(dataset), args.batch_size),
172 | desc="Processing batches",
173 | unit="batch",
174 | leave=False,
175 | colour="magenta",
176 | ):
177 | batch = dataset[i : i + args.batch_size]
178 | prompts = batch["question"]
179 | images = batch["image"]
180 |
181 | responses = infer(prompts, images, model)
182 | all_responses.extend(responses)
183 |
184 | results_df = pd.DataFrame(
185 | {
186 | "question": dataset["question"],
187 | "response": all_responses,
188 | "answer": dataset["answer"],
189 | "split": dataset["split"],
190 | }
191 | )
192 | results_dir = f"output/evaluations/{short_name}/"
193 | os.makedirs(results_dir, exist_ok=True)
194 |
195 | results_df.to_csv(os.path.join(results_dir, f"{model_name}.csv"), index=False)
196 | print(f"Results saved to {os.path.join(results_dir, 'results.csv')}")
197 |
198 |
199 | if __name__ == "__main__":
200 | main()
201 |
--------------------------------------------------------------------------------
/src/eval/acc.py:
--------------------------------------------------------------------------------
1 | import re
2 | import glob
3 | import os
4 | import pandas as pd
5 | import argparse
6 | import sys
7 |
8 |
9 | def normalize_answer(ans):
10 | try:
11 | if isinstance(ans, str):
12 | ans = ans.strip().upper()
13 | ans = ans.split(".")[0] # Take text before the first period
14 | except Exception as e:
15 | pass
16 | return ans
17 |
18 |
19 | def extract_answer(text):
20 | """
21 | Extracts the final answer from an LLM's output, supporting single letters,
22 | specific words, Markdown bolding, and mixed formats (e.g., "C. Left").
23 |
24 | Args:
25 | text (str): The LLM-generated text.
26 |
27 | Returns:
28 | str: The extracted answer in uppercase (e.g., "A", "YES") or "None" if not found.
29 | """
30 | if pd.isna(text) or text == "":
31 | return "None"
32 |
33 | text = str(text)
34 |
35 | # --- Step 1: Clean up the input text ---
36 | prefixes_to_remove = ["Assistant", "ASSISTANT", "[INST]", "assistant"]
37 | first_prefix_pos = len(text)
38 | for prefix in prefixes_to_remove:
39 | pos = text.find(prefix)
40 | if pos != -1:
41 | first_prefix_pos = min(first_prefix_pos, pos)
42 |
43 | if first_prefix_pos != len(text):
44 | text = text[first_prefix_pos:]
45 |
46 | # --- Step 2: Define answer patterns ---
47 | word_answers = ["yes", "no", "left", "right", "back", "front", "center"]
48 | core_pattern = r"\b(" + "|".join(word_answers) + r"|[A-Z])\b"
49 | answer_pattern = r"(?:\*\*)?" + core_pattern + r"(?:\*\*)?"
50 |
51 | # --- Step 3: Check for answers in different formats, from most to least specific ---
52 |
53 | # A. Check for answers in curly brackets, e.g., {**A**}
54 | curly_pattern = (
55 | r"\{" + r"(?:\*\*)?(" + "|".join(word_answers) + r"|[A-Z])(?:\*\*)?" + r"\}"
56 | )
57 | curly_match = re.search(curly_pattern, text, re.IGNORECASE)
58 | if curly_match:
59 | return curly_match.group(1).upper()
60 |
61 | # B. Check for mixed format like "C. Left" and prioritize the letter.
62 | mixed_pattern = r"\b([A-Z])(?:\.|:|\))\s*(?:" + "|".join(word_answers) + r")\b"
63 | mixed_match = re.search(mixed_pattern, text, re.IGNORECASE)
64 | if mixed_match:
65 | return mixed_match.group(1).upper()
66 |
67 | # C. Check for phrases that typically precede or follow the answer
68 | before_phrases = [
69 | "the answer is",
70 | "i think it's",
71 | "i choose",
72 | "i'll go with",
73 | "it's",
74 | "the correct choice is",
75 | "my answer is",
76 | "i believe it's",
77 | "i select",
78 | "the best answer is",
79 | ]
80 | after_phrases = [
81 | "is the answer",
82 | "is correct",
83 | "is the correct choice",
84 | "is right",
85 | "is the best answer",
86 | "is the right choice",
87 | ]
88 |
89 | before_pattern = (
90 | r"(?:"
91 | + "|".join(re.escape(p) for p in before_phrases)
92 | + r")\s*"
93 | + answer_pattern
94 | )
95 | after_pattern = (
96 | answer_pattern
97 | + r"\s*(?:"
98 | + "|".join(re.escape(p) for p in after_phrases)
99 | + r")"
100 | )
101 |
102 | before_regex = re.compile(before_pattern, re.IGNORECASE)
103 | after_regex = re.compile(after_pattern, re.IGNORECASE)
104 |
105 | matches = list(before_regex.finditer(text)) + list(after_regex.finditer(text))
106 | if matches:
107 | matches.sort(key=lambda m: m.start())
108 | return matches[-1].group(1).upper()
109 |
110 | # D. Check for a direct answer format: "**A**.", "**Yes**:", etc.
111 | direct_match = re.search(
112 | answer_pattern + r"(?:\.|:|\))?(?:\s|$)", text, re.IGNORECASE
113 | )
114 | if direct_match:
115 | return direct_match.group(1).upper()
116 |
117 | # E. Fallback: find the last standalone answer word/letter in the text
118 | fallback_matches = re.findall(answer_pattern, text, re.IGNORECASE)
119 | if fallback_matches:
120 | return fallback_matches[-1].upper()
121 |
122 | return "None"
123 |
124 |
125 | def exact_match(pred, gold):
126 | """Compare normalized predicted and gold answers"""
127 | return normalize_answer(pred) == normalize_answer(gold)
128 |
129 |
130 | def process_csv_files(
131 | pattern, response_col, correct_col, split_col=None, output_file=None
132 | ):
133 | """
134 | Process multiple CSV files and calculate accuracy per split.
135 |
136 | Args:
137 | pattern (str): Glob pattern for CSV files
138 | response_col (str): Name of column containing model responses
139 | correct_col (str): Name of column containing correct answers
140 | split_col (str, optional): Name of column containing splits
141 | output_file (str, optional): Path to save results CSV
142 |
143 | Returns:
144 | pd.DataFrame: DataFrame with results per model and split
145 | """
146 | files = glob.glob(pattern)
147 |
148 | if not files:
149 | print(f"No files found matching pattern: {pattern}")
150 | return None
151 |
152 | print(f"Found {len(files)} files to process")
153 |
154 | # Dictionary to store results for all models
155 | results = {}
156 | detailed_results = []
157 |
158 | for file in files:
159 | print(f"Processing: {file}")
160 |
161 | try:
162 | df = pd.read_csv(file)
163 | print(f" Loaded {len(df)} rows")
164 | except Exception as e:
165 | print(f" Error reading {file}: {e}")
166 | continue
167 |
168 | # Extract model name from filename
169 | model_name = os.path.basename(file).replace(".csv", "")
170 |
171 | # Check if required columns exist
172 | if response_col not in df.columns:
173 | print(f" Warning: Column '{response_col}' not found in {file}")
174 | continue
175 | if correct_col not in df.columns:
176 | print(f" Warning: Column '{correct_col}' not found in {file}")
177 | continue
178 | if split_col and split_col not in df.columns:
179 | print(f" Warning: Column '{split_col}' not found in {file}")
180 | continue
181 |
182 | # Extract answers from responses
183 | df["extracted_answer"] = (
184 | df[response_col].fillna("").astype(str).apply(extract_answer)
185 | )
186 |
187 | # Calculate correctness using exact match with extracted answers
188 | try:
189 | df["correct"] = df.apply(
190 | lambda row: exact_match(row["extracted_answer"], row[correct_col]),
191 | axis=1,
192 | )
193 | except Exception as e:
194 | print(f" Error calculating correctness for {file}: {e}")
195 | continue
196 |
197 | # Calculate overall accuracy
198 | overall_accuracy = df["correct"].mean()
199 |
200 | # Initialize results for this model
201 | results[model_name] = {"overall": overall_accuracy}
202 |
203 | # Calculate accuracy per split if split column exists
204 | if split_col and split_col in df.columns:
205 | try:
206 | accuracy_per_split = df.groupby(split_col)["correct"].mean()
207 | results[model_name].update(accuracy_per_split.to_dict())
208 |
209 | print(f" Accuracy per split:")
210 | for split, acc in accuracy_per_split.items():
211 | print(f" {split}: {acc:.4f} ({acc * 100:.2f}%)")
212 |
213 | except Exception as e:
214 | print(f" Error calculating accuracy per split for {file}: {e}")
215 |
216 | print(
217 | f" Overall accuracy: {overall_accuracy:.4f} ({overall_accuracy * 100:.2f}%)"
218 | )
219 |
220 | # Store detailed results for this file
221 | file_details = {
222 | "model": model_name,
223 | "total_questions": len(df),
224 | "correct_answers": int(df["correct"].sum()),
225 | "extraction_failed": int((df["extracted_answer"] == "None").sum()),
226 | "overall_accuracy": overall_accuracy,
227 | }
228 | detailed_results.append(file_details)
229 |
230 | print()
231 |
232 | if not results:
233 | print("No valid results found")
234 | return None
235 |
236 | # Create DataFrame with models as rows and splits as columns
237 | results_df = pd.DataFrame(results).T
238 | results_df = results_df.sort_index()
239 | results_df.index.name = "model"
240 |
241 | # Round all values to 4 decimal places
242 | results_df = results_df.round(4)
243 |
244 | # Save results
245 | if output_file is None:
246 | output_file = "model_accuracy_by_split.csv"
247 |
248 | results_df.to_csv(output_file)
249 | print(f"Results saved to {output_file}")
250 |
251 | # Save detailed results
252 | detailed_df = pd.DataFrame(detailed_results)
253 | detailed_output = output_file.replace(".csv", "_detailed.csv")
254 | detailed_df.to_csv(detailed_output, index=False)
255 | print(f"Detailed results saved to {detailed_output}")
256 |
257 | # Print summary
258 | print("\n" + "=" * 60)
259 | print("SUMMARY RESULTS")
260 | print("=" * 60)
261 | print(results_df)
262 | print("=" * 60)
263 |
264 | return results_df
265 |
266 |
267 | def main():
268 | parser = argparse.ArgumentParser(
269 | description="Calculate accuracy from multiple CSV files"
270 | )
271 | parser.add_argument(
272 | "--input-pattern",
273 | "-i",
274 | required=True,
275 | help='Glob pattern for CSV files (e.g., "data/*.csv")'
276 | )
277 | parser.add_argument(
278 | "--response-col",
279 | "-r",
280 | required=True,
281 | help="Name of column containing model responses",
282 | )
283 | parser.add_argument(
284 | "--correct-col",
285 | "-c",
286 | required=True,
287 | help="Name of column containing correct answers",
288 | )
289 | parser.add_argument(
290 | "--split-col", "-s", help="Name of column containing splits (optional)"
291 | )
292 | parser.add_argument(
293 | "--output",
294 | "-o",
295 | help="Path to save results CSV (default: model_accuracy_by_split.csv)",
296 | )
297 |
298 | args = parser.parse_args()
299 |
300 | # Process files
301 | results_df = process_csv_files(
302 | args.pattern, args.response_col, args.correct_col, args.split_col, args.output
303 | )
304 |
305 | if results_df is None:
306 | sys.exit(1)
307 |
308 |
309 | if __name__ == "__main__":
310 | main()
311 |
--------------------------------------------------------------------------------
/src/data/mrt_blender.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import bpy
4 | import bmesh
5 | import os
6 | import sys
7 | import random
8 | import argparse
9 | import json
10 | import mathutils
11 | from mathutils import Vector, Matrix, Euler
12 | import math
13 |
14 | # Import shape definitions from mrt.py
15 | sys.path.append(os.path.dirname(__file__))
16 | from mrt import SHAPES, EASY_SHAPES, COMPLEX_SHAPES, SIMILAR_MAPPING
17 |
18 | def clear_scene():
19 | """Clear all objects from the scene."""
20 | bpy.ops.object.select_all(action='SELECT')
21 | bpy.ops.object.delete(use_global=False)
22 |
23 | def create_cube_at_position(position, size=1.0, name="Cube"):
24 | """Create a cube at the specified position."""
25 | bpy.ops.mesh.primitive_cube_add(size=size, location=position)
26 | cube = bpy.context.active_object
27 | cube.name = name
28 | return cube
29 |
30 | def create_polycube(shape_coords, cube_size=1.0, material=None):
31 | """Create a polycube from coordinate list."""
32 | cubes = []
33 | for i, coord in enumerate(shape_coords):
34 | pos = Vector((coord[0] * cube_size, coord[1] * cube_size, coord[2] * cube_size))
35 | cube = create_cube_at_position(pos, cube_size, f"Cube_{i}")
36 | if material:
37 | cube.data.materials.append(material)
38 | cubes.append(cube)
39 |
40 | # Join all cubes into one object
41 | bpy.ops.object.select_all(action='DESELECT')
42 | for cube in cubes:
43 | cube.select_set(True)
44 | bpy.context.view_layer.objects.active = cubes[0]
45 | bpy.ops.object.join()
46 |
47 | polycube = bpy.context.active_object
48 | return polycube
49 |
50 | def create_material(color=(1, 1, 1, 1), name="Material"):
51 | """Create a material with specified color."""
52 | mat = bpy.data.materials.new(name=name)
53 | mat.use_nodes = True
54 | bsdf = mat.node_tree.nodes["Principled BSDF"]
55 | bsdf.inputs[0].default_value = color # Base Color
56 | bsdf.inputs[7].default_value = 0.2 # Roughness
57 | bsdf.inputs[15].default_value = 1.0 # Specular
58 | return mat
59 |
60 | def setup_lighting():
61 | """Set up clean lighting for the scene."""
62 | # Add sun light
63 | bpy.ops.object.light_add(type='SUN', location=(5, 5, 10))
64 | sun = bpy.context.active_object
65 | sun.data.energy = 3.0
66 | sun.rotation_euler = (0.785, 0, 0.785) # 45 degrees
67 |
68 | # Add fill light
69 | bpy.ops.object.light_add(type='AREA', location=(-5, -5, 8))
70 | fill = bpy.context.active_object
71 | fill.data.energy = 1.0
72 | fill.data.size = 5.0
73 |
74 | def setup_camera(target_location, distance=8.0):
75 | """Set up camera to look at target location."""
76 | # Add camera
77 | bpy.ops.object.camera_add(location=(distance, -distance, distance))
78 | camera = bpy.context.active_object
79 |
80 | # Point camera at target
81 | direction = target_location - camera.location
82 | rot_quat = direction.to_track_quat('-Z', 'Y')
83 | camera.rotation_euler = rot_quat.to_euler()
84 |
85 | # Set as active camera
86 | bpy.context.scene.camera = camera
87 | return camera
88 |
89 | def get_object_bounds(obj):
90 | """Get the bounding box of an object."""
91 | bbox_corners = [obj.matrix_world @ Vector(corner) for corner in obj.bound_box]
92 | min_coord = Vector((min(v.x for v in bbox_corners),
93 | min(v.y for v in bbox_corners),
94 | min(v.z for v in bbox_corners)))
95 | max_coord = Vector((max(v.x for v in bbox_corners),
96 | max(v.y for v in bbox_corners),
97 | max(v.z for v in bbox_corners)))
98 | center = (min_coord + max_coord) / 2
99 | size = max_coord - min_coord
100 | return center, size
101 |
102 | def transform_rotate_blender(obj, difficulty="easy"):
103 | """Apply rotation transformation to object."""
104 | if difficulty == "easy":
105 | # Single axis rotation with simple angles
106 | axis = random.choice(['X', 'Y', 'Z'])
107 | angle = math.radians(random.choice([-90, 90, 180]))
108 |
109 | if axis == 'X':
110 | obj.rotation_euler = (angle, 0, 0)
111 | elif axis == 'Y':
112 | obj.rotation_euler = (0, angle, 0)
113 | else: # Z
114 | obj.rotation_euler = (0, 0, angle)
115 | else: # complex
116 | # Multi-axis rotation
117 | angle_x = math.radians(random.choice([0, 60, 90, 120]))
118 | angle_y = math.radians(random.choice([0, 60, 90, 120]))
119 | angle_z = math.radians(random.choice([0, 60, 90, 120]))
120 | obj.rotation_euler = (angle_x, angle_y, angle_z)
121 |
122 | def transform_mirror_blender(obj, difficulty="easy"):
123 | """Apply mirror transformation to object."""
124 | if difficulty == "easy":
125 | # Mirror across Z axis
126 | obj.scale[2] = -1
127 | else: # complex
128 | # Mirror across random axis
129 | axis = random.choice([0, 1, 2])
130 | scale = [1, 1, 1]
131 | scale[axis] = -1
132 | obj.scale = scale
133 |
134 | def render_image(output_path, resolution=(512, 512)):
135 | """Render the current scene to an image."""
136 | scene = bpy.context.scene
137 | scene.render.resolution_x = resolution[0]
138 | scene.render.resolution_y = resolution[1]
139 | scene.render.filepath = output_path
140 | scene.render.image_settings.file_format = 'PNG'
141 |
142 | # Set render engine to Cycles for better quality
143 | scene.render.engine = 'CYCLES'
144 | scene.cycles.samples = 64
145 |
146 | bpy.ops.render.render(write_still=True)
147 |
148 | def create_composite_image(original_obj, candidates, output_path, difficulty="easy"):
149 | """Create composite image with original and candidates."""
150 | # Position objects for composite layout
151 | if difficulty == "easy":
152 | positions = [
153 | Vector((0, 0, 4)), # Original (top center)
154 | Vector((-4, 0, 0)), # Candidate A
155 | Vector((0, 0, 0)), # Candidate B
156 | Vector((4, 0, 0)), # Candidate C
157 | ]
158 | else: # complex
159 | positions = [
160 | Vector((0, 0, 6)), # Original (top center)
161 | Vector((-6, 0, 0)), # Candidate A
162 | Vector((-2, 0, 0)), # Candidate B
163 | Vector((2, 0, 0)), # Candidate C
164 | Vector((6, 0, 0)), # Candidate D
165 | ]
166 |
167 | # Position original
168 | original_obj.location = positions[0]
169 |
170 | # Position candidates
171 | for i, (_, candidate_obj) in enumerate(candidates):
172 | candidate_obj.location = positions[i + 1]
173 |
174 | # Set up camera to capture all objects
175 | all_objects = [original_obj] + [obj for _, obj in candidates]
176 |
177 | # Calculate scene bounds
178 | min_coords = Vector((float('inf'), float('inf'), float('inf')))
179 | max_coords = Vector((float('-inf'), float('-inf'), float('-inf')))
180 |
181 | for obj in all_objects:
182 | for corner in obj.bound_box:
183 | world_corner = obj.matrix_world @ Vector(corner)
184 | min_coords.x = min(min_coords.x, world_corner.x)
185 | min_coords.y = min(min_coords.y, world_corner.y)
186 | min_coords.z = min(min_coords.z, world_corner.z)
187 | max_coords.x = max(max_coords.x, world_corner.x)
188 | max_coords.y = max(max_coords.y, world_corner.y)
189 | max_coords.z = max(max_coords.z, world_corner.z)
190 |
191 | scene_center = (min_coords + max_coords) / 2
192 | scene_size = max(max_coords - min_coords)
193 |
194 | # Position camera
195 | camera_distance = scene_size * 1.5
196 | setup_camera(scene_center, camera_distance)
197 |
198 | # Render
199 | render_image(output_path)
200 |
201 | def generate_one_image_blender(index, difficulty="easy", facecolor="white", outdir="data/mrt"):
202 | """Generate a single MRT image using Blender."""
203 | clear_scene()
204 |
205 | # Create material
206 | if facecolor == "white":
207 | color = (0.9, 0.9, 0.9, 1.0)
208 | else:
209 | # Simple color mapping - extend as needed
210 | color_map = {
211 | "red": (0.8, 0.2, 0.2, 1.0),
212 | "blue": (0.2, 0.2, 0.8, 1.0),
213 | "green": (0.2, 0.8, 0.2, 1.0),
214 | }
215 | color = color_map.get(facecolor, (0.9, 0.9, 0.9, 1.0))
216 |
217 | material = create_material(color, "PolycubeMaterial")
218 |
219 | # Select shapes based on difficulty
220 | shapes_list = EASY_SHAPES if difficulty == "easy" else COMPLEX_SHAPES
221 | shape_name = random.choice(shapes_list)
222 | shape_coords = SHAPES[shape_name]
223 |
224 | # Create original polycube
225 | original_obj = create_polycube(shape_coords, material=material)
226 | original_obj.name = "Original"
227 |
228 | # Generate candidates
229 | candidates = []
230 |
231 | # Correct candidate (rotation)
232 | correct_obj = create_polycube(shape_coords, material=material)
233 | correct_obj.name = "Rotate"
234 | transform_rotate_blender(correct_obj, difficulty)
235 | candidates.append(("rotate", correct_obj))
236 |
237 | # Mirror candidate
238 | mirror_obj = create_polycube(shape_coords, material=material)
239 | mirror_obj.name = "Mirror"
240 | transform_mirror_blender(mirror_obj, difficulty)
241 | transform_rotate_blender(mirror_obj, difficulty)
242 | candidates.append(("mirror", mirror_obj))
243 |
244 | # Similar shape candidate
245 | if shape_name in SIMILAR_MAPPING:
246 | similar_candidates = SIMILAR_MAPPING[shape_name][:]
247 | random.shuffle(similar_candidates)
248 | similar_shape_name = similar_candidates[0] if similar_candidates else shape_name
249 | else:
250 | similar_shape_name = shape_name
251 |
252 | similar_coords = SHAPES[similar_shape_name]
253 | similar_obj = create_polycube(similar_coords, material=material)
254 | similar_obj.name = "Similar"
255 | transform_rotate_blender(similar_obj, difficulty)
256 | candidates.append(("similar", similar_obj))
257 |
258 | # Add second mirror for complex mode
259 | if difficulty == "complex":
260 | mirror2_obj = create_polycube(shape_coords, material=material)
261 | mirror2_obj.name = "Mirror2"
262 | transform_mirror_blender(mirror2_obj, difficulty)
263 | transform_rotate_blender(mirror2_obj, difficulty)
264 | candidates.append(("mirror2", mirror2_obj))
265 |
266 | # Shuffle candidates
267 | random.shuffle(candidates)
268 | correct_candidate_index = [
269 | i for i, cand in enumerate(candidates) if cand[0] == "rotate"
270 | ][0]
271 |
272 | # Set up lighting
273 | setup_lighting()
274 |
275 | # Create composite image
276 | filename = f"{shape_name}_{index}.png"
277 | output_path = os.path.join(outdir, filename)
278 | create_composite_image(original_obj, candidates, output_path, difficulty)
279 |
280 | # Save metadata
281 | metadata = {
282 | "filename": filename,
283 | "difficulty": difficulty,
284 | "shape": shape_name,
285 | "candidate_order": [tag for tag, _ in candidates],
286 | "answer": chr(65 + correct_candidate_index),
287 | }
288 |
289 | metadata_path = os.path.join(outdir, "metadata.jsonl")
290 | with open(metadata_path, "a") as f:
291 | f.write(json.dumps(metadata) + "\n")
292 |
293 | def main():
294 | parser = argparse.ArgumentParser(
295 | description="Generate mental rotation test images using Blender."
296 | )
297 | parser.add_argument(
298 | "--difficulty", "-d", type=str, choices=["easy", "complex"], default="easy",
299 | help="Difficulty level: 'easy' or 'complex'"
300 | )
301 | parser.add_argument(
302 | "--num_images", "-n", type=int, default=1,
303 | help="Number of images to generate"
304 | )
305 | parser.add_argument(
306 | "--color", "-c", type=str, default="white",
307 | help="Color for the polycubes"
308 | )
309 | parser.add_argument(
310 | "--seed", "-s", type=int, default=69,
311 | help="Seed for reproducible results"
312 | )
313 | parser.add_argument(
314 | "--outdir", "-o", type=str, default=None,
315 | help="Output directory (defaults to data/mrt_blender/{difficulty})"
316 | )
317 |
318 | args = parser.parse_args()
319 |
320 | # Set default output directory based on difficulty
321 | if args.outdir is None:
322 | args.outdir = f"data/mrt_blender/{args.difficulty}"
323 |
324 | os.makedirs(args.outdir, exist_ok=True)
325 |
326 | # Set seed for reproducibility
327 | if args.seed is not None:
328 | random.seed(args.seed)
329 |
330 | # Generate images
331 | for i in range(args.num_images):
332 | generate_one_image_blender(i, difficulty=args.difficulty,
333 | facecolor=args.color, outdir=args.outdir)
334 |
335 | print(f"Generated {args.num_images} {args.difficulty} MRT images in {args.outdir}")
336 |
337 | # Run directly in Blender
338 | if __name__ == "__main__":
339 | # When running in Blender, parse sys.argv differently
340 | if "--" in sys.argv:
341 | argv = sys.argv[sys.argv.index("--") + 1:]
342 | sys.argv = [sys.argv[0]] + argv
343 | main()
344 |
--------------------------------------------------------------------------------
/src/data/folding_pil.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageDraw, ImageFont
2 | import random
3 | import argparse
4 | import os
5 | import json
6 | import math
7 | from collections import defaultdict
8 |
9 | OUTPUT_DIR = "data/pf"
10 | os.makedirs(OUTPUT_DIR, exist_ok=True)
11 |
12 | # Cache font loading
13 | _font_cache = {}
14 |
15 |
16 | def get_font(size=18):
17 | """Get cached font instance."""
18 | if size not in _font_cache:
19 | try:
20 | _font_cache[size] = ImageFont.truetype("Arial", size=size)
21 | except OSError:
22 | _font_cache[size] = ImageFont.load_default()
23 | return _font_cache[size]
24 |
25 |
26 | def draw_paper(draw):
27 | """Draw a 100x100 white square paper with a black outline."""
28 | draw.rectangle((10, 10, 110, 110), outline="black", fill="white")
29 |
30 |
31 | def draw_holes(draw, holes):
32 | """
33 | Draw holes as small circles. If duplicate coordinates occur,
34 | offset them slightly in a circular pattern so that all are visible.
35 | """
36 | if not holes:
37 | return
38 |
39 | groups = defaultdict(list)
40 | # Group holes by rounded coordinates.
41 | for h in holes:
42 | key = (round(h[0], 1), round(h[1], 1))
43 | groups[key].append(h)
44 |
45 | for group in groups.values():
46 | n = len(group)
47 | if n == 1:
48 | x, y = group[0]
49 | draw.ellipse((x - 2, y - 2, x + 2, y + 2), fill="black")
50 | else:
51 | # Distribute duplicates in a circle of small radius.
52 | radius_offset = 2
53 | angle_step = 2 * math.pi / n
54 | for i, h in enumerate(group):
55 | angle = angle_step * i
56 | offset_x = radius_offset * math.cos(angle)
57 | offset_y = radius_offset * math.sin(angle)
58 | x, y = h[0] + offset_x, h[1] + offset_y
59 | draw.ellipse((x - 2, y - 2, x + 2, y + 2), fill="black")
60 |
61 |
62 | # Reflection logic - use lambda functions for better performance
63 | FOLD_REFLECTIONS = {
64 | "V": lambda p: (120 - p[0], p[1]),
65 | "H": lambda p: (p[0], 120 - p[1]),
66 | "D": lambda p: (p[1], p[0]),
67 | "N": lambda p: (120 - p[1], 120 - p[0]),
68 | }
69 |
70 |
71 | def compute_all_layers(punched_holes, folds):
72 | """
73 | Compute unfolded holes by doubling the layers for each fold.
74 | """
75 | if not punched_holes:
76 | return []
77 |
78 | layers = [punched_holes]
79 | for fold in folds:
80 | reflection_func = FOLD_REFLECTIONS[fold]
81 | new_layers = []
82 | for layer in layers:
83 | new_layers.append(layer)
84 | new_layers.append([reflection_func(h) for h in layer])
85 | layers = new_layers
86 |
87 | # Flatten layers
88 | return [h for layer in layers for h in layer]
89 |
90 |
91 | def point_in_poly(x, y, poly):
92 | """Optimized point-in-polygon test using ray casting."""
93 | if not poly:
94 | return False
95 |
96 | inside = False
97 | j = len(poly) - 1
98 |
99 | for i in range(len(poly)):
100 | xi, yi = poly[i]
101 | xj, yj = poly[j]
102 |
103 | if ((yi > y) != (yj > y)) and (x < (xj - xi) * (y - yi) / (yj - yi) + xi):
104 | inside = not inside
105 | j = i
106 |
107 | return inside
108 |
109 |
110 | def generate_hole_in_poly(poly, margin=5):
111 | """Generate a random point within the polygon with margin."""
112 | if not poly:
113 | return (60, 60) # Default center point
114 |
115 | xs = [p[0] for p in poly]
116 | ys = [p[1] for p in poly]
117 | bx_min, bx_max = min(xs) + margin, max(xs) - margin
118 | by_min, by_max = min(ys) + margin, max(ys) - margin
119 |
120 | # Ensure valid bounds
121 | if bx_min >= bx_max or by_min >= by_max:
122 | return (int((min(xs) + max(xs)) / 2), int((min(ys) + max(ys)) / 2))
123 |
124 | for _ in range(100): # Limit attempts to avoid infinite loop
125 | x = random.uniform(bx_min, bx_max)
126 | y = random.uniform(by_min, by_max)
127 | if point_in_poly(x, y, poly):
128 | return (int(round(x)), int(round(y)))
129 |
130 | # Fallback to polygon centroid
131 | cx = sum(p[0] for p in poly) / len(poly)
132 | cy = sum(p[1] for p in poly) / len(poly)
133 | return (int(round(cx)), int(round(cy)))
134 |
135 |
136 | def clip_polygon(poly, fold):
137 | """Optimized polygon clipping."""
138 | if not poly:
139 | return []
140 |
141 | # Pre-compute bounding box
142 | xs = [p[0] for p in poly]
143 | ys = [p[1] for p in poly]
144 | bx_min, bx_max = min(xs), max(xs)
145 | by_min, by_max = min(ys), max(ys)
146 |
147 | if fold == "V":
148 | mid = (bx_min + bx_max) / 2
149 | inside = lambda p: p[0] >= mid
150 | intersect = lambda p1, p2: (
151 | mid,
152 | p1[1] + (mid - p1[0]) * (p2[1] - p1[1]) / (p2[0] - p1[0]),
153 | )
154 | elif fold == "H":
155 | mid = (by_min + by_max) / 2
156 | inside = lambda p: p[1] >= mid
157 | intersect = lambda p1, p2: (
158 | p1[0] + (mid - p1[1]) * (p2[0] - p1[0]) / (p2[1] - p1[1]),
159 | mid,
160 | )
161 | elif fold == "D":
162 | cx, cy = (bx_min + bx_max) / 2, (by_min + by_max) / 2
163 | b = cy - cx
164 | inside = lambda p: p[1] <= p[0] + b
165 |
166 | def intersect(p1, p2):
167 | denom = (p2[0] - p1[0]) - (p2[1] - p1[1])
168 | if abs(denom) < 1e-10:
169 | return p1
170 | t = (p1[1] - p1[0] - b) / denom
171 | return (p1[0] + t * (p2[0] - p1[0]), p1[1] + t * (p2[1] - p1[1]))
172 | elif fold == "N":
173 | cx, cy = (bx_min + bx_max) / 2, (by_min + by_max) / 2
174 | inside = lambda p: (p[0] + p[1]) <= (cx + cy)
175 |
176 | def intersect(p1, p2):
177 | denom = (p2[0] - p1[0]) + (p2[1] - p1[1])
178 | if abs(denom) < 1e-10:
179 | return p1
180 | t = ((cx + cy) - (p1[0] + p1[1])) / denom
181 | return (p1[0] + t * (p2[0] - p1[0]), p1[1] + t * (p2[1] - p1[1]))
182 | else:
183 | return poly
184 |
185 | new_poly = []
186 | for i in range(len(poly)):
187 | curr = poly[i]
188 | prev = poly[i - 1]
189 |
190 | if inside(curr):
191 | if not inside(prev):
192 | new_poly.append(intersect(prev, curr))
193 | new_poly.append(curr)
194 | elif inside(prev):
195 | new_poly.append(intersect(prev, curr))
196 |
197 | return new_poly
198 |
199 |
200 | # Global list to collect final view images for each fold.
201 | process_images = []
202 |
203 |
204 | def recursive_fold(current_folds, idx, poly):
205 | """Recursively generate the final folded view for each fold."""
206 | global process_images
207 |
208 | if idx == 0:
209 | img_unfolded = Image.new("RGB", (120, 120), "white")
210 | draw_unfolded = ImageDraw.Draw(img_unfolded)
211 | draw_paper(draw_unfolded)
212 | process_images.append((img_unfolded, "Unfolded"))
213 |
214 | if idx >= len(current_folds):
215 | return poly
216 |
217 | fold = current_folds[idx]
218 | new_poly = clip_polygon(poly, fold)
219 |
220 | img_result = Image.new("RGB", (120, 120), "white")
221 | draw_result = ImageDraw.Draw(img_result)
222 |
223 | if poly:
224 | draw_result.polygon(poly, fill="lightgray", outline="black")
225 | if new_poly:
226 | draw_result.polygon(new_poly, fill="white", outline="black")
227 |
228 | process_images.append((img_result, f"Fold {idx + 1}"))
229 | return recursive_fold(current_folds, idx + 1, new_poly)
230 |
231 |
232 | def generate_wrong_options(holes):
233 | """Generate all wrong options at once for efficiency."""
234 | if not holes:
235 | return [[], []]
236 |
237 | # Option 1: Remove one hole or shift if only one hole
238 | option1 = holes.copy()
239 | if len(option1) > 1:
240 | option1.pop(random.randrange(len(option1)))
241 | else:
242 | x, y = option1[0]
243 | dx, dy = random.randint(-3, 3), random.randint(-3, 3)
244 | option1[0] = (x + dx, y + dy)
245 |
246 | # Option 2: Mirror transformation
247 | option2 = [(120 - x, 120 - y) for x, y in holes]
248 |
249 | # If mirror is same as original, try rotation
250 | if set(option2) == set(holes):
251 | option2 = [
252 | (ty + 60, -tx + 60) for x, y in holes for tx, ty in [(x - 60, y - 60)]
253 | ]
254 | # If rotation is also same, shift horizontally
255 | if set(option2) == set(holes):
256 | option2 = [(min(x + 5, 110), y) for x, y in holes]
257 |
258 | return [option1, option2]
259 |
260 |
261 | def generate_test_image(folds, test_number, num_folds, num_holes):
262 | """Generate a test image with optimized processing."""
263 | global process_images
264 | process_images = []
265 |
266 | font_bigger = get_font(18)
267 |
268 | initial_poly = [(10, 10), (110, 10), (110, 110), (10, 110)]
269 | final_poly = recursive_fold(folds, 0, initial_poly)
270 |
271 | # Generate punched holes
272 | punched_holes = [
273 | generate_hole_in_poly(final_poly, margin=5) for _ in range(num_holes)
274 | ]
275 | unfolded_holes = compute_all_layers(punched_holes, folds)
276 |
277 | # Create final view image
278 | img_final = Image.new("RGB", (120, 120), "white")
279 | draw_final = ImageDraw.Draw(img_final)
280 | if final_poly:
281 | draw_final.polygon(final_poly, fill="white", outline="black")
282 | draw_holes(draw_final, punched_holes)
283 | process_images.append((img_final, "Final view"))
284 |
285 | # Create candidate images
286 | img_correct = Image.new("RGB", (120, 120), "white")
287 | draw_correct = ImageDraw.Draw(img_correct)
288 | draw_paper(draw_correct)
289 | draw_holes(draw_correct, unfolded_holes)
290 |
291 | wrong_options = generate_wrong_options(unfolded_holes)
292 | candidate_images = [("correct", img_correct)]
293 |
294 | for wrong_holes in wrong_options:
295 | img_wrong = Image.new("RGB", (120, 120), "white")
296 | draw_wrong = ImageDraw.Draw(img_wrong)
297 | draw_paper(draw_wrong)
298 | draw_holes(draw_wrong, wrong_holes)
299 | candidate_images.append(("wrong", img_wrong))
300 |
301 | # Shuffle and assign labels
302 | random.shuffle(candidate_images)
303 | option_labels = ["A", "B", "C"]
304 | correct_label = None
305 |
306 | for i, (kind, img) in enumerate(candidate_images):
307 | if kind == "correct":
308 | correct_label = option_labels[i]
309 |
310 | # Create final composite image
311 | small_size = 120
312 | top_width = len(process_images) * small_size + (len(process_images) - 1) * 10
313 | bottom_width = 3 * small_size + 2 * 10
314 | total_width = max(top_width, bottom_width)
315 | total_height = 320
316 |
317 | total_img = Image.new("RGB", (total_width, total_height), "white")
318 | draw_total = ImageDraw.Draw(total_img)
319 |
320 | # Draw top row
321 | start_x = (total_width - top_width) // 2
322 | for i, (img, label) in enumerate(process_images):
323 | x = start_x + i * (small_size + 10)
324 | total_img.paste(img, (x, 30))
325 | draw_total.text((x + 15, 10), label, fill="black", font=font_bigger)
326 |
327 | # Draw bottom row
328 | start_x_bottom = (total_width - bottom_width) // 2
329 | for i, (_, img) in enumerate(candidate_images):
330 | x = start_x_bottom + i * (small_size + 10)
331 | total_img.paste(img, (x, 190))
332 | draw_total.text((x + 45, 180), option_labels[i], fill="black", font=font_bigger)
333 |
334 | # Save image
335 | f_name = f"{test_number}_fold-{num_folds}_holes-{num_holes}.png"
336 | out_path = os.path.join(OUTPUT_DIR, f_name)
337 | total_img.save(out_path)
338 |
339 | return out_path, correct_label
340 |
341 |
342 | if __name__ == "__main__":
343 | parser = argparse.ArgumentParser(
344 | description="Generate folded paper images with holes."
345 | )
346 | parser.add_argument(
347 | "-n",
348 | "--num-images",
349 | type=int,
350 | default=10,
351 | help="Number of images to generate (default: 10)",
352 | )
353 | parser.add_argument(
354 | "-s",
355 | "--seed",
356 | type=int,
357 | default=42,
358 | help="Seed for random number generator (default: 42)",
359 | )
360 | parser.add_argument(
361 | "-f",
362 | "--num-folds",
363 | type=int,
364 | default=2,
365 | choices=range(1, 10),
366 | help="Number of folds (minimum 1, maximum 9) (default: 2)",
367 | )
368 | parser.add_argument(
369 | "-H", "--num-holes", type=int, default=1, help="Number of holes (default: 1)"
370 | )
371 | parser.add_argument(
372 | "-m",
373 | "--metadata-file",
374 | type=str,
375 | default="metadata.jsonl",
376 | help="Metadata JSONL file to append to (default: metadata.jsonl)",
377 | )
378 | args = parser.parse_args()
379 |
380 | random.seed(args.seed)
381 | metadata_path = os.path.join(OUTPUT_DIR, args.metadata_file)
382 | with open(metadata_path, "a", encoding="utf-8") as metaf:
383 | for i in range(1, args.num_images + 1):
384 | fold_group = random.choice(["VH", "Diagonal"])
385 | if fold_group == "VH":
386 | folds = [random.choice(["V", "H"]) for _ in range(args.num_folds)]
387 | else:
388 | folds = [random.choice(["D", "N"]) for _ in range(args.num_folds)]
389 | image_path, correct_option = generate_test_image(
390 | folds, i, args.num_folds, args.num_holes
391 | )
392 | rel_image_path = os.path.basename(image_path)
393 | metadata_obj = {
394 | "filename": rel_image_path,
395 | "correct_option": correct_option,
396 | }
397 | metaf.write(json.dumps(metadata_obj) + "\n")
398 |
--------------------------------------------------------------------------------
/src/data/create_prompts.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | This script demonstrates how to:
4 | 1. Define a list of simple objects with attributes and a set of spatial test tasks.
5 | 2. For each spatial task and each model, generate multiple scene descriptions using an LLM.
6 | 3. Post-process the model's output to remove internal commentary.
7 | 4. Log each generated prompt and metadata to both WandB and Neptune.ai.
8 | 5. Save each generated scene description and metadata to a JSONL file.
9 | """
10 |
11 | import re
12 | import os
13 | import random
14 | import json
15 | import logging
16 | import glob
17 | from typing import List, Dict, Tuple, Any
18 |
19 | import neptune.utils
20 | import torch
21 | import pandas as pd
22 | from transformers import pipeline
23 | from dotenv import load_dotenv
24 | import wandb
25 | import neptune
26 |
27 | # Load environment variables from the .env file.
28 | load_dotenv(override=True)
29 |
30 | logger = logging.getLogger(__name__)
31 | logging.basicConfig(filename="debug_prompts.log", encoding="utf-8", level=logging.DEBUG)
32 |
33 | # -------------------------------
34 | # Type Aliases
35 | # -------------------------------
36 | ObjectType = Dict[str, Any]
37 | SpatialTaskType = Dict[str, str]
38 | MessageType = Dict[str, str]
39 |
40 | # -------------------------------
41 | # Constants
42 | # -------------------------------
43 | SAVE_DIR = "output/prompts/"
44 | TRIAL_NAME = "fiveshotv6-prompts-beam"
45 | OUTPUT_FILENAME = f"{SAVE_DIR}{TRIAL_NAME}.jsonl" # All prompts will be appended here.
46 | PROMPTS_PER_TASK = 25
47 | # -------------------------------
48 | # 7. List of Models to Use
49 | # -------------------------------
50 | MODELS: List[str] = glob.glob("bin/models/llms/*")
51 |
52 | # -------------------------------
53 | # Initialize WandB
54 | # -------------------------------
55 | wandb.init(
56 | project="SpatialScenePrompts",
57 | name=TRIAL_NAME,
58 | config={
59 | "num_prompts_per_task": PROMPTS_PER_TASK,
60 | "models": MODELS,
61 | },
62 | )
63 |
64 | # -------------------------------
65 | # Initialize Neptune.ai Run
66 | # -------------------------------
67 | neptune_run = neptune.init_run(
68 | project="stogiannidis/create-prompts", # Replace with your project name.
69 | api_token=os.getenv("NEPTUNE_API_TOKEN"), # Replace with your Neptune API token.
70 | name=TRIAL_NAME,
71 | )
72 | neptune_run["config/num_prompts_per_task"] = PROMPTS_PER_TASK
73 | neptune_run["config/models"] = neptune.utils.stringify_unsupported(MODELS)
74 |
75 | # -------------------------------
76 | # 1. List of Simple Objects with Attributes
77 | # -------------------------------
78 | OBJECTS: List[str] = [
79 | "bicycle",
80 | "motorcycle",
81 | "scooter",
82 | "car",
83 | "truck",
84 | "bus",
85 | "train",
86 | "airplane",
87 | "helicopter",
88 | "boat",
89 | "ship",
90 | "dog",
91 | "cat",
92 | "bird",
93 | "fish",
94 | "rabbit",
95 | ]
96 |
97 | # -------------------------------
98 | # 2. List of Spatial Test Tasks
99 | # -------------------------------
100 | # Each task now uses a concise definition and an example.
101 | SPATIAL_TASKS: List[SpatialTaskType] = [
102 | {
103 | "task_type": "mental rotation",
104 | "examples": [
105 | "A photo-realistic image of a car, viewed from the front, with the wheels turned to the left, parked in a driveway, with a clear blue sky in the background.",
106 | "A photo-realistic image of a mug, viewed from the side, with the handle on the right, filled with steaming coffee, on a wooden table, with a window in the background showing a sunny day.",
107 | "A photo-realistic image of a laptop, viewed from the back, with the screen open",
108 | "A photo-realistic image of a bicycle, viewed from the side, with the front wheel turned to the right, parked on a cobblestone street, with a row of colorful houses in the background.",
109 | "A photo-realistic image of a cat, viewed from the front, with the tail curled to the right, sitting on a windowsill, with a potted plant in the background.",
110 | ],
111 | },
112 | # {
113 | # "task_type": "perspective taking",
114 | # "examples": [
115 | # "A photograph of a family of four, two adults and two children, standing in a row, with the camera positioned at the height of the children, looking up at the adults, who are smiling down at them.",
116 | # "Photograph of street artist, vividly painting a vibrant mural, surrounded by captivated pedestrians, in a stencil-like graffiti style, with a gritty urban setting, drenched in chiaroscuro lighting for a dramatic and lively atmosphere.",
117 | # "Photograph of vibrant street vendors, laden with an array of ripe fruits, amidst the lively hustle of a farmers market - captured in the style of a vivid, Impressionist oil painting, with warm sunlight filtering through a cloud-speckled sky.",
118 | # "Photograph of scuba diver capturing a vibrant, up-close moment with a majestic sea turtle among intricately detailed, luminous coral reef, in the style of a high-definition underwater photograph blending vivid hues and soft shadows, with a serene, lively atmosphere.",
119 | # "Photograph of toddler and playful puppy in sunlit backyard, chasing iridescent bubbles in whimsical Impressionist style, vibrant colors, tender atmosphere, capturing the joy of childhood and canine companionship.",
120 | # ],
121 | # },
122 | ]
123 |
124 | # -------------------------------
125 | # 3. Enhanced System Prompt
126 | # -------------------------------
127 | # This prompt instructs the LLM to generate a test prompt for spatial reasoning.
128 | SYSTEM_PROMPT: str = """**Role:** Text-Based Image Description Generation Assistant
129 |
130 | **Objective:** To generate high-quality text image descriptions for generative models.
131 |
132 | **Input (Textual):**
133 | * A list of objects (provided as text by the user).
134 | * Example image descriptions (provided as text by the user).
135 |
136 | **Output (Textual):** New image descriptions (as text).
137 |
138 | **Task:** Generate new text image descriptions that meet the following criteria:
139 | 1. **Style Mimicry:** Replicate the writing style, sentence structure, and vocabulary used in the example text descriptions.
140 | 2. **Object Novelty:** Feature the provided list of objects, ensuring they are different from objects explicitly mentioned in the example text descriptions.
141 | 3. **Setting Novelty:** Describe the objects in new and different settings or contexts compared to those presented in the example text descriptions.
142 | 4. **Logical Coherence & Realism:** Ensure all generated text descriptions are logically sound, realistic, and portray plausible scenarios in text. Avoid nonsensical or physically impossible descriptions in text.
143 | """
144 |
145 | # -------------------------------
146 | # 4. Function to Construct a Prompt for a Given Task
147 | # -------------------------------
148 |
149 |
150 | def construct_prompt_for_task(task: SpatialTaskType) -> Tuple[str, str]:
151 | """
152 | Constructs a test prompt using the task's definition and example, and optionally
153 | includes one or two randomly selected simple objects with attributes.
154 |
155 | Args:
156 | task: A spatial test task containing "task_type", "definition", and "example".
157 |
158 | Returns:
159 | A tuple with the task type and the constructed user prompt.
160 | """
161 | task_type = task["task_type"]
162 | examples = task["examples"]
163 | user_prompt = (f"""Please generate a text-to-image prompt inspired by the following examples and using these objects:
164 | **Example Prompts:**
165 | {examples}
166 |
167 | **Objects to Include:**
168 | {OBJECTS}
169 |
170 | Respond with the brief yet consise generated scene description."""
171 | )
172 | return task_type, user_prompt
173 |
174 |
175 | # -------------------------------
176 | # 5. Post-Processing Function
177 | # -------------------------------
178 | def post_process_output(output_text: str) -> str:
179 | """
180 | Removes any text between and tags and strips whitespace.
181 |
182 | Args:
183 | output_text: The raw output text from the model.
184 |
185 | Returns:
186 | The cleaned output text.
187 | """
188 | cleaned_text: str = re.sub(r".*?", "", output_text, flags=re.DOTALL)
189 | return cleaned_text.strip()
190 |
191 |
192 | # -------------------------------
193 | # 6. Helper Functions to Adapt to Different Pipeline Variants
194 | # -------------------------------
195 | def supports_system_prompt(model_name: str) -> bool:
196 | """
197 | Heuristic: if the model id contains "gemma" (case-insensitive) we assume it does not support chat-style prompts.
198 | """
199 | return "gemma" not in model_name.lower()
200 |
201 |
202 | def prepare_pipeline_input(messages: List[MessageType], model_name: str) -> Any:
203 | """
204 | Prepares the input for the pipeline call based on whether the model supports system prompts.
205 | """
206 | if supports_system_prompt(model_name):
207 | return messages
208 | else:
209 | combined_prompt = "\n\n".join(msg["content"].strip() for msg in messages)
210 | return [{"role": "user", "content": combined_prompt.strip()}]
211 |
212 |
213 | def extract_generated_text(outputs: Any) -> str:
214 | """
215 | Extracts the generated text from the output of the pipeline.
216 |
217 | Supports both nested (chat-style) and flat output.
218 | """
219 | if isinstance(outputs, list) and outputs:
220 | output_item = outputs[0]
221 | if "generated_text" in output_item:
222 | gen = output_item["generated_text"]
223 | if isinstance(gen, list):
224 | last = gen[-1]
225 | if isinstance(last, dict) and "content" in last:
226 | return last["content"].strip()
227 | elif isinstance(last, str):
228 | return last.strip()
229 | elif isinstance(gen, str):
230 | return gen.strip()
231 | return ""
232 |
233 |
234 | # -------------------------------
235 | # 8. Main Routine: Iterate Over All Model and Task Combinations
236 | # -------------------------------
237 | def main() -> None:
238 | logger.info("\n\nStarting main routine with WandB and Neptune.ai logging.")
239 |
240 | for model_name in MODELS:
241 | logger.info("Processing model: %s", model_name)
242 | # Get the original name of the model without the path.
243 | sanitized_model: str = model_name.split("/")[-1]
244 | logger.info("Sanitized model name: %s", sanitized_model)
245 |
246 | # Initialize the text-generation pipeline for the current model.
247 | llm_pipe: Any = pipeline(
248 | "text-generation",
249 | model=model_name,
250 | device_map="auto",
251 | torch_dtype=torch.bfloat16,
252 | return_full_text=False,
253 | do_sample=True,
254 | temperature=1,
255 | top_k=10,
256 | top_p=0.9,
257 | num_beams=10,
258 | )
259 | eos_token_id = llm_pipe.tokenizer.eos_token_id
260 | logger.info("Pipeline initialized for %s", sanitized_model)
261 |
262 | for task in SPATIAL_TASKS:
263 | # Generate a number of prompts for each task.
264 | for i in range(PROMPTS_PER_TASK):
265 | logger.info(
266 | "Generating prompt %d for task: %s", i + 1, task["task_type"]
267 | )
268 | task_type, user_prompt = construct_prompt_for_task(task)
269 | logger.info("User prompt generated for task: %s", task_type)
270 |
271 | # Prepare the messages.
272 | messages: List[MessageType] = [
273 | {"role": "system", "content": SYSTEM_PROMPT},
274 | {"role": "user", "content": user_prompt},
275 | ]
276 | prompt_input = prepare_pipeline_input(messages, model_name)
277 |
278 | # Generate the scene description using the LLM.
279 | outputs: Any = llm_pipe(
280 | prompt_input,
281 | max_new_tokens=1024,
282 | pad_token_id=eos_token_id,
283 | )
284 | generated_text: str = extract_generated_text(outputs)
285 | if generated_text:
286 | cleaned_text = post_process_output(generated_text)
287 |
288 | # Prepare output metadata.
289 | output_data: Dict[str, Any] = {
290 | "model": sanitized_model,
291 | "task_type": task_type,
292 | "user_prompt": user_prompt,
293 | "generated_scene_description": cleaned_text,
294 | "iteration": i + 1,
295 | }
296 |
297 | # Log the metadata to WandB.
298 | wandb.log(output_data)
299 |
300 | # Log the metadata to Neptune.
301 | neptune_run[
302 | f"prompt/{sanitized_model}/{task_type}/iteration_{i + 1}"
303 | ] = output_data
304 |
305 | # Save the generated description and metadata to a JSONL file.
306 | with open(OUTPUT_FILENAME, "a", encoding="utf-8") as f:
307 | f.write(json.dumps(output_data, ensure_ascii=False) + "\n")
308 |
309 | # Finish logging for both platforms.
310 | logger.info("Main routine completed.")
311 |
312 | # Create a Table of Contents for the generated prompts.
313 | results_df: pd.DataFrame = pd.read_json(OUTPUT_FILENAME, lines=True)
314 |
315 | # Log the Table of Contents to both platforms.
316 | wandb.log({"results": wandb.Table(data=results_df)})
317 | neptune_run["output"].upload(neptune.types.File.as_html(results_df))
318 |
319 | # Finish the run for both platforms.
320 | wandb.finish()
321 | neptune_run.stop()
322 |
323 |
324 | if __name__ == "__main__":
325 | main()
326 |
--------------------------------------------------------------------------------
/src/data/mrt.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import random
4 | import argparse
5 | import json
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
9 | from matplotlib.gridspec import GridSpec
10 |
11 | # Global dictionary of polycube shapes
12 | SHAPES = {
13 | "Snake": [
14 | (0, 0, 0), (1, 0, 0), (1, 1, 0), (1, 1, 1),
15 | (1, 2, 1), (2, 2, 1), (2, 2, 2), (2, 3, 2),
16 | ],
17 | "Zigzag": [
18 | (0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0),
19 | (2, 1, 1), (2, 2, 1), (3, 2, 1), (3, 2, 2),
20 | ],
21 | "SnakeComplex1": [
22 | (0, 0, 0), (1, 0, 0), (2, 0, 0), (2, 1, 0),
23 | (2, 1, 1), (2, 2, 1), (1, 2, 1), (1, 3, 1), (1, 3, 2),
24 | ],
25 | "HookedCorner": [
26 | (0, 0, 0), (1, 0, 0), (2, 0, 0), (0, 1, 0),
27 | (0, 2, 0), (0, 2, 1), (0, 2, 2),
28 | ],
29 | "TopPlate": [
30 | (0, 0, 0), (0, 1, 0), (0, 2, 0),
31 | (0, 2, 1), (1, 2, 1), (2, 2, 1),
32 | ],
33 | "CornerStaircase": [
34 | (0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0),
35 | (1, 1, 0), (2, 1, 0), (3, 1, 0), (3, 2, 0), (3, 3, 0),
36 | ],
37 | "TripleArm": [
38 | (3, -1, 0), (3, -1, 1), (3, -1, 2), (0, 0, 0),
39 | (1, 0, 0), (2, 0, 0), (3, 0, 0), (0, 1, 0), (0, 2, 0),
40 | ],
41 | }
42 |
43 | # Shapes available for each difficulty
44 | EASY_SHAPES = ["Snake", "HookedCorner", "TopPlate", "CornerStaircase", "TripleArm"]
45 | COMPLEX_SHAPES = list(SHAPES.keys())
46 |
47 | # Dynamic similar-object mapping
48 | all_shape_keys = list(SHAPES.keys())
49 | SIMILAR_MAPPING = {
50 | key: [s for s in all_shape_keys if s != key] for key in all_shape_keys
51 | }
52 |
53 |
54 | def set_axes_equal(ax, all_vertices):
55 | """Make the aspect ratio equal and remove visual distractions."""
56 | all_vertices = np.array(all_vertices)
57 | x_limits = [np.min(all_vertices[:, 0]), np.max(all_vertices[:, 0])]
58 | y_limits = [np.min(all_vertices[:, 1]), np.max(all_vertices[:, 1])]
59 | z_limits = [np.min(all_vertices[:, 2]), np.max(all_vertices[:, 2])]
60 |
61 | x_range = x_limits[1] - x_limits[0]
62 | y_range = y_limits[1] - y_limits[0]
63 | z_range = z_limits[1] - z_limits[0]
64 | max_range = max(x_range, y_range, z_range)
65 |
66 | x_mid = np.mean(x_limits)
67 | y_mid = np.mean(y_limits)
68 | z_mid = np.mean(z_limits)
69 |
70 | ax.set_xlim(x_mid - max_range / 2, x_mid + max_range / 2)
71 | ax.set_ylim(y_mid - max_range / 2, y_mid + max_range / 2)
72 | ax.set_zlim(z_mid - max_range / 2, z_mid + max_range / 2)
73 | ax.set_box_aspect([1, 1, 1])
74 |
75 | # Set ticks and remove labels
76 | ax.set_xticks(np.linspace(x_limits[0], x_limits[1], 5))
77 | ax.set_yticks(np.linspace(y_limits[0], y_limits[1], 5))
78 | ax.set_zticks(np.linspace(z_limits[0], z_limits[1], 5))
79 | ax.set_xticklabels([])
80 | ax.set_yticklabels([])
81 | ax.set_zticklabels([])
82 |
83 | # Remove 3D visual elements for complex mode
84 | if hasattr(ax, '_remove_3d_elements'):
85 | ax.xaxis.pane.set_visible(False)
86 | ax.yaxis.pane.set_visible(False)
87 | ax.zaxis.pane.set_visible(False)
88 | ax.grid(False)
89 | ax._axis3don = False
90 |
91 |
92 | def cube_vertices(origin, size=1.0):
93 | """Return the 8 corner vertices of a cube."""
94 | x, y, z = origin
95 | return np.array([
96 | [x, y, z], [x + size, y, z], [x + size, y + size, z], [x, y + size, z],
97 | [x, y, z + size], [x + size, y, z + size],
98 | [x + size, y + size, z + size], [x, y + size, z + size],
99 | ])
100 |
101 |
102 | def plot_cubes(ax, vertices, title="", facecolor="white", hide_3d_elements=False):
103 | """Plot cubes using their vertices."""
104 | if hide_3d_elements:
105 | ax._remove_3d_elements = True
106 |
107 | n_cubes = len(vertices) // 8
108 | vertices_reshaped = vertices.reshape((n_cubes, 8, 3))
109 |
110 | for cube_verts in vertices_reshaped:
111 | faces = [
112 | [0, 1, 2, 3], [4, 5, 6, 7], [0, 1, 5, 4],
113 | [2, 3, 7, 6], [1, 2, 6, 5], [0, 3, 7, 4],
114 | ]
115 | for face in faces:
116 | polygon = Poly3DCollection(
117 | [cube_verts[face]], facecolors=facecolor,
118 | edgecolors="black", alpha=1.0,
119 | )
120 | ax.add_collection3d(polygon)
121 |
122 | ax.set_title(title, fontsize=12)
123 | set_axes_equal(ax, vertices)
124 |
125 |
126 | def generate_shape_vertices(shape_name, cube_size=1.0):
127 | """Generate vertices for a given shape."""
128 | if shape_name not in SHAPES:
129 | raise ValueError(f"Unknown shape {shape_name}")
130 |
131 | cube_origins = SHAPES[shape_name]
132 | all_vertices = []
133 | for origin in cube_origins:
134 | corners = cube_vertices(origin, size=cube_size)
135 | all_vertices.append(corners)
136 | return np.vstack(all_vertices)
137 |
138 |
139 | def get_transformed_candidate(transformation_func, original, max_attempts=10):
140 | """Apply transformation until result differs from original."""
141 | for _ in range(max_attempts):
142 | candidate = transformation_func(original)
143 | if not np.allclose(candidate, original, atol=1e-6):
144 | return candidate
145 | return candidate
146 |
147 |
148 | def transform_rotate(vertices, difficulty="easy"):
149 | """Rotate shape based on difficulty level."""
150 | center = vertices.mean(axis=0)
151 | shifted = vertices - center
152 |
153 | if difficulty == "easy":
154 | # Single axis rotation with simple angles
155 | axis = random.choice(["x", "y", "z"])
156 | angle = np.deg2rad(random.choice([-90, 90, 180]))
157 |
158 | if axis == "x":
159 | R = np.array([
160 | [1, 0, 0],
161 | [0, np.cos(angle), -np.sin(angle)],
162 | [0, np.sin(angle), np.cos(angle)],
163 | ])
164 | elif axis == "y":
165 | R = np.array([
166 | [np.cos(angle), 0, np.sin(angle)],
167 | [0, 1, 0],
168 | [-np.sin(angle), 0, np.cos(angle)],
169 | ])
170 | else: # z axis
171 | R = np.array([
172 | [np.cos(angle), -np.sin(angle), 0],
173 | [np.sin(angle), np.cos(angle), 0],
174 | [0, 0, 1],
175 | ])
176 | else: # complex
177 | # Multi-axis rotation with varied angles
178 | angle_x = np.deg2rad(np.random.choice([0, 60, 90, 120]))
179 | angle_y = np.deg2rad(np.random.choice([0, 60, 90, 120]))
180 | angle_z = np.deg2rad(np.random.choice([0, 60, 90, 120]))
181 |
182 | Rx = np.array([
183 | [1, 0, 0],
184 | [0, np.cos(angle_x), -np.sin(angle_x)],
185 | [0, np.sin(angle_x), np.cos(angle_x)],
186 | ])
187 | Ry = np.array([
188 | [np.cos(angle_y), 0, np.sin(angle_y)],
189 | [0, 1, 0],
190 | [-np.sin(angle_y), 0, np.cos(angle_y)],
191 | ])
192 | Rz = np.array([
193 | [np.cos(angle_z), -np.sin(angle_z), 0],
194 | [np.sin(angle_z), np.cos(angle_z), 0],
195 | [0, 0, 1],
196 | ])
197 | R = Rz @ Ry @ Rx
198 |
199 | rotated = (R @ shifted.T).T
200 | return rotated + center
201 |
202 |
203 | def transform_mirror(vertices, difficulty="easy"):
204 | """Mirror shape based on difficulty level."""
205 | center = vertices.mean(axis=0)
206 | shifted = vertices - center
207 | mirrored = shifted.copy()
208 |
209 | if difficulty == "easy":
210 | # Mirror across XY plane (Z-axis)
211 | mirrored[:, 2] = -mirrored[:, 2]
212 | else: # complex
213 | # Mirror across random axis
214 | axis = random.choice([0, 1, 2])
215 | mirrored[:, axis] = -mirrored[:, axis]
216 |
217 | return mirrored + center
218 |
219 |
220 | def get_visually_similar_candidate(chosen_shape_name, original_vertices, cube_size=1.0, difficulty="easy"):
221 | """Get a similar shape candidate."""
222 | if chosen_shape_name in SIMILAR_MAPPING:
223 | similar_candidates = SIMILAR_MAPPING[chosen_shape_name][:]
224 | random.shuffle(similar_candidates)
225 |
226 | for similar_shape_name in similar_candidates:
227 | if similar_shape_name in SHAPES:
228 | candidate_vertices = generate_shape_vertices(similar_shape_name, cube_size)
229 | candidate_vertices = get_transformed_candidate(
230 | lambda v: transform_rotate(v, difficulty), candidate_vertices
231 | )
232 | if (candidate_vertices.shape != original_vertices.shape or
233 | not np.allclose(candidate_vertices, original_vertices, atol=1e-6)):
234 | return candidate_vertices
235 | return None
236 |
237 |
238 | def generate_one_image(index, difficulty="easy", facecolor="white", outdir="data/mrt"):
239 | """Generate a single MRT image based on difficulty."""
240 | cube_size = 1.0
241 |
242 | # Select shapes based on difficulty
243 | shapes_list = EASY_SHAPES if difficulty == "easy" else COMPLEX_SHAPES
244 | shape_name = random.choice(shapes_list)
245 | original_vertices = generate_shape_vertices(shape_name, cube_size=cube_size)
246 |
247 | # Generate correct candidate (rotation)
248 | correct_candidate = get_transformed_candidate(
249 | lambda v: transform_rotate(v, difficulty), original_vertices
250 | )
251 |
252 | # Generate wrong candidates
253 | mirror_candidate = get_transformed_candidate(
254 | lambda v: transform_rotate(transform_mirror(v, difficulty), difficulty),
255 | original_vertices
256 | )
257 |
258 | similar_candidate = get_visually_similar_candidate(
259 | shape_name, original_vertices, cube_size, difficulty
260 | )
261 | if similar_candidate is None:
262 | similar_candidate = mirror_candidate
263 |
264 | # Set up candidates based on difficulty
265 | if difficulty == "easy":
266 | candidates = [
267 | ("rotate", correct_candidate),
268 | ("mirror", mirror_candidate),
269 | ("similar", similar_candidate),
270 | ]
271 | num_candidates = 3
272 | figure_size = (6, 6)
273 | else: # complex
274 | mirror_candidate2 = get_transformed_candidate(
275 | lambda v: transform_rotate(transform_mirror(v, difficulty), difficulty),
276 | original_vertices
277 | )
278 | candidates = [
279 | ("rotate", correct_candidate),
280 | ("mirror", mirror_candidate),
281 | ("similar", similar_candidate),
282 | ("mirror2", mirror_candidate2),
283 | ]
284 | num_candidates = 4
285 | figure_size = (12, 8)
286 |
287 | random.shuffle(candidates)
288 | correct_candidate_index = [
289 | i for i, cand in enumerate(candidates) if cand[0] == "rotate"
290 | ][0]
291 |
292 | # Create figure
293 | fig = plt.figure(figsize=figure_size)
294 | gs = GridSpec(2, num_candidates, height_ratios=[0.5, 1], wspace=0.1, hspace=0.1)
295 |
296 | # Plot original shape
297 | ax_orig = fig.add_subplot(gs[0, :], projection="3d")
298 | plot_cubes(ax_orig, original_vertices, title="Original Shape",
299 | facecolor=facecolor, hide_3d_elements=(difficulty == "complex"))
300 |
301 | # Plot candidates
302 | for i in range(num_candidates):
303 | ax = fig.add_subplot(gs[1, i], projection="3d")
304 | _, candidate_vertices = candidates[i]
305 | plot_cubes(ax, candidate_vertices, title=f"Option {chr(65 + i)}",
306 | facecolor=facecolor, hide_3d_elements=(difficulty == "complex"))
307 |
308 | # Save image
309 | filename = f"{shape_name}_{index}.png"
310 | output_path = os.path.join(outdir, filename)
311 | plt.savefig(output_path, dpi=60, bbox_inches="tight", pad_inches=0)
312 | plt.close(fig)
313 |
314 | # Save metadata
315 | metadata = {
316 | "filename": filename,
317 | "difficulty": difficulty,
318 | "shape": shape_name,
319 | "candidate_order": [tag for tag, _ in candidates],
320 | "answer": chr(65 + correct_candidate_index),
321 | }
322 | with open(os.path.join(outdir, "metadata.jsonl"), "a") as f:
323 | f.write(json.dumps(metadata) + "\n")
324 |
325 |
326 | def main():
327 | parser = argparse.ArgumentParser(
328 | description="Generate mental rotation test images with variable difficulty."
329 | )
330 | parser.add_argument(
331 | "--difficulty", "-d", type=str, choices=["easy", "complex"], default="easy",
332 | help="Difficulty level: 'easy' (3 candidates, simple rotations) or 'complex' (4 candidates, complex transformations)"
333 | )
334 | parser.add_argument(
335 | "--num_images", "-n", type=int, default=1,
336 | help="Number of images to generate"
337 | )
338 | parser.add_argument(
339 | "--color", "-c", type=str, default="white",
340 | help="Color for the polycubes"
341 | )
342 | parser.add_argument(
343 | "--seed", "-s", type=int, default=69,
344 | help="Seed for reproducible results"
345 | )
346 | parser.add_argument(
347 | "--outdir", "-o", type=str, default=None,
348 | help="Output directory (defaults to data/mrt/{difficulty})"
349 | )
350 |
351 | args = parser.parse_args()
352 |
353 | # Set default output directory based on difficulty
354 | if args.outdir is None:
355 | args.outdir = f"data/mrt/{args.difficulty}"
356 |
357 | os.makedirs(args.outdir, exist_ok=True)
358 |
359 | # Set seed for reproducibility
360 | if args.seed is not None:
361 | random.seed(args.seed)
362 | np.random.seed(args.seed)
363 |
364 | # Generate images
365 | for i in range(args.num_images):
366 | generate_one_image(i, difficulty=args.difficulty,
367 | facecolor=args.color, outdir=args.outdir)
368 |
369 | print(f"Generated {args.num_images} {args.difficulty} MRT images in {args.outdir}")
370 |
371 |
372 | if __name__ == "__main__":
373 | main()
374 |
--------------------------------------------------------------------------------
/src/data/create_images.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import logging
5 | import random
6 | import numpy as np
7 | from diffusers import FluxPipeline, StableDiffusion3Pipeline, DiffusionPipeline
8 | import gc
9 | from tqdm import tqdm
10 |
11 | """
12 | Image generation module for creating synthetic images using diffusion models.
13 |
14 | This module provides a wrapper for diffusion models and utilities for batch processing
15 | image generation tasks. It supports Stable Diffusion 3.5 and FLUX.1-dev models with
16 | optimized memory management and error handling.
17 | """
18 |
19 | # Configure logging
20 | LOG_DIR = "logs"
21 | os.makedirs(LOG_DIR, exist_ok=True)
22 | LOG_FILE = os.path.join(LOG_DIR, "image_creation.log")
23 | logging.basicConfig(
24 | level=logging.INFO,
25 | format="%(asctime)s - %(levelname)s - %(message)s",
26 | handlers=[logging.FileHandler(LOG_FILE)],
27 | )
28 |
29 | def set_seed(seed: int = 42):
30 | """
31 | Set random seeds for reproducibility across all libraries.
32 |
33 | Args:
34 | seed (int): The random seed to use. Defaults to 42.
35 | """
36 | random.seed(seed)
37 | np.random.seed(seed)
38 | torch.manual_seed(seed)
39 | torch.cuda.manual_seed(seed)
40 | torch.cuda.manual_seed_all(seed)
41 | # Make CuDNN deterministic
42 | torch.backends.cudnn.deterministic = True
43 | torch.backends.cudnn.benchmark = False
44 | logging.info(f"Random seed set to {seed} for reproducibility")
45 |
46 |
47 | class DiffusionPipelineWrapper:
48 | """
49 | A wrapper class for diffusion model pipelines with optimized memory management.
50 |
51 | This class provides a unified interface for different diffusion models, handles
52 | lazy loading/unloading of models to optimize GPU memory usage, and provides
53 | consistent image generation capabilities.
54 |
55 | Attributes:
56 | model_id (str): Identifier for the diffusion model to use
57 | steps (int): Number of inference steps for image generation
58 | scale (float): Guidance scale for controlling adherence to prompt
59 | pipeline: The loaded diffusion pipeline (None when unloaded)
60 |
61 | Supported Models:
62 | - stabilityai/stable-diffusion-3.5-large
63 | - black-forest-labs/FLUX.1-dev
64 | """
65 |
66 | def __init__(self, model_id: str, steps: int = 40, scale: float = 4.5):
67 | """
68 | Initialize the diffusion pipeline wrapper.
69 |
70 | Args:
71 | model_id (str): The model identifier for the diffusion model
72 | steps (int, optional): Number of inference steps. Defaults to 40.
73 | scale (float, optional): Guidance scale value. Defaults to 4.5.
74 |
75 | Raises:
76 | ValueError: If an unsupported model_id is provided
77 | """
78 | self.model_id = model_id
79 | self.steps = steps
80 | self.scale = scale
81 | self.pipeline = None
82 | logging.info(f"Initializing wrapper for {model_id}")
83 |
84 | def _create_pipeline(self, model_id: str):
85 | """
86 | Create and configure the appropriate diffusion pipeline based on model ID.
87 |
88 | Args:
89 | model_id (str): The model identifier
90 |
91 | Returns:
92 | Union[StableDiffusion3Pipeline, FluxPipeline]: Configured pipeline instance
93 |
94 | Raises:
95 | ValueError: If the model_id is not supported
96 | """
97 | if model_id == "stabilityai/stable-diffusion-3.5-large":
98 | pipe = StableDiffusion3Pipeline.from_pretrained(
99 | "bin/models/diffusers/stable-diffusion-3.5-large",
100 | device_map="balanced",
101 | )
102 | elif model_id == "black-forest-labs/FLUX.1-dev":
103 | pipe = FluxPipeline.from_pretrained(
104 | "bin/models/diffusers/FLUX.1-dev",
105 | torch_dtype=torch.bfloat16,
106 | device_map="balanced",
107 | )
108 | else:
109 | raise ValueError(
110 | "Unsupported model id. Use 'stabilityai/stable-diffusion-3.5-large' or 'black-forest-labs/FLUX.1-dev'."
111 | )
112 | return pipe
113 |
114 | def load_pipeline(self):
115 | """
116 | Lazy load the diffusion pipeline to optimize memory usage.
117 |
118 | This method implements lazy loading, only creating the pipeline when needed.
119 | Subsequent calls return the already loaded pipeline without recreating it.
120 |
121 | Returns:
122 | Union[StableDiffusion3Pipeline, FluxPipeline]: The loaded pipeline
123 | """
124 | if self.pipeline is None:
125 | self.pipeline = self._create_pipeline(self.model_id)
126 | return self.pipeline
127 |
128 | def unload_pipeline(self):
129 | """
130 | Unload the pipeline and free GPU memory.
131 |
132 | This method properly deallocates the pipeline from memory, clears CUDA cache,
133 | and forces garbage collection to maximize memory availability for other models.
134 | Essential for processing multiple models sequentially on limited GPU memory.
135 | """
136 | if self.pipeline is not None:
137 | del self.pipeline
138 | self.pipeline = None
139 | torch.cuda.empty_cache()
140 | gc.collect()
141 | logging.info(f"Pipeline for {self.model_id} unloaded from memory")
142 |
143 | def _call_pipeline(self, prompt: str):
144 | """
145 | Execute the pipeline with appropriate parameters based on model type.
146 |
147 | Args:
148 | prompt (str): The text prompt for image generation
149 |
150 | Returns:
151 | GenerationOutput: The pipeline output containing generated images
152 |
153 | Raises:
154 | ValueError: If the pipeline type is not supported
155 |
156 | Note:
157 | Clips the prompt to 300 characters for CLIP compatibility and uses
158 | the full prompt for secondary prompt parameters when available.
159 | """
160 | clip_prompt = prompt[:300]
161 | if isinstance(self.pipeline, StableDiffusion3Pipeline):
162 | return self.pipeline(
163 | prompt=clip_prompt,
164 | prompt_3=prompt,
165 | negative_prompt="bad anatomy, poorly drawn face, low resolution, blurry, artifacts, bad lighting, bad composition, cartoonish",
166 | num_inference_steps=self.steps,
167 | guidance_scale=self.scale,
168 | )
169 | elif isinstance(self.pipeline, FluxPipeline):
170 | return self.pipeline(
171 | prompt=clip_prompt,
172 | prompt_2=prompt,
173 | negative_prompt="bad anatomy, poorly drawn face, low resolution, blurry, artifacts, bad lighting, bad composition, cartoonish",
174 | num_inference_steps=self.steps,
175 | guidance_scale=self.scale,
176 | )
177 | else:
178 | raise ValueError("Unsupported pipeline type.")
179 |
180 | def generate_image(self, prompt: str, output_filename: str, output_dir: str) -> str:
181 | """
182 | Generate an image from a text prompt and save it to disk.
183 |
184 | Args:
185 | prompt (str): Text description for image generation
186 | output_filename (str): Filename for the saved image
187 | output_dir (str): Directory path where the image will be saved
188 |
189 | Returns:
190 | str: Full path to the saved image file
191 |
192 | Raises:
193 | Exception: Re-raises any exception that occurs during generation or saving
194 |
195 | Note:
196 | Creates the output directory if it doesn't exist and uses torch.no_grad()
197 | context for memory efficiency during inference.
198 | """
199 | try:
200 | self.load_pipeline()
201 | with torch.no_grad():
202 | result = self._call_pipeline(prompt)
203 | image = result.images[0]
204 | os.makedirs(output_dir, exist_ok=True)
205 | image_path = os.path.join(output_dir, output_filename)
206 | image.save(image_path)
207 | logging.info(f"Image saved to {image_path}")
208 | return image_path
209 | except Exception as e:
210 | logging.error(f"Error generating image {output_filename}: {e}")
211 | raise
212 |
213 | def __call__(self, prompt: str, output_filename: str, output_dir: str) -> str:
214 | """
215 | Make the wrapper callable, delegating to generate_image method.
216 |
217 | Args:
218 | prompt (str): Text description for image generation
219 | output_filename (str): Filename for the saved image
220 | output_dir (str): Directory path where the image will be saved
221 |
222 | Returns:
223 | str: Full path to the saved image file
224 | """
225 | return self.generate_image(prompt, output_filename, output_dir)
226 |
227 | @staticmethod
228 | def load_metadata_from_json(json_path: str):
229 | """
230 | Load metadata from a JSONL file containing prompt information.
231 |
232 | Args:
233 | json_path (str): Path to the JSONL file containing metadata
234 |
235 | Returns:
236 | list: List of dictionaries containing prompt metadata
237 |
238 | Note:
239 | Expects JSONL format where each line is a valid JSON object.
240 | Empty lines are automatically skipped.
241 | """
242 | with open(json_path, "r") as f:
243 | data = [json.loads(line.strip()) for line in f if line.strip()]
244 | return data
245 |
246 |
247 | def process_model_batch(model_config, metadata_list, base_output_dir, output_metadata):
248 | """
249 | Process all prompts for a single diffusion model in batch.
250 |
251 | This function handles the complete workflow for generating images with one model:
252 | initializing the wrapper, processing all prompts with progress tracking,
253 | collecting metadata, and proper cleanup.
254 |
255 | Args:
256 | model_config (dict): Configuration dictionary containing:
257 | - model_id (str): The model identifier
258 | - steps (int, optional): Number of inference steps
259 | - scale (float, optional): Guidance scale
260 | metadata_list (list): List of prompt metadata dictionaries
261 | base_output_dir (str): Base directory for saving images
262 | output_metadata (list): List to append generation metadata to
263 |
264 | Note:
265 | Automatically handles errors for individual prompts and continues processing.
266 | Ensures proper cleanup of GPU memory regardless of success or failure.
267 | """
268 | model_id = model_config["model_id"]
269 | steps = model_config.get("steps", 40)
270 | scale = model_config.get("scale", 4.5)
271 | safe_model_id = model_id.replace("/", "_")
272 | output_dir = os.path.join(base_output_dir, safe_model_id)
273 |
274 | logging.info(f"Processing {len(metadata_list)} prompts for model '{model_id}'")
275 |
276 | wrapper = None
277 | try:
278 | wrapper = DiffusionPipelineWrapper(model_id, steps, scale)
279 |
280 | # Process prompts with progress bar
281 | for idx, item in enumerate(tqdm(metadata_list, desc=f"Generating with {safe_model_id}")):
282 | prompt = item.get("generated_scene_description")
283 | if not prompt:
284 | logging.warning(f"Skipping prompt index {idx} due to missing description.")
285 | continue
286 |
287 | output_filename = f"image_{idx:03d}.png"
288 | try:
289 | image_path = wrapper.generate_image(prompt, output_filename, output_dir)
290 | record = {
291 | "model_id": model_id,
292 | "prompt": prompt,
293 | "image_path": image_path,
294 | "steps": steps,
295 | "scale": scale,
296 | "prompt_index": idx
297 | }
298 | output_metadata.append(record)
299 |
300 | except Exception as e:
301 | logging.error(f"Error generating image for prompt index {idx}: {e}")
302 | continue
303 |
304 | except Exception as e:
305 | logging.error(f"Failed to initialize wrapper for model {model_id}: {e}")
306 | return
307 | finally:
308 | if wrapper:
309 | wrapper.unload_pipeline()
310 |
311 |
312 | def save_metadata_batch(output_metadata, output_file):
313 | """
314 | Save generation metadata to a JSONL file in batch.
315 |
316 | Args:
317 | output_metadata (list): List of metadata dictionaries to save
318 | output_file (str): Path to the output JSONL file
319 |
320 | Note:
321 | Overwrites the output file if it exists. Each metadata record is written
322 | as a separate JSON line for easy parsing and streaming.
323 | """
324 | with open(output_file, "w") as f:
325 | for record in output_metadata:
326 | f.write(json.dumps(record) + "\n")
327 | logging.info(f"Saved {len(output_metadata)} records to '{output_file}'")
328 |
329 |
330 | def main():
331 | """
332 | Main execution function for the image generation pipeline.
333 |
334 | This function orchestrates the complete image generation workflow:
335 | 1. Loads prompt metadata from JSONL file
336 | 2. Processes each configured diffusion model sequentially
337 | 3. Generates images for all prompts with each model
338 | 4. Saves comprehensive metadata about the generation process
339 | 5. Handles memory optimization between models
340 |
341 | The function is designed to be memory-efficient for multi-GPU setups and
342 | provides comprehensive logging for monitoring long-running generation tasks.
343 |
344 | Configuration:
345 | - Supports multiple diffusion models processed sequentially
346 | - Automatic memory cleanup between models
347 | - Progress tracking with tqdm
348 | - Comprehensive error handling and logging
349 |
350 | Raises:
351 | Exception: Logs and handles any unexpected errors during execution
352 | """
353 |
354 | set_seed(42) # Set a fixed seed for reproducibility
355 |
356 | logging.info("\n\nStarting image generation process using DiffusionPipelineWrapper.")
357 | try:
358 | diffusion_models = [
359 | {"model_id": "black-forest-labs/FLUX.1-dev", "steps": 50, "scale": 7.5},
360 | {"model_id": "stabilityai/stable-diffusion-3.5-large", "steps": 50, "scale": 7.5},
361 | ]
362 | json_file = "output/prompts/claude3.7-prompt.jsonl"
363 | base_output_dir = "output/images_" + json_file.split("/")[-1].split("-")[0]
364 |
365 | metadata_list = DiffusionPipelineWrapper.load_metadata_from_json(json_file)
366 | if not metadata_list:
367 | logging.error("No metadata found in the JSON file.")
368 | return
369 |
370 | output_metadata = []
371 |
372 | # Process each model sequentially to optimize memory usage
373 | for model_config in diffusion_models:
374 | process_model_batch(model_config, metadata_list, base_output_dir, output_metadata)
375 |
376 | # Force garbage collection between models
377 | torch.cuda.empty_cache()
378 | gc.collect()
379 |
380 | # Save all metadata at once
381 | output_metadata_file = f"metadata_{json_file.split('/')[-1].split('-')[0]}.jsonl"
382 | save_metadata_batch(output_metadata, output_metadata_file)
383 |
384 | except Exception as err:
385 | logging.exception(f"An unexpected error occurred: {err}")
386 |
387 |
388 | if __name__ == "__main__":
389 | main()
390 |
--------------------------------------------------------------------------------
/src/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import gc
4 | import json
5 | import random
6 | import hashlib
7 | import pandas as pd
8 | import logging
9 | from typing import List, Dict, Optional
10 | from tqdm import tqdm
11 | from datasets import load_dataset, Dataset
12 | import torch
13 | from torch.utils.data import DataLoader
14 | from pathlib import Path
15 | from datetime import datetime
16 | import numpy as np
17 | from PIL import Image
18 |
19 | from utils.vlm_wrapper import VLMWrapper
20 |
21 | # Configure logging
22 | logging.basicConfig(
23 | level=logging.INFO,
24 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
25 | datefmt="%Y-%m-%d %H:%M:%S",
26 | handlers=[
27 | logging.StreamHandler()
28 | ]
29 | )
30 | logger = logging.getLogger(__name__)
31 |
32 | class EvaluationEngine:
33 | """Evaluation engine for all VLM types with enhanced features."""
34 |
35 | def __init__(self, model_id: str, device_map: str = "auto", seed: int = 42,
36 | use_cot: bool = False, one_shot_example: Optional[Dict] = None):
37 | """
38 | Initialize the evaluation engine.
39 |
40 | Args:
41 | model_id: HuggingFace model identifier
42 | device_map: Device mapping strategy
43 | seed: Random seed for reproducibility
44 | use_cot: Enable Chain-of-Thought prompting
45 | one_shot_example: One-shot example dictionary
46 | """
47 | self.model_id = model_id
48 | self.short_name = self._extract_model_name(model_id)
49 | self.device_map = device_map
50 | self.seed = seed
51 | self.use_cot = use_cot
52 | self.one_shot_example = one_shot_example
53 |
54 | # Initialize reproducibility
55 | self._set_seed(seed)
56 |
57 | # Lazy initialization
58 | self.vlm = None
59 |
60 | # Evaluation metadata
61 | self.eval_metadata = {
62 | "model_id": model_id,
63 | "seed": seed,
64 | "use_cot": use_cot,
65 | "has_one_shot": one_shot_example is not None,
66 | "timestamp": datetime.now().isoformat(),
67 | "torch_version": torch.__version__,
68 | }
69 |
70 | logger.info(f"Initialized EvaluationEngine for model: {model_id}")
71 | logger.info(f"Seed: {seed}, CoT: {use_cot}, One-shot: {one_shot_example is not None}")
72 |
73 | def _set_seed(self, seed: int):
74 | """Set all random seeds for reproducibility."""
75 | random.seed(seed)
76 | np.random.seed(seed)
77 | torch.manual_seed(seed)
78 | if torch.cuda.is_available():
79 | torch.cuda.manual_seed(seed)
80 | torch.cuda.manual_seed_all(seed)
81 | # Make cudnn deterministic
82 | torch.backends.cudnn.deterministic = True
83 | torch.backends.cudnn.benchmark = False
84 |
85 | # Set environment variables for reproducibility
86 | os.environ['PYTHONHASHSEED'] = str(seed)
87 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8' # For deterministic CUDA operations
88 |
89 | logger.info(f"Set all seeds to {seed} for reproducibility")
90 |
91 | def _extract_model_name(self, model_id: str) -> str:
92 | """Extract a clean model name for file naming."""
93 | return model_id.split("/")[-1].replace("-", "_")
94 |
95 | def _compute_dataset_hash(self, dataset_id: str, max_samples: Optional[int] = None) -> str:
96 | """Compute a hash of the dataset configuration for reproducibility tracking."""
97 | config_str = f"{dataset_id}_{max_samples}_{self.seed}"
98 | return hashlib.md5(config_str.encode()).hexdigest()[:8]
99 |
100 | def _load_model(self):
101 | """Load the VLM model lazily to save memory."""
102 | if self.vlm is None:
103 | logger.info(f"Loading model: {self.model_id}")
104 | self.vlm = VLMWrapper(self.model_id, self.device_map, dtype=torch.bfloat16)
105 | logger.info("Model loaded successfully")
106 |
107 | # Add model-specific metadata
108 | self.eval_metadata.update({
109 | "model_type": self.vlm.model_type,
110 | "inference_type": self.vlm.config.inference_type,
111 | "device_map": self.device_map,
112 | "dtype": str(self.vlm.dtype),
113 | })
114 |
115 | def _prepare_messages(self, questions: List[str], images: List[Image.Image]) -> List[List[Dict]]:
116 | """Prepare messages in the required format with optional CoT and one-shot examples."""
117 | messages = []
118 |
119 | for question, image in zip(questions, images):
120 | conversation = []
121 |
122 | # Add system message with model and strategy info
123 | system_message = {
124 | "role": "system",
125 | "content": [{
126 | "type": "text",
127 | "text": ("You are a spatial reasoning AI assistant specialized in analyzing, "
128 | "understanding, and solving problems involving spatial relationships, "
129 | "geometric transformations, and visual-spatial concepts."
130 | )
131 | }]
132 | }
133 | conversation.append(system_message)
134 |
135 | # Add one-shot example if provided
136 | if self.one_shot_example:
137 | # Add the example question with its image
138 | conversation.append({
139 | "role": "user",
140 | "content": [
141 | {"type": "image", "image": self.one_shot_example["image"]},
142 | {"type": "text", "text": self._format_question_with_cot(self.one_shot_example["question"])},
143 | ],
144 | })
145 |
146 | # Add the example response (with CoT reasoning if available)
147 | example_response = self.one_shot_example.get("reasoning", "")
148 | if example_response and self.use_cot:
149 | example_response += f"\n\nTherefore, the answer is: {self.one_shot_example['answer']}"
150 | else:
151 | example_response = self.one_shot_example["answer"]
152 |
153 | conversation.append({
154 | "role": "assistant",
155 | "content": example_response
156 | })
157 |
158 | # Add the actual question
159 | conversation.append({
160 | "role": "user",
161 | "content": [
162 | {"type": "image", "image": image},
163 | {"type": "text", "text": self._format_question_with_cot(question)},
164 | ],
165 | })
166 |
167 | messages.append(conversation)
168 |
169 | return messages
170 |
171 | def _format_question_with_cot(self, question: str) -> str:
172 | """Format question with Chain-of-Thought prompting if enabled."""
173 | if not self.use_cot:
174 | return question.strip()
175 |
176 | cot_prompt = (
177 | "Please think step by step and explain your reasoning before providing answering the question.\n"
178 | "Provide the final answer at the end of your reasoning in curly bracket, e.g., {A} or {yes}.\n\n"
179 | f"Question: {question.strip()}\n\n"
180 | )
181 |
182 | return cot_prompt
183 |
184 | def _process_batch(self, batch: Dict[str, List], batch_idx: int, total_batches: int) -> List[Dict[str, str]]:
185 | """
186 | Process a single batch of examples with unified interface.
187 |
188 | Args:
189 | batch: Dictionary containing batch data
190 | batch_idx: Current batch index
191 | total_batches: Total number of batches
192 |
193 | Returns:
194 | List of results for this batch
195 | """
196 | try:
197 | batch_size = len(batch["question"])
198 | logger.debug(f"Processing batch {batch_idx + 1}/{total_batches} with {batch_size} examples")
199 |
200 | # Prepare inputs
201 | messages = self._prepare_messages(batch["question"], batch["image"])
202 |
203 | # Unified batch processing leveraging unified preprocessing in VLMWrapper
204 | results = self._process_batch_standard(messages, batch, batch_idx)
205 |
206 | return results
207 |
208 | except Exception as e:
209 | logger.error(f"Error processing batch {batch_idx + 1}: {e}", exc_info=True)
210 | # Return empty results for failed batch
211 | return [{
212 | "response": f"ERROR: {str(e)}",
213 | "answer": answer,
214 | "question": question,
215 | "raw_response": f"ERROR: {str(e)}",
216 | "batch_idx": batch_idx,
217 | "example_idx": idx,
218 | } for idx, (question, answer) in enumerate(zip(batch["question"], batch["answer"]))]
219 |
220 | def _process_batch_standard(self, messages: List[List[Dict]], batch: Dict[str, List], batch_idx: int) -> List[Dict[str, str]]:
221 | """Process batch for standard VLM models."""
222 | results = []
223 |
224 | with self.vlm.memory_efficient_mode():
225 | # Preprocess the entire batch
226 | inputs = self.vlm.preprocess(conversation=messages, image_input=batch["image"])
227 |
228 |
229 | generated_ids = self.vlm.generate(
230 | inputs,
231 | max_new_tokens=1024,
232 | do_sample=False
233 | )
234 |
235 | output_texts = self.vlm.decode(generated_ids)
236 |
237 | del inputs, generated_ids # Free memory
238 |
239 | # Process results for each example in the batch
240 | for idx, (raw_prediction, question, ground_truth) in enumerate(zip(output_texts, batch["question"], batch["answer"])):
241 | results.append({
242 | "question": question,
243 | "response": raw_prediction,
244 | "gold answer": ground_truth,
245 | "batch_idx": batch_idx,
246 | "example_idx": idx,
247 | })
248 |
249 | return results
250 |
251 | def evaluate(self, dataset_id: str, batch_size: int = 16, max_samples: Optional[int] = None,
252 | sample_strategy: str = "first", num_workers: int = 4) -> str:
253 | """
254 | Evaluate the model on the specified dataset with reproducible sampling.
255 |
256 | Args:
257 | dataset_id: HuggingFace dataset identifier
258 | batch_size: Number of examples per batch
259 | max_samples: Maximum number of samples to process
260 | sample_strategy: Sampling strategy ("first", "random", "stratified")
261 | num_workers: Number of parallel data loading workers
262 |
263 | Returns:
264 | Path to the saved results CSV file
265 | """
266 | # Load model lazily
267 | self._load_model()
268 |
269 | # Load and sample dataset
270 | try:
271 | data = load_dataset(dataset_id, split="train")
272 | original_size = len(data)
273 |
274 | # Apply sampling strategy
275 | if max_samples and max_samples < original_size:
276 | data = self._sample_dataset(data, max_samples, sample_strategy)
277 |
278 | logger.info(f"Loaded dataset: {original_size} total, {len(data)} selected samples")
279 |
280 | # Add dataset hash to metadata
281 | self.eval_metadata.update({
282 | "dataset_id": dataset_id,
283 | "dataset_hash": self._compute_dataset_hash(dataset_id, max_samples),
284 | "original_size": original_size,
285 | "sampled_size": len(data),
286 | "sample_strategy": sample_strategy,
287 | "max_samples": max_samples,
288 | })
289 |
290 | except Exception as e:
291 | logger.error(f"Failed to load dataset {dataset_id}: {e}")
292 | raise
293 |
294 | # Validate dataset format
295 | required_columns = ["question", "answer", "image"]
296 | missing_columns = [col for col in required_columns if col not in data.column_names]
297 | if missing_columns:
298 | raise ValueError(f"Dataset missing required columns: {missing_columns}")
299 |
300 | # Convert to iterable format for efficient batching
301 | # Use HuggingFace's built-in batching which is more memory efficient
302 | data = data.with_format("python") # Ensure consistent format
303 |
304 | # Process dataset in batches using DataLoader for prefetching
305 | all_results = []
306 | total_batches = (len(data) + batch_size - 1) // batch_size
307 |
308 | # Create DataLoader with prefetching for parallel data loading
309 | # Note: prefetch_factor only valid when num_workers > 0
310 | dataloader_kwargs = {
311 | "batch_size": batch_size,
312 | "shuffle": False,
313 | "num_workers": num_workers,
314 | "pin_memory": torch.cuda.is_available(),
315 | "collate_fn": self._collate_batch,
316 | }
317 | if num_workers > 0:
318 | dataloader_kwargs["prefetch_factor"] = 2
319 |
320 | dataloader = DataLoader(data, **dataloader_kwargs)
321 | with tqdm(total=total_batches, desc="Processing batches", colour="green") as pbar:
322 | for batch_idx, batch in enumerate(dataloader):
323 | batch_results = self._process_batch(batch, batch_idx, total_batches)
324 | all_results.extend(batch_results)
325 |
326 | pbar.update(1)
327 | pbar.set_postfix({
328 | "processed": len(all_results),
329 | "memory": f"{torch.cuda.memory_allocated() / 1e9:.1f}GB" if torch.cuda.is_available() else "N/A",
330 | "success_rate": f"{sum(1 for r in all_results if not r['response'].startswith('ERROR')) / len(all_results) * 100:.1f}%"
331 | })
332 |
333 | gc.collect()
334 | if torch.cuda.is_available():
335 | torch.cuda.empty_cache()
336 |
337 | # Save results with metadata
338 | output_path = self._save_results_with_metadata(all_results, dataset_id)
339 | logger.info(f"Evaluation completed. Results saved to: {output_path}")
340 |
341 | return output_path
342 |
343 | def _collate_batch(self, batch: List[Dict]) -> Dict[str, List]:
344 | """
345 | Custom collate function to convert list of dicts to dict of lists.
346 | Handles PIL images properly without tensor conversion.
347 | """
348 | collated = {
349 | "question": [item["question"] for item in batch],
350 | "answer": [item["answer"] for item in batch],
351 | "image": [item["image"] for item in batch],
352 | }
353 | return collated
354 |
355 | def _sample_dataset(self, dataset: Dataset, max_samples: int, strategy: str) -> Dataset:
356 | """Sample dataset using specified strategy for reproducibility."""
357 | if strategy == "first":
358 | return dataset.select(range(max_samples))
359 | elif strategy == "random":
360 | indices = list(range(len(dataset)))
361 | random.shuffle(indices) # Uses the set seed
362 | selected_indices = indices[:max_samples]
363 | return dataset.select(selected_indices)
364 | elif strategy == "stratified":
365 | # Try to stratify by answer if possible, otherwise fall back to random
366 | try:
367 | answers = dataset["answer"]
368 | unique_answers = list(set(answers))
369 | samples_per_answer = max_samples // len(unique_answers)
370 |
371 | selected_indices = []
372 | for answer in unique_answers:
373 | answer_indices = [i for i, a in enumerate(answers) if a == answer]
374 | random.shuffle(answer_indices)
375 | selected_indices.extend(answer_indices[:samples_per_answer])
376 |
377 | # Fill remaining slots randomly
378 | remaining = max_samples - len(selected_indices)
379 | if remaining > 0:
380 | all_indices = set(range(len(dataset)))
381 | available_indices = list(all_indices - set(selected_indices))
382 | random.shuffle(available_indices)
383 | selected_indices.extend(available_indices[:remaining])
384 |
385 | return dataset.select(selected_indices[:max_samples])
386 | except Exception as e:
387 | logger.warning(f"Stratified sampling failed, falling back to random: {e}")
388 | return self._sample_dataset(dataset, max_samples, "random")
389 | else:
390 | raise ValueError(f"Unknown sampling strategy: {strategy}")
391 |
392 | def _save_results_with_metadata(self, results: List[Dict[str, str]], dataset_id: str) -> str:
393 | """Save results with comprehensive metadata for reproducibility."""
394 | dataset_name = dataset_id.split("/")[-1]
395 |
396 | # Create directory structure with prompting strategy
397 | strategy_dir = self._get_strategy_directory_name()
398 | output_dir = Path("output/evaluations") / dataset_name / strategy_dir
399 | output_dir.mkdir(parents=True, exist_ok=True)
400 |
401 | # Save main results
402 | results_df = pd.DataFrame(results)
403 | output_path = output_dir / f"{self.short_name}.csv"
404 |
405 | try:
406 | results_df.to_csv(output_path, index=False)
407 | logger.info(f"Successfully saved {len(results)} results to {output_path}")
408 | except Exception as e:
409 | logger.error(f"Failed to save results: {e}")
410 | raise
411 |
412 | # Save detailed metadata for reproducibility
413 | metadata_path = output_dir / f"{self.short_name}_metadata.json"
414 |
415 | # Add evaluation statistics to metadata
416 | successful_responses = sum(1 for r in results if not r['response'].startswith('ERROR'))
417 | self.eval_metadata.update({
418 | "total_samples": len(results),
419 | "successful_responses": successful_responses,
420 | "success_rate": successful_responses / len(results) * 100,
421 | "average_response_length": sum(len(r['response']) for r in results) / len(results),
422 | "output_path": str(output_path),
423 | "evaluation_completed_at": datetime.now().isoformat(),
424 | })
425 |
426 | with open(metadata_path, 'w') as f:
427 | json.dump(self.eval_metadata, f, indent=2, default=str)
428 |
429 | logger.info(f"Saved metadata to {metadata_path}")
430 |
431 | return str(output_path)
432 |
433 | def _get_strategy_directory_name(self) -> str:
434 | """Generate directory name based on prompting strategy."""
435 | if self.use_cot and self.one_shot_example:
436 | return "cot_oneshot"
437 | elif self.use_cot:
438 | return "cot_test"
439 | elif self.one_shot_example:
440 | return "oneshot"
441 | else:
442 | return "_normal"
443 |
444 |
445 | def load_one_shot_example(json_path: str) -> Optional[Dict]:
446 | """Load one-shot example from JSON file with image loading."""
447 | if not json_path:
448 | return None
449 |
450 | # Handle relative paths for JSON file
451 | if not os.path.isabs(json_path):
452 | json_path = os.path.join('/users/stogian/srbench', json_path)
453 |
454 | if not os.path.exists(json_path):
455 | logger.warning(f"One-shot JSON file not found: {json_path}")
456 | return None
457 |
458 | try:
459 | with open(json_path, 'r') as f:
460 | data = json.load(f)
461 |
462 | # Handle image path with multiple fallbacks
463 | image_path = data.get('image_path', '')
464 | if not image_path:
465 | logger.warning("No image_path found in one-shot example")
466 | return None
467 |
468 | # Try multiple path variations
469 | possible_paths = [
470 | image_path,
471 | os.path.join(os.path.dirname(json_path), os.path.basename(image_path)),
472 | os.path.join('/users/stogian/srbench', image_path.lstrip('./')),
473 | os.path.join('/users/stogian/srbench/example', os.path.basename(image_path)),
474 | ]
475 |
476 | image_loaded = False
477 | for path in possible_paths:
478 | if os.path.exists(path):
479 | try:
480 | data['image'] = Image.open(path).convert('RGB')
481 | logger.info(f"Successfully loaded one-shot image from: {path}")
482 | image_loaded = True
483 | break
484 | except Exception as e:
485 | logger.warning(f"Failed to load image from {path}: {e}")
486 | continue
487 |
488 | if not image_loaded:
489 | logger.error(f"Could not load image from any of these paths: {possible_paths}")
490 | return None
491 |
492 | # Validate required keys
493 | required_keys = ['question', 'answer']
494 | if not all(key in data for key in required_keys):
495 | logger.warning(f"One-shot example missing required keys: {required_keys}")
496 | return None
497 |
498 | logger.info(f"Loaded one-shot example: {data['question'][:50]}...")
499 | return data
500 |
501 | except Exception as e:
502 | logger.error(f"Failed to load one-shot example: {e}")
503 | return None
504 |
505 |
506 | def parse_args():
507 | """Parse command-line arguments with comprehensive options."""
508 | parser = argparse.ArgumentParser(
509 | description="Reproducible evaluation of vision-language models",
510 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
511 | )
512 | parser.add_argument(
513 | "-m", "--model",
514 | type=str, required=True,
515 | help="Model identifier (e.g., 'meta-llama/Llama-3.2-90B-Vision-Instruct')"
516 | )
517 | parser.add_argument(
518 | "-d", "--dataset",
519 | type=str, required=True,
520 | help="Dataset identifier (e.g., 'stogian/sr_test')"
521 | )
522 | parser.add_argument(
523 | "-b", "--batch_size",
524 | type=int, default=16,
525 | help="Batch size for processing"
526 | )
527 | parser.add_argument(
528 | "--num_workers",
529 | type=int, default=4,
530 | help="Number of data loading workers (0 for main process only)"
531 | )
532 | parser.add_argument(
533 | "--max_samples",
534 | type=int, default=None,
535 | help="Maximum samples to process (for testing)"
536 | )
537 | parser.add_argument(
538 | "--sample_strategy",
539 | type=str, default="first", choices=["first", "random", "stratified"],
540 | help="Sampling strategy for subset selection"
541 | )
542 | parser.add_argument(
543 | "--device_map",
544 | type=str, default="auto",
545 | help="Device mapping strategy"
546 | )
547 | parser.add_argument(
548 | "--seed",
549 | type=int, default=42,
550 | help="Random seed for reproducibility"
551 | )
552 | parser.add_argument(
553 | "--cot", "--chain-of-thought",
554 | action="store_true",
555 | help="Enable Chain-of-Thought prompting"
556 | )
557 | parser.add_argument(
558 | "--one-shot",
559 | type=str, default=None,
560 | help="Path to one-shot example JSON file"
561 | )
562 | parser.add_argument(
563 | "-v", "--verbose",
564 | action="store_true",
565 | help="Enable verbose logging"
566 | )
567 |
568 | return parser.parse_args()
569 |
570 |
571 | def main():
572 | """Main function with comprehensive error handling."""
573 |
574 | # print gpu info
575 | if torch.cuda.is_available():
576 | gpu_count = torch.cuda.device_count()
577 | print(f"Found {gpu_count} GPU(s):")
578 | for i in range(gpu_count):
579 | print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({torch.cuda.get_device_properties(i).total_memory / 1e9:.1f}GB)")
580 | print(f"CUDA version: {torch.version.cuda}")
581 | else:
582 | print("No GPU available, using CPU")
583 |
584 | args = parse_args()
585 |
586 | if args.verbose:
587 | logging.getLogger().setLevel(logging.DEBUG)
588 |
589 | # Load one-shot example if provided
590 | one_shot_example = load_one_shot_example(args.one_shot) if args.one_shot else None
591 |
592 | try:
593 | # Create reproducible evaluation engine
594 | eval_engine = EvaluationEngine(
595 | model_id=args.model,
596 | device_map=args.device_map,
597 | seed=args.seed,
598 | use_cot=args.cot,
599 | one_shot_example=one_shot_example
600 | )
601 |
602 | output_path = eval_engine.evaluate(
603 | dataset_id=args.dataset,
604 | batch_size=args.batch_size,
605 | max_samples=args.max_samples,
606 | sample_strategy=args.sample_strategy,
607 | num_workers=args.num_workers
608 | )
609 |
610 | print("\n✅ Reproducible evaluation completed successfully!")
611 | print(f"📁 Results saved to: {output_path}")
612 | print(f"🔄 Evaluation can be reproduced using seed: {args.seed}")
613 |
614 | except KeyboardInterrupt:
615 | logger.info("Evaluation interrupted by user")
616 | print("\n⚠️ Evaluation interrupted by user")
617 | except Exception as e:
618 | logger.error(f"Evaluation failed: {e}", exc_info=True)
619 | print(f"\n❌ Evaluation failed: {e}")
620 | raise
621 |
622 |
623 | if __name__ == "__main__":
624 | main()
--------------------------------------------------------------------------------
/src/utils/vlm_wrapper.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import requests
4 | import logging
5 | import gc
6 | import numpy as np
7 | from typing import List, Dict, Any, Optional, Union, Tuple
8 | from dataclasses import dataclass
9 | from functools import lru_cache
10 | from contextlib import contextmanager
11 | from PIL import Image
12 | import torchvision.transforms as T
13 | from torchvision.transforms.functional import InterpolationMode
14 | from transformers import (
15 | Qwen2VLForConditionalGeneration,
16 | Qwen2_5_VLForConditionalGeneration,
17 | AutoProcessor,
18 | LlavaForConditionalGeneration,
19 | LlavaNextForConditionalGeneration,
20 | AutoModelForCausalLM,
21 | GenerationConfig,
22 | AutoModelForVision2Seq,
23 | AutoModelForImageTextToText,
24 | MllamaForConditionalGeneration,
25 | AutoModel,
26 | AutoTokenizer,
27 | Gemma3ForConditionalGeneration,
28 | Glm4vForConditionalGeneration
29 | )
30 | from qwen_vl_utils import process_vision_info
31 |
32 | # Configure logging
33 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
34 | logger = logging.getLogger(__name__)
35 |
36 | # Constants for InternVL
37 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
38 | IMAGENET_STD = (0.229, 0.224, 0.225)
39 |
40 | @dataclass
41 | class ModelConfig:
42 | """Configuration for different VLM models."""
43 | model_class: type
44 | processor_class: type
45 | requires_trust_remote_code: bool = False
46 | supports_flash_attention: bool = False
47 | padding_side: str = "left"
48 | special_args: Dict[str, Any] = None
49 | processor_args: Dict[str, Any] = None
50 | inference_type: str = "standard" # standard, internvl, minicpm
51 |
52 | def __post_init__(self):
53 | if self.special_args is None:
54 | self.special_args = {}
55 | if self.processor_args is None:
56 | self.processor_args = {}
57 |
58 | # Model configurations registry
59 | MODEL_CONFIGS = {
60 | "qwen": ModelConfig(
61 | model_class=Qwen2_5_VLForConditionalGeneration,
62 | processor_class=AutoProcessor,
63 | supports_flash_attention=True,
64 | processor_args={"use_fast": True},
65 | ),
66 | "kimi": ModelConfig(
67 | model_class=AutoModelForCausalLM,
68 | processor_class=AutoProcessor,
69 | requires_trust_remote_code=True,
70 | supports_flash_attention=True,
71 | processor_args={"use_fast": True, "padding_side": "left"},
72 | ),
73 | "llava": ModelConfig(
74 | model_class=LlavaForConditionalGeneration,
75 | processor_class=AutoProcessor,
76 | special_args={"low_cpu_mem_usage": True},
77 | processor_args={"use_fast": True}
78 | ),
79 | "llava_next": ModelConfig(
80 | model_class=LlavaNextForConditionalGeneration,
81 | processor_class=AutoProcessor,
82 | special_args={"low_cpu_mem_usage": True},
83 | processor_args={"use_fast": True}
84 | ),
85 | "idefics": ModelConfig(
86 | model_class=AutoModelForImageTextToText,
87 | processor_class=AutoProcessor,
88 | supports_flash_attention=True,
89 | processor_args={"use_fast": True}
90 | ),
91 | "smolvlm": ModelConfig(
92 | model_class=AutoModelForImageTextToText,
93 | processor_class=AutoProcessor,
94 | supports_flash_attention=True,
95 | special_args={},
96 | processor_args={"use_fast": True, "padding_side": "left"}
97 | ),
98 | "mllama": ModelConfig(
99 | model_class=MllamaForConditionalGeneration,
100 | processor_class=AutoProcessor,
101 | processor_args={"use_fast": True, "padding_side": "left"}
102 | ),
103 | "minicpm": ModelConfig(
104 | model_class=AutoModel,
105 | processor_class=AutoTokenizer,
106 | requires_trust_remote_code=True,
107 | supports_flash_attention=True,
108 | inference_type="minicpm"
109 | ),
110 | "internvl": ModelConfig(
111 | model_class=AutoModel,
112 | processor_class=AutoTokenizer,
113 | requires_trust_remote_code=True,
114 | supports_flash_attention=True,
115 | inference_type="internvl"
116 | ),
117 | "internvl_hf": ModelConfig(
118 | model_class=AutoModelForImageTextToText,
119 | processor_class=AutoProcessor,
120 | requires_trust_remote_code=True,
121 | supports_flash_attention=True,
122 | processor_args={"use_fast": True},
123 | inference_type="standard" # Use standard HF generation
124 | ),
125 | "gemma3": ModelConfig(
126 | model_class=Gemma3ForConditionalGeneration,
127 | processor_class=AutoProcessor,
128 | supports_flash_attention=True,
129 | processor_args={"use_fast": True},
130 | ),
131 | "glm4v": ModelConfig(
132 | model_class=Glm4vForConditionalGeneration,
133 | processor_class=AutoProcessor,
134 | requires_trust_remote_code=True,
135 | supports_flash_attention=True,
136 | padding_side="left",
137 | processor_args={"use_fast": True}
138 | ),
139 | }
140 |
141 | class VLMWrapper:
142 | """Unified Vision-Language Model wrapper supporting all model types."""
143 |
144 | def __init__(self, model_id: str, device_map: str = "auto", dtype: torch.dtype = torch.bfloat16):
145 | """
146 | Initialize unified VLM wrapper with automatic model type detection.
147 |
148 | Args:
149 | model_id: HuggingFace model identifier
150 | device_map: Device mapping strategy
151 | dtype: Model precision
152 | """
153 | self.model_id = model_id
154 | self.model_type = self._detect_model_type(model_id)
155 | self.config = MODEL_CONFIGS[self.model_type]
156 | self.dtype = dtype
157 | self.device_map = self._optimize_device_map(device_map)
158 |
159 | # Lazy initialization
160 | self._model = None
161 | self._processor = None
162 | self._device = None
163 | self._transform = None # For InternVL
164 |
165 | logger.info(f"Initialized VLMWrapper: for {self.model_type} model: {model_id}")
166 |
167 | @property
168 | def model(self):
169 | """Lazy loading of model."""
170 | if self._model is None:
171 | self._model = self._load_model()
172 | return self._model
173 |
174 | @property
175 | def processor(self):
176 | """Lazy loading of processor."""
177 | if self._processor is None:
178 | self._processor = self._load_processor()
179 | return self._processor
180 |
181 | @property
182 | def device(self):
183 | """Get model device."""
184 | if self._device is None:
185 | self._device = self.model.device
186 | return self._device
187 |
188 | @property
189 | def transform(self):
190 | """Lazy load image transform for InternVL."""
191 | if self._transform is None and self.model_type == "internvl":
192 | self._transform = self._build_transform(input_size=448)
193 | return self._transform
194 |
195 | def _optimize_device_map(self, device_map: str) -> str:
196 | """Optimize device mapping based on available hardware."""
197 | if device_map == "auto":
198 | if torch.cuda.is_available():
199 | gpu_count = torch.cuda.device_count()
200 | if gpu_count > 1:
201 | logger.info(f"Multiple GPUs detected ({gpu_count}), using auto device mapping")
202 | return "auto"
203 | else:
204 | return "cuda:0"
205 | else:
206 | logger.warning("CUDA not available, using CPU")
207 | return "cpu"
208 | return device_map
209 |
210 | def _detect_model_type(self, model_id: str) -> str:
211 | """Detect model type from model_id."""
212 | # Check for specific patterns first (more specific before generic)
213 | model_patterns = {
214 | "internvl_hf": r"OpenGVLab/InternVL.*-HF", # HF-native InternVL models (InternVL3-HF, etc.)
215 | "qwen": r"Qwen/",
216 | "llava": r"llava-hf/llava-1\.5",
217 | "llava_next": r"llava-hf/llava-v1\.6",
218 | "idefics": r"HuggingFaceM4/Idefics",
219 | "smolvlm": r"HuggingFaceTB/SmolVLM",
220 | "mllama": r"meta-llama",
221 | "minicpm": r"openbmb/MiniCPM",
222 | # "internvl": r"OpenGVLab/InternVL", # Original InternVL with batch_chat
223 | "gemma3": r"google/gemma-3",
224 | "kimi": r"moonshotai/Kimi-VL",
225 | "glm4v": r"zai-org/GLM-4",
226 | }
227 |
228 | for model_type, pattern in model_patterns.items():
229 | if re.search(pattern, model_id):
230 | return model_type
231 |
232 | raise ValueError(f"Unsupported model_id: {model_id}")
233 |
234 | def _load_model(self):
235 | """Load model with optimized settings."""
236 | try:
237 | model_args = {
238 | "torch_dtype": self.dtype,
239 | "device_map": self.device_map,
240 | }
241 |
242 | if self.config.requires_trust_remote_code:
243 | model_args["trust_remote_code"] = True
244 |
245 | if self.config.supports_flash_attention and torch.cuda.is_available():
246 | model_args["attn_implementation"] = "flash_attention_2"
247 |
248 | # Add special arguments
249 | model_args.update(self.config.special_args)
250 |
251 | model = self.config.model_class.from_pretrained(self.model_id, **model_args)
252 | # model = torch.compile(model, mode="default") # Compile model for performance
253 | return model.eval()
254 |
255 | except Exception as e:
256 | logger.error(f"Failed to load model {self.model_id}: {e}")
257 | raise
258 |
259 | def _load_processor(self):
260 | """Load processor with optimized settings."""
261 | try:
262 | processor_args = {}
263 |
264 | if self.config.requires_trust_remote_code:
265 | processor_args["trust_remote_code"] = True
266 |
267 | # Add special arguments
268 | processor_args.update(self.config.processor_args)
269 |
270 | processor = self.config.processor_class.from_pretrained(self.model_id, **processor_args)
271 |
272 | # Configure padding
273 | if hasattr(processor, "tokenizer"):
274 | # Ensure padding_side
275 | processor.tokenizer.padding_side = self.config.padding_side
276 | # Ensure pad_token_id exists; fallback to eos if missing
277 | if getattr(processor.tokenizer, "pad_token_id", None) is None:
278 | eos_id = getattr(processor.tokenizer, "eos_token_id", None)
279 | if eos_id is not None:
280 | processor.tokenizer.pad_token_id = eos_id
281 |
282 | # Set pad token on model generation config if available
283 | if hasattr(processor, "tokenizer") and hasattr(self.model, "generation_config"):
284 | if getattr(self.model.generation_config, "pad_token_id", None) is None:
285 | self.model.generation_config.pad_token_id = processor.tokenizer.pad_token_id
286 | else:
287 | self.model.generation_config.pad_token_id = processor.tokenizer.pad_token_id
288 |
289 | return processor
290 |
291 | except Exception as e:
292 | logger.error(f"Failed to load processor for {self.model_id}: {e}")
293 | raise
294 |
295 | def _build_transform(self, input_size: int = 448) -> T.Compose:
296 | """Create optimized image transform pipeline for InternVL."""
297 | return T.Compose([
298 | T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
299 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
300 | T.ToTensor(),
301 | T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
302 | ])
303 |
304 | @lru_cache(maxsize=128)
305 | def load_image_from_url(self, image_url: str) -> Image.Image:
306 | """Load image from URL with caching."""
307 | try:
308 | response = requests.get(image_url, stream=True, timeout=10)
309 | response.raise_for_status()
310 | return Image.open(response.raw)
311 | except Exception as e:
312 | logger.error(f"Failed to load image from {image_url}: {e}")
313 | raise
314 |
315 | def preprocess(self, conversation: List, image_input: Optional[Union[Image.Image, List[Image.Image]]] = None) -> Any:
316 | # Unified preprocessing for all inference types
317 | return self._preprocess_standard(conversation, image_input)
318 |
319 |
320 | def _preprocess_standard(self, batch_conversations: List, batch_images: Optional[Union[Image.Image, List[Image.Image]]] = None) -> Any:
321 | """Unified preprocessing for all VLM models, including internvl and minicpm."""
322 | if self.config.inference_type == "internvl":
323 | return self._preprocess_internvl(batch_conversations, batch_images)
324 | if self.config.inference_type == "minicpm":
325 | return self._preprocess_minicpm(batch_conversations, batch_images)
326 | # batch_conversations, batch_images = self._normalize_inputs(conversation, image_input)
327 | try:
328 | return self._preprocess(batch_conversations, batch_images)
329 |
330 | except Exception as e:
331 | logger.error(f"Preprocessing failed for {self.model_type}: {e}", exc_info=True)
332 | raise
333 |
334 | def _preprocess(self, batch_conversations: List[List], batch_images: Optional[Union[Image.Image, List[Image.Image]]]) -> Dict[str, torch.Tensor]:
335 | """Preprocess input data for the model.
336 | Args:
337 | batch_conversations: List of conversations, each a list of turns.
338 | batch_images: List of images corresponding to each conversation.
339 | Returns:
340 | Dictionary of preprocessed inputs ready for model inference.
341 | """
342 |
343 | # Apply chat template to each conversation
344 | prompts = [
345 | self.processor.apply_chat_template(
346 | conv,
347 | add_generation_prompt=True,
348 | tokenize=False
349 | )
350 | for conv in batch_conversations
351 | ]
352 |
353 | if self.model_type in ["mllama", "smolvlm", "idefics", "gemma3", "glm4v", "internvl_hf"]:
354 | images_to_process = [[img] for img in batch_images]
355 | else:
356 | images_to_process = batch_images
357 |
358 | assert len(prompts) == len(images_to_process), "Number of prompts must match number of image inputs"
359 |
360 | # Preprocess inputs
361 | inputs = self.processor(
362 | text=prompts,
363 | images=images_to_process,
364 | return_tensors="pt",
365 | padding=True,
366 | ).to(self.device, dtype=self.dtype)
367 |
368 |
369 | return inputs
370 |
371 |
372 | def _preprocess_internvl(self, conversation: List, image_input: Optional[Union[Image.Image, List[Image.Image]]] = None) -> Dict[str, Any]:
373 | """
374 | Robust preprocessing for InternVL models.
375 | - Supports batch or single conversations
376 | - Extracts the last user turn to avoid mixing one-shot examples
377 | - Builds a per-sample list of pixel tensors (later concatenated for batch_chat)
378 | - Ensures input dtype matches model parameters
379 | """
380 | # Normalize to list of conversations
381 | conversation_list = conversation if isinstance(conversation[0], list) else [conversation]
382 |
383 | questions: List[str] = []
384 | pixel_values_list: List[torch.Tensor] = []
385 | num_patches_list: List[int] = []
386 |
387 | # Determine model param dtype to avoid dtype mismatch (e.g., bf16 vs fp16)
388 | try:
389 | model_param_dtype = next(self.model.parameters()).dtype
390 | except Exception:
391 | model_param_dtype = self.dtype
392 |
393 | for conv in conversation_list:
394 | # Find last user turn
395 | user_turns = [turn for turn in conv if isinstance(turn, dict) and turn.get("role") == "user"]
396 | if not user_turns:
397 | raise ValueError("No user turn found in conversation for InternVL")
398 | last_user = user_turns[-1]
399 |
400 | # Extract image and text
401 | image = None
402 | text_parts: List[str] = []
403 | for item in last_user.get('content', []):
404 | if isinstance(item, dict) and 'image' in item and image is None:
405 | image = item['image']
406 | elif isinstance(item, dict) and 'text' in item:
407 | text_parts.append(item['text'])
408 |
409 | if image is None and image_input is not None:
410 | # Fallback if image not embedded
411 | image = image_input[0] if isinstance(image_input, list) and image_input else image_input
412 | if image is None:
413 | raise ValueError("No image found in conversation content for InternVL")
414 |
415 | question = " ".join(text_parts).strip()
416 | questions.append(question)
417 |
418 | # Load and preprocess image into patches (N, 3, 448, 448)
419 | patches = self._load_image_internvl(image, max_num=12)
420 | # Match dtype/device to model parameters to avoid conv2d dtype mismatch
421 | patches = patches.to(device=self.device, dtype=model_param_dtype)
422 | pixel_values_list.append(patches)
423 | num_patches_list.append(patches.size(0))
424 |
425 | return {
426 | "pixel_values": pixel_values_list, # keep per-sample; we will concat at generation
427 | "questions": questions,
428 | "num_patches_list": num_patches_list,
429 | }
430 |
431 | def _preprocess_minicpm(self, conversation: List, image_input: Optional[Union[Image.Image, List[Image.Image]]] = None) -> Tuple:
432 | """Preprocessing for MiniCPM models. Extract the last user turn, then find image/text by keys."""
433 | # Normalize to list of conversations
434 | conv_list = conversation if isinstance(conversation[0], list) else [conversation]
435 | msgs_batch: List[List[Dict[str, Any]]] = []
436 | for conv in conv_list:
437 | # Find last user turn (to avoid one-shot example turns)
438 | user_turns = [turn for turn in conv if isinstance(turn, dict) and turn.get("role") == "user"]
439 | if not user_turns:
440 | raise ValueError("No user turn found in conversation")
441 | last_user = user_turns[-1]
442 | # Extract image and text from the content list, regardless of order
443 | img = None
444 | text = ""
445 | for item in last_user.get("content", []):
446 | if isinstance(item, dict) and "image" in item and img is None:
447 | img = item["image"]
448 | elif isinstance(item, dict) and "text" in item and not text:
449 | text = item["text"]
450 | # Fallback to image_input if not embedded in conversation
451 | if img is None and image_input is not None:
452 | if isinstance(image_input, list) and len(image_input) > 0:
453 | img = image_input[0]
454 | else:
455 | img = image_input
456 | if img is None:
457 | # Mirror the KeyError seen in logs for clarity
458 | raise KeyError("image")
459 | # Prepare image and build msgs
460 | np_img = self._prepare_image_minicpm(img)
461 | msgs = [{"role": "user", "content": [np_img, text]}]
462 | msgs_batch.append(msgs)
463 | # Return batch or single
464 | return msgs_batch if len(msgs_batch) > 1 else msgs_batch[0]
465 |
466 | def _load_image_internvl(self, image_input: Any, input_size: int = 448, max_num: int = 12) -> torch.Tensor:
467 | """Load and preprocess image for InternVL with dynamic preprocessing."""
468 | try:
469 | # Handle different image input types
470 | if isinstance(image_input, str):
471 | image = Image.open(image_input).convert("RGB")
472 | elif isinstance(image_input, dict) and "path" in image_input:
473 | image = Image.open(image_input["path"]).convert("RGB")
474 | elif isinstance(image_input, Image.Image):
475 | image = image_input.convert("RGB")
476 | else:
477 | raise ValueError(f"Unsupported image input type: {type(image_input)}")
478 |
479 | # Dynamic preprocessing
480 | images = self._dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
481 |
482 | # Apply transforms
483 | pixel_values = [self.transform(img) for img in images]
484 | return torch.stack(pixel_values)
485 |
486 | except Exception as e:
487 | logger.error(f"Failed to load image for InternVL: {e}")
488 | raise
489 |
490 | def _prepare_image_minicpm(self, image_input: Union[str, Image.Image, Dict]) -> np.ndarray:
491 | """Prepare image for MiniCPM processing."""
492 | try:
493 | # Handle different image input types
494 | if isinstance(image_input, str):
495 | image = Image.open(image_input).convert("RGB")
496 | elif isinstance(image_input, dict) and "path" in image_input:
497 | image = Image.open(image_input["path"]).convert("RGB")
498 | elif isinstance(image_input, Image.Image):
499 | image = image_input.convert("RGB")
500 | else:
501 | image = image_input
502 |
503 | # Convert to numpy array with channel-first format (C x H x W)
504 | np_img = np.array(image)
505 | if len(np_img.shape) == 3:
506 | np_img = np_img.transpose(2, 0, 1)
507 |
508 | return np_img
509 |
510 | except Exception as e:
511 | logger.error(f"Failed to prepare image for MiniCPM: {e}")
512 | raise
513 |
514 | def _dynamic_preprocess(self, image: Image.Image, min_num: int = 1, max_num: int = 12,
515 | image_size: int = 448, use_thumbnail: bool = False) -> List[Image.Image]:
516 | """Dynamic preprocessing for InternVL."""
517 | orig_width, orig_height = image.size
518 | aspect_ratio = orig_width / orig_height
519 |
520 | # Generate target ratios
521 | target_ratios = [
522 | (i, j) for n in range(min_num, max_num + 1)
523 | for i in range(1, n + 1) for j in range(1, n + 1)
524 | if min_num <= i * j <= max_num
525 | ]
526 | target_ratios.sort(key=lambda x: x[0] * x[1])
527 |
528 | target_aspect_ratio = self._find_closest_aspect_ratio(
529 | aspect_ratio, target_ratios, orig_width, orig_height, image_size
530 | )
531 |
532 | target_width = image_size * target_aspect_ratio[0]
533 | target_height = image_size * target_aspect_ratio[1]
534 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
535 |
536 | # Resize image
537 | resized_img = image.resize((target_width, target_height))
538 |
539 | # Extract blocks
540 | processed_images = []
541 | cols = target_width // image_size
542 |
543 | for i in range(blocks):
544 | col = i % cols
545 | row = i // cols
546 | box = (
547 | col * image_size,
548 | row * image_size,
549 | (col + 1) * image_size,
550 | (row + 1) * image_size,
551 | )
552 | processed_images.append(resized_img.crop(box))
553 |
554 | if use_thumbnail and len(processed_images) != 1:
555 | thumbnail_img = image.resize((image_size, image_size))
556 | processed_images.append(thumbnail_img)
557 |
558 | return processed_images
559 |
560 | def _find_closest_aspect_ratio(self, aspect_ratio: float, target_ratios: List[Tuple[int, int]],
561 | width: int, height: int, image_size: int) -> Tuple[int, int]:
562 | """Find the best matching target aspect ratio for InternVL."""
563 | best_ratio_diff = float("inf")
564 | best_ratio = (1, 1)
565 | area = width * height
566 |
567 | for ratio in target_ratios:
568 | target_aspect_ratio = ratio[0] / ratio[1]
569 | ratio_diff = abs(aspect_ratio - target_aspect_ratio)
570 |
571 | if ratio_diff < best_ratio_diff:
572 | best_ratio_diff = ratio_diff
573 | best_ratio = ratio
574 | elif (ratio_diff == best_ratio_diff and
575 | area > 0.5 * image_size * image_size * ratio[0] * ratio[1]):
576 | best_ratio = ratio
577 |
578 | return best_ratio
579 |
580 |
581 |
582 | def decode(self, generated_ids: Any) -> List[str]:
583 | """
584 | Decode generated outputs.
585 | - For standard models: decode token IDs to text (slice per-sample new tokens using prompt lengths).
586 | - For internvl/minicpm: pass through strings/lists.
587 | """
588 | try:
589 | # Pass-through for models that already return strings
590 | if self.config.inference_type in ("internvl", "minicpm"):
591 | if isinstance(generated_ids, list):
592 | return generated_ids
593 | if isinstance(generated_ids, str):
594 | return [generated_ids]
595 | return [str(generated_ids)]
596 |
597 | return self.processor.batch_decode(
598 | generated_ids,
599 | skip_special_tokens=True,
600 | clean_up_tokenization_spaces=True
601 | )
602 |
603 | except Exception as e:
604 | logger.error(f"Decoding failed for {self.model_type}: {e}")
605 | raise
606 |
607 |
608 | @torch.inference_mode()
609 | def generate(self, inputs: Any, **generation_kwargs) -> Any:
610 | """
611 | Generate text using the model with unified interface.
612 |
613 | Args:
614 | inputs: Preprocessed inputs
615 | **generation_kwargs: Generation parameters
616 |
617 | Returns:
618 | Generated outputs (format depends on model type)
619 | """
620 | try:
621 | if self.config.inference_type == "internvl":
622 | return self._generate_internvl(inputs, **generation_kwargs)
623 | elif self.config.inference_type == "minicpm":
624 | return self._generate_minicpm(inputs, **generation_kwargs)
625 | else:
626 | sequences = self.model.generate(
627 | **inputs,
628 | **generation_kwargs
629 | )
630 |
631 | # Slice prompt tokens from generated sequences
632 | prompt_length = inputs["input_ids"].shape[-1]
633 | if sequences.dim() == 1:
634 | generated_ids = sequences[prompt_length:]
635 | else:
636 | generated_ids = sequences[:, prompt_length:]
637 |
638 | return generated_ids
639 |
640 | except Exception as e:
641 | logger.error(f"Generation failed: {e}")
642 | raise
643 |
644 | def _generate_internvl(self, inputs: Dict[str, Any], **generation_kwargs) -> List[str]:
645 | """Generate using InternVL's chat API per-sample to avoid batch misalignment issues."""
646 | pixel_values_in = inputs["pixel_values"]
647 | questions = inputs["questions"]
648 | num_patches_list = inputs["num_patches_list"]
649 |
650 | # Build per-sample pixel tensors list
651 | if isinstance(pixel_values_in, list):
652 | pv_list = pixel_values_in
653 | else:
654 | # Split concatenated tensor by patch counts
655 | pv_list = list(torch.split(pixel_values_in, num_patches_list, dim=0))
656 |
657 | if len(pv_list) != len(questions):
658 | raise ValueError(
659 | f"InternVL mismatch: {len(pv_list)} pixel groups vs {len(questions)} questions"
660 | )
661 |
662 | responses: List[str] = []
663 | for pv, q, n in zip(pv_list, questions, num_patches_list):
664 | # Sanity checks
665 | if pv.dim() != 4 or pv.size(0) != int(n):
666 | raise ValueError(
667 | f"InternVL per-sample mismatch: pixel batch={pv.size(0)} vs n={n}, shape={tuple(pv.shape)}"
668 | )
669 | out = self.model.batch_chat(
670 | self.processor,
671 | pv,
672 | num_patches_list=[int(n)],
673 | questions=[q],
674 | generation_config=generation_kwargs,
675 | )
676 | # Normalize to string
677 | if isinstance(out, list):
678 | if len(out) == 1 and isinstance(out[0], dict) and 'response' in out[0]:
679 | responses.append(out[0]['response'])
680 | elif len(out) == 1 and isinstance(out[0], str):
681 | responses.append(out[0])
682 | else:
683 | responses.append(str(out))
684 | else:
685 | responses.append(str(out))
686 | return responses
687 |
688 | def _generate_minicpm(self, msgs: Any, **generation_kwargs) -> List[str]:
689 | """Generate using MiniCPM's chat interface."""
690 | # Default generation config
691 | generation_config = {
692 | "max_tokens": 128,
693 | "do_sample": False,
694 | }
695 | generation_config.update(generation_kwargs)
696 |
697 | if isinstance(msgs, list) and isinstance(msgs[0], list):
698 | # Batch processing
699 | responses = []
700 | for msg in msgs:
701 | try:
702 | response = self.model.chat(
703 | image=None,
704 | msgs=msg,
705 | tokenizer=self.processor,
706 | **generation_config
707 | )
708 | responses.append(response)
709 | except Exception as e:
710 | logger.warning(f"Individual MiniCPM inference failed: {e}")
711 | responses.append(f"ERROR: {str(e)}")
712 | return responses
713 | else:
714 | # Single processing
715 | return self.model.chat(
716 | image=None,
717 | msgs=msgs,
718 | tokenizer=self.processor,
719 | **generation_config
720 | )
721 |
722 | @contextmanager
723 | def memory_efficient_mode(self):
724 | """Context manager for memory-efficient inference."""
725 | original_grad_state = torch.is_grad_enabled()
726 | try:
727 | torch.set_grad_enabled(False)
728 | if torch.cuda.is_available():
729 | torch.cuda.empty_cache()
730 | yield
731 | finally:
732 | torch.set_grad_enabled(original_grad_state)
733 | if torch.cuda.is_available():
734 | torch.cuda.empty_cache()
735 |
736 | def __del__(self):
737 | """Cleanup resources."""
738 | try:
739 | if torch is not None and torch.cuda.is_available():
740 | torch.cuda.empty_cache()
741 | except Exception:
742 | pass
743 |
744 |
745 | # Alias for backward compatibility
746 | VLM = VLMWrapper
--------------------------------------------------------------------------------