├── .gitignore ├── images └── overview.png ├── requirements.txt ├── README.md └── evaluation.py /.gitignore: -------------------------------------------------------------------------------- 1 | images/.DS_Store 2 | images/.DS_Store 3 | -------------------------------------------------------------------------------- /images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JD-GenX/Uni-Layout/HEAD/images/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | transformers==4.38.2 3 | sentencepiece==0.1.99 4 | protobuf 5 | opencv-python-headless 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Uni-Layout 2 | ## Uni-Layout: Integrating Human Feedback in Unified Layout Generation and Evaluation 3 | [ACM MM 2025] Official PyTorch Code for "Uni-Layout: Integrating Human Feedback in Unified Layout Generation and Evaluation" 4 | 5 | image 6 | 7 | ## 📢 News 8 | 9 | `[2025-09-02]:` 🚀 CoT data has been released! You can now find it in the ["Dataset for Reward Model" link](https://drive.google.com/drive/folders/1VASp90_mqSwJxJH65v5-iP9Sk3tgr23M?usp=drive_link). 10 | 11 | `[2025-08-04]:` 🎯 Our paper is now available on arXiv! Check it out here: [https://arxiv.org/abs/2508.02374](https://arxiv.org/abs/2508.02374). 12 | 13 | `[2025-07-04]:` 🎉 Exciting news! Our paper has been accepted to ACM MM 2025! Stay tuned for more updates! 14 | 15 | ## 🚀 Code & Weights Notice 16 | 17 | - Layout Evaluator Checkpoints: [Download Link](https://drive.google.com/drive/folders/1evrHmorHW7CBLRhxrV3-3qvFki1ovoJ3?usp=drive_link) 18 | 19 | ### 🧪 Evaluation 20 | 21 | - **Script**: `evaluation.py` 22 | 23 | #### Requirements 24 | - Python >= 3.8 (recommend Anaconda/Miniconda) 25 | - PyTorch >= 2.3.1 + CUDA 11.8 (install from official wheel index) 26 | - Extra deps in `requirements.txt` 27 | 28 | #### Setup 29 | ```bash 30 | conda create -n caig python==3.8.20 -y && conda activate caig 31 | pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | #### Run 36 | ```bash 37 | python evaluation.py \ 38 | --model_path /path/to/model \ 39 | --input_data_path /path/to/input.json \ 40 | --output_data_path /path/to/output.json 41 | ``` 42 | 43 | #### Notes 44 | - Optional: `--model_base`, `--conv_mode`, generation args (`--temperature`, `--top_p`, `--num_beams`, `--max_new_tokens`, `--generate_nums`), and process args (`--save_interval`, `--batch_size`). 45 | - Input JSON follows the dataset format below; `image` field is optional. 46 | 47 | ## 📊 Datasets 48 | ### 1. Dataset for Layout Generator 49 | [Download Link](https://drive.google.com/drive/folders/1OLWRUZSiecpGuG2sUdQHOnmp46P9ojuD?usp=sharing). 50 | 51 | #### Key Fields 52 | - **`sku_id`**: Anonymized sample identifier. 53 | - **`image`**: Path to the image (optional; may be absent for text-only tasks). 54 | - **`conversations`**: List of two messages: 55 | - **human**: Task description, may include the `` tag, canvas size, element types, and layout constraints. 56 | - **gpt**: Layout result; `value` is a string in the form `Layout:{...}`, where bounding boxes are `[x_min, y_min, x_max, y_max]`. 57 | 58 | ### 2. Dataset for Layout Evaluator 59 | [Download Link](https://drive.google.com/drive/folders/1VASp90_mqSwJxJH65v5-iP9Sk3tgr23M?usp=drive_link). 60 | 61 | #### Key Fields 62 | - **`image`**: Path to the image. 63 | - **`conversations`**: Single-turn QA pair: 64 | - **human**: Evaluation instruction with candidate layout and constraints; expects a binary decision (0/1). 65 | - **gpt**: The answer; `value` is the Ground Truth label (0 or 1). 66 | 67 | 68 | # Copyright & Licensing 69 | © JD.COM. All rights reserved. The datasets and code provided in this repository are licensed exclusively for academic research purposes. Commercial use, reproduction, or distribution requires express written permission from JD.COM. Unauthorized commercial use constitutes a violation of these terms and is strictly prohibited. 70 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | LLaVA Model Evaluation Script 4 | 5 | This script performs inference using LLaVA models on multimodal data. 6 | It supports both image-text and text-only tasks with distributed processing. 7 | 8 | Author: [Your Name] 9 | Date: [Current Date] 10 | """ 11 | 12 | import argparse 13 | import json 14 | import os 15 | import re 16 | import time 17 | from typing import Dict, List, Optional, Union 18 | from pathlib import Path 19 | 20 | import torch 21 | from PIL import Image 22 | from tqdm import tqdm 23 | from accelerate import Accelerator 24 | from accelerate.utils import gather_object 25 | from accelerate import InitProcessGroupKwargs 26 | import datetime 27 | 28 | # LLaVA imports 29 | from llava.constants import ( 30 | IMAGE_TOKEN_INDEX, 31 | DEFAULT_IMAGE_TOKEN, 32 | DEFAULT_IM_START_TOKEN, 33 | DEFAULT_IM_END_TOKEN, 34 | IMAGE_PLACEHOLDER, 35 | ) 36 | from llava.conversation import conv_templates 37 | from llava.model.builder import load_pretrained_model 38 | from llava.utils import disable_torch_init 39 | from llava.mm_utils import ( 40 | process_images, 41 | tokenizer_image_token, 42 | get_model_name_from_path, 43 | ) 44 | 45 | 46 | def load_image(image_file: str) -> Image.Image: 47 | """Load image from file path or URL.""" 48 | if image_file.startswith(("http://", "https://")): 49 | import requests 50 | from io import BytesIO 51 | response = requests.get(image_file) 52 | image = Image.open(BytesIO(response.content)).convert("RGB") 53 | else: 54 | image = Image.open(image_file).convert("RGB") 55 | return image 56 | 57 | 58 | def process_prompt( 59 | caption: str, 60 | args, 61 | mm_use_im_start_end: bool, 62 | model_name: str, 63 | has_image: bool = True 64 | ) -> str: 65 | """Process prompt for model input with image token handling.""" 66 | qs = caption 67 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 68 | 69 | if has_image: 70 | if IMAGE_PLACEHOLDER in qs: 71 | if mm_use_im_start_end: 72 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 73 | else: 74 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 75 | else: 76 | if mm_use_im_start_end: 77 | qs = image_token_se + "\n" + qs 78 | else: 79 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 80 | else: 81 | if IMAGE_PLACEHOLDER in qs: 82 | qs = qs.replace(IMAGE_PLACEHOLDER, "") 83 | 84 | # Determine conversation mode based on model name 85 | if "llama-2" in model_name.lower(): 86 | conv_mode = "llava_llama_2" 87 | elif "mistral" in model_name.lower(): 88 | conv_mode = "mistral_instruct" 89 | elif "v1.6-34b" in model_name.lower(): 90 | conv_mode = "chatml_direct" 91 | elif "v1" in model_name.lower(): 92 | conv_mode = "llava_v1" 93 | elif "mpt" in model_name.lower(): 94 | conv_mode = "mpt" 95 | else: 96 | conv_mode = "llava_v0" 97 | 98 | if args.conv_mode is not None and conv_mode != args.conv_mode: 99 | print(f"[WARNING] Auto inferred conversation mode is {conv_mode}, " 100 | f"while `--conv-mode` is {args.conv_mode}, using {args.conv_mode}") 101 | else: 102 | args.conv_mode = conv_mode 103 | 104 | conv = conv_templates[args.conv_mode].copy() 105 | conv.append_message(conv.roles[0], qs) 106 | conv.append_message(conv.roles[1], None) 107 | prompt = conv.get_prompt() 108 | return prompt 109 | 110 | 111 | def load_data(data_path: str) -> List[Dict]: 112 | """Load data from JSON file.""" 113 | with open(data_path, 'r', encoding='utf-8') as f: 114 | data = json.load(f) 115 | return data 116 | 117 | 118 | def process_single_item( 119 | item: Dict, 120 | args, 121 | tokenizer, 122 | model, 123 | image_processor, 124 | mm_use_im_start_end: bool, 125 | model_name: str 126 | ) -> Dict: 127 | """Process a single data item for inference.""" 128 | new_sample = {} 129 | 130 | if 'image' in item: 131 | # Process image-text task 132 | inp = '\n'.join(item["conversations"][0]['value'].split('\n')[1:]) 133 | prompt = process_prompt(inp, args, mm_use_im_start_end, model_name, has_image=True) 134 | input_ids = tokenizer_image_token( 135 | prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" 136 | ).unsqueeze(0).cuda() 137 | 138 | new_sample['question'] = prompt 139 | new_sample['label_answer'] = item["conversations"][1]['value'] 140 | new_sample['image'] = item['image'] 141 | 142 | # Load and process image 143 | try: 144 | images = [Image.open(item['image']).convert("RGB")] 145 | image_sizes = [x.size for x in images] 146 | images_tensor = process_images( 147 | images, image_processor, model.config 148 | ).to(model.device, dtype=torch.float16) 149 | except Exception as e: 150 | print(f"Error loading image {item['image']}: {e}") 151 | return None 152 | 153 | # Generate responses 154 | outputs = [] 155 | for i in range(args.generate_nums): 156 | with torch.inference_mode(): 157 | output_ids = model.generate( 158 | input_ids, 159 | images=images_tensor, 160 | image_sizes=image_sizes, 161 | do_sample=True if args.temperature > 0 else False, 162 | temperature=args.temperature, 163 | top_p=args.top_p, 164 | num_beams=args.num_beams, 165 | max_new_tokens=args.max_new_tokens, 166 | use_cache=True, 167 | ) 168 | 169 | output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 170 | if len(output) >= 800: # Prevent overly long outputs 171 | break 172 | outputs.append(output) 173 | 174 | else: 175 | # Process text-only task 176 | inp = item["conversations"][0]['value'] 177 | prompt = process_prompt(inp, args, mm_use_im_start_end, model_name, has_image=False) 178 | input_ids = tokenizer_image_token( 179 | prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" 180 | ).unsqueeze(0).cuda() 181 | 182 | new_sample['question'] = prompt 183 | new_sample['label_answer'] = item["conversations"][1]['value'] 184 | 185 | # Generate responses 186 | outputs = [] 187 | for i in range(args.generate_nums): 188 | with torch.inference_mode(): 189 | output_ids = model.generate( 190 | input_ids, 191 | images=None, 192 | image_sizes=None, 193 | do_sample=True if args.temperature > 0 else False, 194 | temperature=args.temperature, 195 | top_p=args.top_p, 196 | num_beams=args.num_beams, 197 | max_new_tokens=args.max_new_tokens, 198 | use_cache=True, 199 | ) 200 | 201 | output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 202 | if len(output) >= 800: # Prevent overly long outputs 203 | break 204 | outputs.append(output) 205 | 206 | new_sample['gpt_answer'] = outputs if args.generate_nums > 1 else outputs[0] 207 | return new_sample 208 | 209 | 210 | def evaluate_model(args): 211 | """Main evaluation function.""" 212 | disable_torch_init() 213 | process_group_kwargs = InitProcessGroupKwargs(timeout=datetime.timedelta(seconds=540000)) 214 | accelerator = Accelerator(kwargs_handlers=[process_group_kwargs]) 215 | 216 | if args.model_path == "None": 217 | args.model_path = args.model_base 218 | args.model_base = None 219 | print("Changed args.model_path to args.model_base") 220 | 221 | model_name = get_model_name_from_path(args.model_path) 222 | tokenizer, model, image_processor, context_len = load_pretrained_model( 223 | args.model_path, args.model_base, model_name, device=accelerator.process_index 224 | ) 225 | mm_use_im_start_end = model.config.mm_use_im_start_end 226 | 227 | # Load data 228 | data = load_data(args.input_data_path) 229 | print(f"Total number of samples: {len(data)}") 230 | 231 | accelerator.wait_for_everyone() 232 | start = time.time() 233 | 234 | with accelerator.split_between_processes(data) as prompts: 235 | results = [] 236 | 237 | for num, item in tqdm(enumerate(prompts), total=len(prompts), desc="Processing"): 238 | try: 239 | result = process_single_item( 240 | item, args, tokenizer, model, image_processor, 241 | mm_use_im_start_end, model_name 242 | ) 243 | if result is not None: 244 | results.append(result) 245 | except Exception as e: 246 | print(f"Error processing item {num}: {e}") 247 | continue 248 | 249 | # Save intermediate results periodically 250 | if num % args.save_interval == 0 and num != 0 and accelerator.is_main_process: 251 | results_gathered = gather_object(results) 252 | formatted_data = json.dumps(results_gathered, indent=0, ensure_ascii=False) 253 | 254 | timediff = time.time() - start 255 | output_file_path = f"{args.output_data_path}_checkpoint_{num}.json" 256 | with open(output_file_path, 'w', encoding='utf-8') as file: 257 | file.write(formatted_data) 258 | print(f"Checkpoint saved at {num}, time elapsed: {timediff:.2f}s") 259 | start = time.time() 260 | 261 | # Save final results 262 | results_gathered = gather_object(results) 263 | formatted_data = json.dumps(results_gathered, indent=0, ensure_ascii=False) 264 | 265 | if accelerator.is_main_process: 266 | timediff = time.time() - start 267 | with open(args.output_data_path, 'w', encoding='utf-8') as file: 268 | file.write(formatted_data) 269 | print(f"Final results saved, total time elapsed: {timediff:.2f}s") 270 | print(f"Processed {len(results_gathered)} samples successfully") 271 | 272 | 273 | def main(): 274 | """Main entry point.""" 275 | parser = argparse.ArgumentParser(description="LLaVA Model Evaluation") 276 | 277 | # Model arguments 278 | parser.add_argument("--model_path", type=str, required=True, 279 | help="Path to the LLaVA model") 280 | parser.add_argument("--model_base", type=str, default=None, 281 | help="Base model path (for delta weights)") 282 | parser.add_argument("--conv_mode", type=str, default=None, 283 | help="Conversation mode (auto-detected if not specified)") 284 | 285 | # Data arguments 286 | parser.add_argument("--input_data_path", type=str, required=True, 287 | help="Path to input JSON data file") 288 | parser.add_argument("--output_data_path", type=str, required=True, 289 | help="Path to output JSON results file") 290 | 291 | # Generation arguments 292 | parser.add_argument("--temperature", type=float, default=0.8, 293 | help="Sampling temperature") 294 | parser.add_argument("--top_p", type=float, default=None, 295 | help="Top-p sampling parameter") 296 | parser.add_argument("--num_beams", type=int, default=1, 297 | help="Number of beams for generation") 298 | parser.add_argument("--max_new_tokens", type=int, default=512, 299 | help="Maximum number of new tokens to generate") 300 | parser.add_argument("--generate_nums", type=int, default=1, 301 | help="Number of generations per sample") 302 | 303 | # Processing arguments 304 | parser.add_argument("--save_interval", type=int, default=5000, 305 | help="Save intermediate results every N samples") 306 | parser.add_argument("--batch_size", type=int, default=None, 307 | help="Batch size for processing") 308 | 309 | # Optional arguments 310 | parser.add_argument("--image_file", type=str, help="Single image file for testing") 311 | parser.add_argument("--query", type=str, help="Single query for testing") 312 | parser.add_argument("--sep", type=str, default=",", help="Separator for multiple images") 313 | parser.add_argument("--debug", action="store_true", default=False, 314 | help="Enable debug mode") 315 | 316 | args = parser.parse_args() 317 | 318 | if args.debug: 319 | args.temperature = 0 320 | args.generate_nums = 1 321 | print("Debug mode enabled") 322 | 323 | evaluate_model(args) 324 | 325 | 326 | if __name__ == "__main__": 327 | main() 328 | --------------------------------------------------------------------------------