├── .gitignore
├── images
└── overview.png
├── requirements.txt
├── README.md
└── evaluation.py
/.gitignore:
--------------------------------------------------------------------------------
1 | images/.DS_Store
2 | images/.DS_Store
3 |
--------------------------------------------------------------------------------
/images/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JD-GenX/Uni-Layout/HEAD/images/overview.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | transformers==4.38.2
3 | sentencepiece==0.1.99
4 | protobuf
5 | opencv-python-headless
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Uni-Layout
2 | ## Uni-Layout: Integrating Human Feedback in Unified Layout Generation and Evaluation
3 | [ACM MM 2025] Official PyTorch Code for "Uni-Layout: Integrating Human Feedback in Unified Layout Generation and Evaluation"
4 |
5 |
6 |
7 | ## 📢 News
8 |
9 | `[2025-09-02]:` 🚀 CoT data has been released! You can now find it in the ["Dataset for Reward Model" link](https://drive.google.com/drive/folders/1VASp90_mqSwJxJH65v5-iP9Sk3tgr23M?usp=drive_link).
10 |
11 | `[2025-08-04]:` 🎯 Our paper is now available on arXiv! Check it out here: [https://arxiv.org/abs/2508.02374](https://arxiv.org/abs/2508.02374).
12 |
13 | `[2025-07-04]:` 🎉 Exciting news! Our paper has been accepted to ACM MM 2025! Stay tuned for more updates!
14 |
15 | ## 🚀 Code & Weights Notice
16 |
17 | - Layout Evaluator Checkpoints: [Download Link](https://drive.google.com/drive/folders/1evrHmorHW7CBLRhxrV3-3qvFki1ovoJ3?usp=drive_link)
18 |
19 | ### 🧪 Evaluation
20 |
21 | - **Script**: `evaluation.py`
22 |
23 | #### Requirements
24 | - Python >= 3.8 (recommend Anaconda/Miniconda)
25 | - PyTorch >= 2.3.1 + CUDA 11.8 (install from official wheel index)
26 | - Extra deps in `requirements.txt`
27 |
28 | #### Setup
29 | ```bash
30 | conda create -n caig python==3.8.20 -y && conda activate caig
31 | pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
32 | pip install -r requirements.txt
33 | ```
34 |
35 | #### Run
36 | ```bash
37 | python evaluation.py \
38 | --model_path /path/to/model \
39 | --input_data_path /path/to/input.json \
40 | --output_data_path /path/to/output.json
41 | ```
42 |
43 | #### Notes
44 | - Optional: `--model_base`, `--conv_mode`, generation args (`--temperature`, `--top_p`, `--num_beams`, `--max_new_tokens`, `--generate_nums`), and process args (`--save_interval`, `--batch_size`).
45 | - Input JSON follows the dataset format below; `image` field is optional.
46 |
47 | ## 📊 Datasets
48 | ### 1. Dataset for Layout Generator
49 | [Download Link](https://drive.google.com/drive/folders/1OLWRUZSiecpGuG2sUdQHOnmp46P9ojuD?usp=sharing).
50 |
51 | #### Key Fields
52 | - **`sku_id`**: Anonymized sample identifier.
53 | - **`image`**: Path to the image (optional; may be absent for text-only tasks).
54 | - **`conversations`**: List of two messages:
55 | - **human**: Task description, may include the `` tag, canvas size, element types, and layout constraints.
56 | - **gpt**: Layout result; `value` is a string in the form `Layout:{...}`, where bounding boxes are `[x_min, y_min, x_max, y_max]`.
57 |
58 | ### 2. Dataset for Layout Evaluator
59 | [Download Link](https://drive.google.com/drive/folders/1VASp90_mqSwJxJH65v5-iP9Sk3tgr23M?usp=drive_link).
60 |
61 | #### Key Fields
62 | - **`image`**: Path to the image.
63 | - **`conversations`**: Single-turn QA pair:
64 | - **human**: Evaluation instruction with candidate layout and constraints; expects a binary decision (0/1).
65 | - **gpt**: The answer; `value` is the Ground Truth label (0 or 1).
66 |
67 |
68 | # Copyright & Licensing
69 | © JD.COM. All rights reserved. The datasets and code provided in this repository are licensed exclusively for academic research purposes. Commercial use, reproduction, or distribution requires express written permission from JD.COM. Unauthorized commercial use constitutes a violation of these terms and is strictly prohibited.
70 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | LLaVA Model Evaluation Script
4 |
5 | This script performs inference using LLaVA models on multimodal data.
6 | It supports both image-text and text-only tasks with distributed processing.
7 |
8 | Author: [Your Name]
9 | Date: [Current Date]
10 | """
11 |
12 | import argparse
13 | import json
14 | import os
15 | import re
16 | import time
17 | from typing import Dict, List, Optional, Union
18 | from pathlib import Path
19 |
20 | import torch
21 | from PIL import Image
22 | from tqdm import tqdm
23 | from accelerate import Accelerator
24 | from accelerate.utils import gather_object
25 | from accelerate import InitProcessGroupKwargs
26 | import datetime
27 |
28 | # LLaVA imports
29 | from llava.constants import (
30 | IMAGE_TOKEN_INDEX,
31 | DEFAULT_IMAGE_TOKEN,
32 | DEFAULT_IM_START_TOKEN,
33 | DEFAULT_IM_END_TOKEN,
34 | IMAGE_PLACEHOLDER,
35 | )
36 | from llava.conversation import conv_templates
37 | from llava.model.builder import load_pretrained_model
38 | from llava.utils import disable_torch_init
39 | from llava.mm_utils import (
40 | process_images,
41 | tokenizer_image_token,
42 | get_model_name_from_path,
43 | )
44 |
45 |
46 | def load_image(image_file: str) -> Image.Image:
47 | """Load image from file path or URL."""
48 | if image_file.startswith(("http://", "https://")):
49 | import requests
50 | from io import BytesIO
51 | response = requests.get(image_file)
52 | image = Image.open(BytesIO(response.content)).convert("RGB")
53 | else:
54 | image = Image.open(image_file).convert("RGB")
55 | return image
56 |
57 |
58 | def process_prompt(
59 | caption: str,
60 | args,
61 | mm_use_im_start_end: bool,
62 | model_name: str,
63 | has_image: bool = True
64 | ) -> str:
65 | """Process prompt for model input with image token handling."""
66 | qs = caption
67 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
68 |
69 | if has_image:
70 | if IMAGE_PLACEHOLDER in qs:
71 | if mm_use_im_start_end:
72 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
73 | else:
74 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
75 | else:
76 | if mm_use_im_start_end:
77 | qs = image_token_se + "\n" + qs
78 | else:
79 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
80 | else:
81 | if IMAGE_PLACEHOLDER in qs:
82 | qs = qs.replace(IMAGE_PLACEHOLDER, "")
83 |
84 | # Determine conversation mode based on model name
85 | if "llama-2" in model_name.lower():
86 | conv_mode = "llava_llama_2"
87 | elif "mistral" in model_name.lower():
88 | conv_mode = "mistral_instruct"
89 | elif "v1.6-34b" in model_name.lower():
90 | conv_mode = "chatml_direct"
91 | elif "v1" in model_name.lower():
92 | conv_mode = "llava_v1"
93 | elif "mpt" in model_name.lower():
94 | conv_mode = "mpt"
95 | else:
96 | conv_mode = "llava_v0"
97 |
98 | if args.conv_mode is not None and conv_mode != args.conv_mode:
99 | print(f"[WARNING] Auto inferred conversation mode is {conv_mode}, "
100 | f"while `--conv-mode` is {args.conv_mode}, using {args.conv_mode}")
101 | else:
102 | args.conv_mode = conv_mode
103 |
104 | conv = conv_templates[args.conv_mode].copy()
105 | conv.append_message(conv.roles[0], qs)
106 | conv.append_message(conv.roles[1], None)
107 | prompt = conv.get_prompt()
108 | return prompt
109 |
110 |
111 | def load_data(data_path: str) -> List[Dict]:
112 | """Load data from JSON file."""
113 | with open(data_path, 'r', encoding='utf-8') as f:
114 | data = json.load(f)
115 | return data
116 |
117 |
118 | def process_single_item(
119 | item: Dict,
120 | args,
121 | tokenizer,
122 | model,
123 | image_processor,
124 | mm_use_im_start_end: bool,
125 | model_name: str
126 | ) -> Dict:
127 | """Process a single data item for inference."""
128 | new_sample = {}
129 |
130 | if 'image' in item:
131 | # Process image-text task
132 | inp = '\n'.join(item["conversations"][0]['value'].split('\n')[1:])
133 | prompt = process_prompt(inp, args, mm_use_im_start_end, model_name, has_image=True)
134 | input_ids = tokenizer_image_token(
135 | prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
136 | ).unsqueeze(0).cuda()
137 |
138 | new_sample['question'] = prompt
139 | new_sample['label_answer'] = item["conversations"][1]['value']
140 | new_sample['image'] = item['image']
141 |
142 | # Load and process image
143 | try:
144 | images = [Image.open(item['image']).convert("RGB")]
145 | image_sizes = [x.size for x in images]
146 | images_tensor = process_images(
147 | images, image_processor, model.config
148 | ).to(model.device, dtype=torch.float16)
149 | except Exception as e:
150 | print(f"Error loading image {item['image']}: {e}")
151 | return None
152 |
153 | # Generate responses
154 | outputs = []
155 | for i in range(args.generate_nums):
156 | with torch.inference_mode():
157 | output_ids = model.generate(
158 | input_ids,
159 | images=images_tensor,
160 | image_sizes=image_sizes,
161 | do_sample=True if args.temperature > 0 else False,
162 | temperature=args.temperature,
163 | top_p=args.top_p,
164 | num_beams=args.num_beams,
165 | max_new_tokens=args.max_new_tokens,
166 | use_cache=True,
167 | )
168 |
169 | output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
170 | if len(output) >= 800: # Prevent overly long outputs
171 | break
172 | outputs.append(output)
173 |
174 | else:
175 | # Process text-only task
176 | inp = item["conversations"][0]['value']
177 | prompt = process_prompt(inp, args, mm_use_im_start_end, model_name, has_image=False)
178 | input_ids = tokenizer_image_token(
179 | prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
180 | ).unsqueeze(0).cuda()
181 |
182 | new_sample['question'] = prompt
183 | new_sample['label_answer'] = item["conversations"][1]['value']
184 |
185 | # Generate responses
186 | outputs = []
187 | for i in range(args.generate_nums):
188 | with torch.inference_mode():
189 | output_ids = model.generate(
190 | input_ids,
191 | images=None,
192 | image_sizes=None,
193 | do_sample=True if args.temperature > 0 else False,
194 | temperature=args.temperature,
195 | top_p=args.top_p,
196 | num_beams=args.num_beams,
197 | max_new_tokens=args.max_new_tokens,
198 | use_cache=True,
199 | )
200 |
201 | output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
202 | if len(output) >= 800: # Prevent overly long outputs
203 | break
204 | outputs.append(output)
205 |
206 | new_sample['gpt_answer'] = outputs if args.generate_nums > 1 else outputs[0]
207 | return new_sample
208 |
209 |
210 | def evaluate_model(args):
211 | """Main evaluation function."""
212 | disable_torch_init()
213 | process_group_kwargs = InitProcessGroupKwargs(timeout=datetime.timedelta(seconds=540000))
214 | accelerator = Accelerator(kwargs_handlers=[process_group_kwargs])
215 |
216 | if args.model_path == "None":
217 | args.model_path = args.model_base
218 | args.model_base = None
219 | print("Changed args.model_path to args.model_base")
220 |
221 | model_name = get_model_name_from_path(args.model_path)
222 | tokenizer, model, image_processor, context_len = load_pretrained_model(
223 | args.model_path, args.model_base, model_name, device=accelerator.process_index
224 | )
225 | mm_use_im_start_end = model.config.mm_use_im_start_end
226 |
227 | # Load data
228 | data = load_data(args.input_data_path)
229 | print(f"Total number of samples: {len(data)}")
230 |
231 | accelerator.wait_for_everyone()
232 | start = time.time()
233 |
234 | with accelerator.split_between_processes(data) as prompts:
235 | results = []
236 |
237 | for num, item in tqdm(enumerate(prompts), total=len(prompts), desc="Processing"):
238 | try:
239 | result = process_single_item(
240 | item, args, tokenizer, model, image_processor,
241 | mm_use_im_start_end, model_name
242 | )
243 | if result is not None:
244 | results.append(result)
245 | except Exception as e:
246 | print(f"Error processing item {num}: {e}")
247 | continue
248 |
249 | # Save intermediate results periodically
250 | if num % args.save_interval == 0 and num != 0 and accelerator.is_main_process:
251 | results_gathered = gather_object(results)
252 | formatted_data = json.dumps(results_gathered, indent=0, ensure_ascii=False)
253 |
254 | timediff = time.time() - start
255 | output_file_path = f"{args.output_data_path}_checkpoint_{num}.json"
256 | with open(output_file_path, 'w', encoding='utf-8') as file:
257 | file.write(formatted_data)
258 | print(f"Checkpoint saved at {num}, time elapsed: {timediff:.2f}s")
259 | start = time.time()
260 |
261 | # Save final results
262 | results_gathered = gather_object(results)
263 | formatted_data = json.dumps(results_gathered, indent=0, ensure_ascii=False)
264 |
265 | if accelerator.is_main_process:
266 | timediff = time.time() - start
267 | with open(args.output_data_path, 'w', encoding='utf-8') as file:
268 | file.write(formatted_data)
269 | print(f"Final results saved, total time elapsed: {timediff:.2f}s")
270 | print(f"Processed {len(results_gathered)} samples successfully")
271 |
272 |
273 | def main():
274 | """Main entry point."""
275 | parser = argparse.ArgumentParser(description="LLaVA Model Evaluation")
276 |
277 | # Model arguments
278 | parser.add_argument("--model_path", type=str, required=True,
279 | help="Path to the LLaVA model")
280 | parser.add_argument("--model_base", type=str, default=None,
281 | help="Base model path (for delta weights)")
282 | parser.add_argument("--conv_mode", type=str, default=None,
283 | help="Conversation mode (auto-detected if not specified)")
284 |
285 | # Data arguments
286 | parser.add_argument("--input_data_path", type=str, required=True,
287 | help="Path to input JSON data file")
288 | parser.add_argument("--output_data_path", type=str, required=True,
289 | help="Path to output JSON results file")
290 |
291 | # Generation arguments
292 | parser.add_argument("--temperature", type=float, default=0.8,
293 | help="Sampling temperature")
294 | parser.add_argument("--top_p", type=float, default=None,
295 | help="Top-p sampling parameter")
296 | parser.add_argument("--num_beams", type=int, default=1,
297 | help="Number of beams for generation")
298 | parser.add_argument("--max_new_tokens", type=int, default=512,
299 | help="Maximum number of new tokens to generate")
300 | parser.add_argument("--generate_nums", type=int, default=1,
301 | help="Number of generations per sample")
302 |
303 | # Processing arguments
304 | parser.add_argument("--save_interval", type=int, default=5000,
305 | help="Save intermediate results every N samples")
306 | parser.add_argument("--batch_size", type=int, default=None,
307 | help="Batch size for processing")
308 |
309 | # Optional arguments
310 | parser.add_argument("--image_file", type=str, help="Single image file for testing")
311 | parser.add_argument("--query", type=str, help="Single query for testing")
312 | parser.add_argument("--sep", type=str, default=",", help="Separator for multiple images")
313 | parser.add_argument("--debug", action="store_true", default=False,
314 | help="Enable debug mode")
315 |
316 | args = parser.parse_args()
317 |
318 | if args.debug:
319 | args.temperature = 0
320 | args.generate_nums = 1
321 | print("Debug mode enabled")
322 |
323 | evaluate_model(args)
324 |
325 |
326 | if __name__ == "__main__":
327 | main()
328 |
--------------------------------------------------------------------------------