├── .gitattributes ├── CLIS ├── eval_image.py └── eval_layout.py ├── LICENSE ├── README.md ├── config ├── bbox.json ├── create_dataset.json └── model_config.json ├── infer_image.py ├── inference_single_data.py ├── inputs └── demo.json ├── llm.py ├── prompt.py ├── requirements.txt └── utils ├── parse.py ├── utils.py └── visualize.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /CLIS/eval_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from prompt import get_prompt 5 | from llm import infer 6 | from utils.parse import parse_score 7 | from utils.utils import create_dir, get_attribute, scale_bbox 8 | 9 | from PIL import Image 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | 13 | def vlm_caption( 14 | img_path: str, 15 | task: str = 'vlm_global_describe', 16 | bbox: list = [], 17 | vlm: AutoModelForCausalLM = None, 18 | vlm_tokenizer: AutoTokenizer = None 19 | ): 20 | 21 | prompt = get_prompt(task=task) 22 | if 'local' in task: 23 | # img = Image.open(img_path).convert("RGB") 24 | width, height = 1000, 1000 25 | xmin = int(bbox[0] * width) - 1 26 | ymin = int(bbox[1] * height) - 1 27 | xmax = int((bbox[0] + bbox[2]) * width) - 1 28 | ymax = int((bbox[1] + bbox[3]) * height) - 1 29 | prompt.format(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) 30 | 31 | query = vlm_tokenizer.from_list_format([ 32 | {"image": img_path}, 33 | {"text": prompt}, 34 | ]) 35 | 36 | response, history = vlm.chat(vlm_tokenizer, query=query, history=None) 37 | 38 | return response 39 | 40 | 41 | def llm_align( 42 | text: str, 43 | pred: str, 44 | model: str = 'qwen-1.5-14b', 45 | llm: AutoModelForCausalLM = None, 46 | llm_tokenizer: AutoTokenizer = None 47 | ): 48 | ''' 49 | Function: 50 | use llm to score 51 | ''' 52 | 53 | # init_llm_model() 54 | 55 | prompt = get_prompt(task='llm_align').format(answer=text, pred=pred) 56 | response = infer(model=model, prompt=prompt, h_model=llm, h_tokenizer=llm_tokenizer) 57 | score = parse_score(response) 58 | 59 | return score 60 | 61 | 62 | def clis_img( 63 | img_path: str, 64 | text: str, 65 | local_list: list, 66 | crop_path: str = 'crop/0.png', 67 | local_weights: list = [0.8, 0.2], 68 | weights: list = [0.5, 0.5], 69 | vlm: AutoModelForCausalLM = None, 70 | vlm_tokenizer: AutoTokenizer = None, 71 | llm: AutoModelForCausalLM = None, 72 | llm_tokenizer: AutoTokenizer = None 73 | ): 74 | 75 | global_pred = vlm_caption(img_path, task='vlm_global_describe', vlm=vlm, vlm_tokenizer=vlm_tokenizer) 76 | global_score = llm_align(text, global_pred, llm=llm, llm_tokenizer=llm_tokenizer) 77 | if global_score == -1: 78 | return -1, -1, [] 79 | 80 | local_score_list = [] 81 | for local_item in local_list: 82 | 83 | try: 84 | local_pred = vlm_caption(img_path, task='vlm_local_describe', bbox=local_item['bbox'], vlm=vlm, vlm_tokenizer=vlm_tokenizer) 85 | local_score = llm_align(local_item['text'], local_pred, llm=llm, llm_tokenizer=llm_tokenizer) 86 | if local_score == -1: 87 | return -1, -1, [] 88 | except: 89 | local_score_list.append(0) 90 | continue 91 | 92 | # crop score 93 | with Image.open(img_path) as img: 94 | 95 | try: 96 | crop_img = img.crop(scale_bbox(local_item['bbox'])) 97 | create_dir(crop_path) 98 | crop_img.save(crop_path) 99 | except: 100 | local_score_list.append(0) 101 | continue 102 | 103 | local_crop_pred = vlm_caption(crop_path, task='vlm_global_describe', vlm=vlm, vlm_tokenizer=vlm_tokenizer) 104 | local_crop_score = llm_align(local_item['text'], local_crop_pred, llm=llm, llm_tokenizer=llm_tokenizer) 105 | if local_crop_score == -1: 106 | return -1, -1, [] 107 | 108 | local_score_list.append(local_score * local_weights[0] + local_crop_score * local_weights[1]) 109 | 110 | score = global_score * weights[0] + np.mean(local_score_list) * weights[1] 111 | 112 | return score, global_score, local_score_list 113 | 114 | 115 | def eval_image( 116 | syn_data: dict, 117 | vlm: AutoModelForCausalLM = None, 118 | vlm_tokenizer: AutoTokenizer = None, 119 | llm: AutoModelForCausalLM = None, 120 | llm_tokenizer: AutoTokenizer = None 121 | ): 122 | 123 | # initialize task 124 | suffix = 'gc5-seed0-alpha0.8/' 125 | img_dir = f"{syn_data['img_dir']}{suffix}" 126 | 127 | score_list = [] 128 | 129 | for root, dirs, files in os.walk(img_dir): 130 | for file in files: 131 | 132 | if 'xl' not in file: 133 | continue 134 | 135 | # get image 136 | img_path = os.path.join(root, file) 137 | img = Image.open(img_path).convert("RGB") 138 | 139 | text = syn_data['caption'] 140 | 141 | local_list = [] 142 | for i in syn_data['layout']: 143 | if type(i['bbox']) == list and len(i['bbox']) == 4: 144 | local_list.append({ 145 | "bbox": i['bbox'], 146 | "text": get_attribute(i['object'], syn_data) 147 | }) 148 | 149 | # calculate score 150 | score, global_score, local_score_list = clis_img(img_path, text, local_list, crop_path=f"crop/0.png", vlm=vlm, vlm_tokenizer=vlm_tokenizer, llm=llm, llm_tokenizer=llm_tokenizer) 151 | 152 | score_list.append({ 153 | "score": score, 154 | "file_name": img_path, 155 | "img_path": img_path, 156 | "global_score": global_score, 157 | "local_score_list": local_score_list, 158 | }) 159 | 160 | if len(score_list) > 0: 161 | 162 | # update syn_data 163 | max_score_item = max(score_list, key=lambda x: x['score']) 164 | syn_data['file_name'] = max_score_item['file_name'] 165 | syn_data['img_path'] = max_score_item['img_path'] 166 | syn_data['score'] = max_score_item['score'] 167 | syn_data['global_score'] = max_score_item['global_score'] 168 | syn_data['local_score_list'] = max_score_item['local_score_list'] 169 | 170 | return syn_data, score_list 171 | -------------------------------------------------------------------------------- /CLIS/eval_layout.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import argparse 5 | import numpy as np 6 | 7 | from tqdm import tqdm 8 | from typing import Callable 9 | 10 | 11 | def get_pred_layout(pred_item: dict, subject: str, object: str): 12 | 13 | s_bbox = -1 14 | o_bbox = -1 15 | 16 | for i in pred_item['layout']: 17 | if i['object'].split('-')[-1] == subject.split('-')[-1]: 18 | s_bbox = i['bbox'] 19 | if i['object'].split('-')[-1] == object.split('-')[-1]: 20 | o_bbox = i['bbox'] 21 | 22 | return s_bbox, o_bbox 23 | 24 | 25 | ''' 26 | ============================== Penalty Functions ============================== 27 | ''' 28 | 29 | def non_penalty(score: float): 30 | return score 31 | 32 | 33 | def linear_penalty(score: float, penalty_threshold: float = 0.1): 34 | return -(1 + penalty_threshold) * (1 - (score / penalty_threshold)) + penalty_threshold if score < penalty_threshold else score 35 | 36 | 37 | ''' 38 | ============================== Tool Functions ============================== 39 | ''' 40 | 41 | def cal_area(bbox: list): 42 | ''' 43 | Function: 44 | caluate the area of bbox 45 | 46 | Args: 47 | bbox: [x, y, w, h] 48 | ''' 49 | 50 | return bbox[2] * bbox[3] 51 | 52 | 53 | def cal_area_ratio(bbox1: list, bbox2: list): 54 | ''' 55 | Function: 56 | caluate the area ratio of bbox1 and bbox2 57 | ''' 58 | 59 | return cal_area(bbox1) / cal_area(bbox2) 60 | 61 | 62 | def cal_iou(bbox1: list, bbox2: list): 63 | ''' 64 | Function: 65 | calculate the iou of bbox1 and bbox2 66 | 67 | Args: 68 | bbox: [x, y, w, h] 69 | ''' 70 | 71 | x1 = max(bbox1[0], bbox2[0]) 72 | y1 = max(bbox1[1], bbox2[1]) 73 | x2 = min(bbox1[0] + bbox1[2], bbox2[0] + bbox2[2]) 74 | y2 = min(bbox1[1] + bbox1[3], bbox2[1] + bbox2[3]) 75 | 76 | # calculate inter area 77 | inter_area = max(0, x2 - x1) * max(0, y2 - y1) 78 | 79 | # calculate area of bbox 80 | area1 = cal_area(bbox1) 81 | area2 = cal_area(bbox2) 82 | 83 | # calculate union area 84 | union_area = area1 + area2 - inter_area 85 | 86 | # calculate iou 87 | iou = inter_area / union_area 88 | 89 | return iou 90 | 91 | 92 | def cal_diagonal(bbox: list): 93 | ''' 94 | Function: 95 | calculate the diagonal of bbox 96 | ''' 97 | 98 | return math.sqrt(bbox[2] ** 2 + bbox[3] ** 2) 99 | 100 | 101 | def cal_center(bbox: list): 102 | ''' 103 | Function: 104 | calculate the center of bbox 105 | ''' 106 | 107 | return bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2 108 | 109 | 110 | def cal_dist(x1: float, y1: float, x2: float, y2: float): 111 | return math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) 112 | 113 | 114 | def cal_rel_dist(bbox1: list, bbox2: list): 115 | ''' 116 | Function: 117 | calculate the relative distance of bbox1 and bbox2 118 | ''' 119 | 120 | d1 = cal_diagonal(bbox1) 121 | d2 = cal_diagonal(bbox2) 122 | d_avg = (d1 + d2) / 2 123 | 124 | c1 = cal_center(bbox1) 125 | c2 = cal_center(bbox2) 126 | 127 | d = cal_dist(c1[0], c1[1], c2[0], c2[1]) 128 | 129 | d_norm = d / d_avg 130 | 131 | return d_norm 132 | 133 | 134 | def cal_dir(bbox1: list, bbox2: list): 135 | ''' 136 | Function: 137 | calculate the direction vector 138 | ''' 139 | 140 | c1 = cal_center(bbox1) 141 | c2 = cal_center(bbox2) 142 | 143 | return np.array([c2[0] - c1[0], c2[1] - c1[1]]) 144 | 145 | 146 | def cal_sim(A: float, B: float): 147 | ''' 148 | Function: 149 | calculate the similarity between A and B 150 | ''' 151 | 152 | # judge 153 | if A + B < 1e-10: 154 | return 1.0 155 | 156 | # norm 157 | A_norm = A / (A + B) 158 | B_norm = B / (A + B) 159 | 160 | # calculate relative error 161 | relative_error = abs(A_norm - B_norm) / max(A_norm, B_norm) 162 | 163 | # calculate sim 164 | sim = 1 - relative_error 165 | 166 | return sim 167 | 168 | 169 | def update_weight(init_weight: float): 170 | ''' 171 | Function: 172 | update weight 173 | ''' 174 | 175 | if init_weight >= 0.9: 176 | weight = init_weight 177 | elif init_weight < 0.9 and init_weight >= 0.8: 178 | weight = init_weight * init_weight 179 | elif init_weight < 0.8 and init_weight >= 0.7: 180 | weight = init_weight * init_weight * init_weight 181 | elif init_weight < 0.7 and init_weight >= 0.6: 182 | weight = init_weight * init_weight * init_weight * init_weight 183 | else: 184 | weight = init_weight * init_weight * init_weight * init_weight * init_weight 185 | 186 | return weight 187 | 188 | 189 | def avg_score_by_conf(score_list, score_size_list, score_dist_list, score_dir_list, conf_list): 190 | ''' 191 | Function: 192 | average score by conf 193 | ''' 194 | 195 | conf_sum = sum(conf_list) 196 | 197 | if conf_sum == 0: 198 | return 0, 0, 0, 0 199 | 200 | conf_list = [i / conf_sum for i in conf_list] 201 | 202 | for i in range(len(conf_list)): 203 | score_list[i] *= conf_list[i] 204 | score_size_list[i] *= conf_list[i] 205 | score_dist_list[i] *= conf_list[i] 206 | score_dir_list[i] *= conf_list[i] 207 | 208 | return np.sum(score_list), np.sum(score_size_list), np.sum(score_dist_list), np.sum(score_dir_list) 209 | 210 | 211 | def judge_dir_symmetry(relation: str): 212 | ''' 213 | Function: 214 | judge dir_symmetry according to relation 215 | ''' 216 | 217 | if 'left' in relation.lower() or 'right' in relation.lower(): 218 | return False 219 | 220 | return True 221 | 222 | 223 | def get_eval_example_list_s_o(eval_example_list: list, subject: str, object: str, sim_map: dict, sim_threshold: float = 0.8): 224 | ''' 225 | Function: 226 | get eval_example_list_s_o from eval_example_list 227 | 228 | Args: 229 | eval_example_list = [{'img_id', 'img_path', 's_bbox', 'o_bbox', 'subject', 'object'}] 230 | ''' 231 | 232 | eval_example_list_s_o = [] 233 | 234 | for i in eval_example_list: 235 | 236 | try: 237 | s_sim = sim_map[convert_name(subject)][i['subject'].lower()] 238 | o_sim = sim_map[convert_name(object)][i['object'].lower()] 239 | except: 240 | s_sim = 0 241 | o_sim = 0 242 | 243 | if s_sim >= sim_threshold and o_sim >= sim_threshold: 244 | 245 | eval_example_list_s_o.append({ 246 | "img_id": i['info']['img_id'], 247 | 'img_path': i['info']['img_path'], 248 | 'caption': i['info']['caption'], 249 | 's_bbox': i['s_bbox'], 250 | 'o_bbox': i['o_bbox'], 251 | 'subject': i['subject'], 252 | 'object': i['object'], 253 | 's_sim': s_sim, 254 | 'o_sim': o_sim 255 | }) 256 | 257 | return eval_example_list_s_o 258 | 259 | 260 | def convert_name(name: str): 261 | ''' 262 | Function: 263 | convert name 264 | ''' 265 | 266 | index = name.split('-')[-1] 267 | name = name.replace(f'-{index}', '') 268 | name = name.replace('_', ' ') 269 | 270 | if '(' in name: 271 | index = name.find('(') 272 | if name[index - 1] != ' ': 273 | name = name.replace('(', ' (') 274 | 275 | return name 276 | 277 | 278 | ''' 279 | ============================== Eval Functions ============================== 280 | ''' 281 | 282 | def eval_size(pred_s_bbox: list, pred_o_bbox: list, example_s_bbox: list, example_o_bbox: list, weight: float = 1.0, penalty_threshold: float = 0.1, penalty_func: Callable = linear_penalty): 283 | ''' 284 | Function: 285 | Evaluate size similarity 286 | ''' 287 | 288 | # calculate area ratio 289 | A = cal_area_ratio(pred_s_bbox, pred_o_bbox) 290 | B = cal_area_ratio(example_s_bbox, example_o_bbox) 291 | 292 | # calculate sim 293 | score = cal_sim(A, B) * weight 294 | 295 | # penalty 296 | score = penalty_func(score, penalty_threshold) 297 | 298 | return score 299 | 300 | 301 | def eval_dist(pred_s_bbox: list, pred_o_bbox: list, example_s_bbox: list, example_o_bbox: list, weight: float = 1.0, penalty_threshold: float = 0.1, penalty_func: Callable = linear_penalty, use_balance: bool = False): 302 | ''' 303 | Function: 304 | Evaluate dist similarity 305 | ''' 306 | 307 | # calculate iou 308 | iou1 = cal_iou(pred_s_bbox, pred_o_bbox) 309 | iou2 = cal_iou(example_s_bbox, example_o_bbox) 310 | 311 | # calculate relative distance 312 | rel_dist1 = cal_rel_dist(pred_s_bbox, pred_o_bbox) 313 | rel_dist2 = cal_rel_dist(pred_s_bbox, pred_o_bbox) 314 | 315 | # calculate sim 316 | score_iou = cal_sim(iou1, iou2) * weight 317 | score_rel_dist = cal_sim(rel_dist1, rel_dist2) * weight 318 | 319 | # penalty 320 | score_iou = penalty_func(score_iou, penalty_threshold) 321 | score_rel_dist = penalty_func(score_iou, penalty_threshold) 322 | 323 | # assign weights 324 | weights = [0.5, 0.5] 325 | if use_balance: 326 | if score_iou > 0.95: 327 | weights = [0.2, 0.8] 328 | elif score_rel_dist > 0.95: 329 | weights = [0.8, 0.2] 330 | 331 | score = score_iou * weights[0] + score_rel_dist * weights[1] 332 | 333 | return score 334 | 335 | 336 | def eval_dir(pred_s_bbox: list, pred_o_bbox: list, example_s_bbox: list, example_o_bbox: list, mode: str = 'oppo', weight: float = 1.0, penalty_threshold: float = 0.1, penalty_func: Callable = linear_penalty, use_symmetry: bool = False): 337 | ''' 338 | Function: 339 | Evaluate direction similarity 340 | 341 | Args: 342 | mode: str, ['perp', 'oppo']. 'perp': score lower when reaching 90 degree. 'oppo': score lower when reaching 180 degree. 343 | ''' 344 | 345 | # calculate dir vector 346 | dir1 = cal_dir(pred_s_bbox, pred_o_bbox) 347 | dir2 = cal_dir(example_s_bbox, example_o_bbox) 348 | if use_symmetry: 349 | dir3 = np.array([-dir2[0], dir2[1]]) 350 | 351 | # process 0 vector 352 | if np.all(dir1 == 0) and np.all(dir2 == 0): 353 | return 1 354 | elif np.all(dir1 == 0) or np.all(dir2 == 0): 355 | return 0 356 | 357 | # calculate cos 358 | cos = np.dot(dir1, dir2) / (np.linalg.norm(dir1) * np.linalg.norm(dir2)) 359 | if use_symmetry: 360 | cos_symmetry = np.dot(dir1, dir3) / (np.linalg.norm(dir1) * np.linalg.norm(dir3)) 361 | cos = max(cos, cos_symmetry) 362 | 363 | # convert 364 | cos = abs(cos) if mode == 'perp' else (cos + 1) / 2 365 | cos *= weight 366 | 367 | cos = penalty_func(cos, penalty_threshold) 368 | 369 | return cos 370 | 371 | 372 | def rule_eval_single_item(pred_s_bbox: list, pred_o_bbox: list, example_list: list, weights: list, size_penalty_threshold: float = 0.03, dist_penalty_threshold: float = 0.03, dir_penalty_threshold: float = 0.03, use_balance: bool = True, use_symmetry: bool = True): 373 | ''' 374 | Function: 375 | eval single item according to rules designed 376 | ''' 377 | 378 | # initialize the task 379 | ret = [] 380 | index_len = 0 381 | conf = 0.0 382 | 383 | # evaluate 384 | for example in example_list: 385 | 386 | # update weight 387 | s_weight = update_weight(example['s_sim']) 388 | o_weight = update_weight(example['o_sim']) 389 | weight = s_weight * o_weight 390 | 391 | # update conf 392 | if example['s_sim'] >= 0.8 and example['o_sim'] >= 0.8: 393 | conf += weight 394 | index_len += 1 395 | else: 396 | conf += (weight * 0.0001) 397 | 398 | # calculate score 399 | score_size = eval_size(pred_s_bbox, pred_o_bbox, example['s_bbox'], example['o_bbox'], weight=weight, penalty_threshold=size_penalty_threshold) 400 | score_dist = eval_dist(pred_s_bbox, pred_o_bbox, example['s_bbox'], example['o_bbox'], weight=weight, penalty_threshold=dist_penalty_threshold, use_balance=use_balance) 401 | score_dir = eval_dir(pred_s_bbox, pred_o_bbox, example['s_bbox'], example['o_bbox'], weight=weight, penalty_threshold=dir_penalty_threshold, use_symmetry=use_symmetry) 402 | 403 | score = (score_size * weights[0] + score_dist * weights[1] + score_dir * weights[2]) / sum(weights) 404 | 405 | # update ret 406 | ret.append({ 407 | "img_id": example['img_id'], 408 | "img_path": example['img_path'], 409 | "pred_s_bbox": pred_s_bbox, 410 | "pred_o_bbox": pred_o_bbox, 411 | "example_s_bbox": example['s_bbox'], 412 | "example_o_bbox": example['o_bbox'], 413 | "score_size": score_size, 414 | "score_dist": score_dist, 415 | "score_dir": score_dir, 416 | "score": score, 417 | "sim": weight, 418 | }) 419 | 420 | # order 421 | ret.sort(key=lambda x: x['score'], reverse=True) 422 | 423 | # calculate index 424 | index = int(index_len / 50) 425 | 426 | return ret, conf, ret[index] 427 | 428 | 429 | def rule_eval_offline(pred_path: str, example_path: str = 'config/eval/relations_one_to_one.json', weights: list = [0.5, 1.0, 0.8], use_balance: bool = True, use_symmetry: bool = True, sim_threshold: float = 0.8, sim_map_path: str = 'config/eval/sim_map.json'): 430 | ''' 431 | Function: 432 | eval layout according to rules designed 433 | 434 | Args: 435 | pred_path: predictions from LLMs 436 | example_path: examples to refer to in evaluation 437 | 438 | ''' 439 | 440 | # get pred list 441 | with open(pred_path, 'r') as f: 442 | pred_list = json.load(f) 443 | 444 | # get example_list 445 | with open(example_path, 'r') as f: 446 | example_list = json.load(f) 447 | 448 | # get sim_map 449 | with open(sim_map_path, 'r') as f: 450 | sim_map = json.load(f) 451 | 452 | # initialize the task 453 | eval_fail_cnt = 0 # record the number of eval fail 454 | pred_fail_cnt = 0 # record the number of pred fail 455 | ret_list = [] 456 | score_list = [] 457 | score_size_list = [] 458 | score_dist_list = [] 459 | score_dir_list = [] 460 | conf_list = [] 461 | 462 | # evaluate 463 | for i in tqdm(range(len(pred_list)), desc="Evaluating"): 464 | 465 | # get item from list 466 | pred_item = pred_list[i] 467 | 468 | # get relations 469 | relation_list = pred_item['relations'] 470 | 471 | # initialize the task 472 | eval_fail = False 473 | pred_fail = False 474 | score_relation_list = [] 475 | score_size_relation_list = [] 476 | score_dist_relation_list = [] 477 | score_dir_relation_list = [] 478 | conf_relation_list = [] 479 | 480 | # evaluate relations 481 | for relation_item_list in relation_list: 482 | 483 | # check eval_fail and pred_fail 484 | if eval_fail or pred_fail: 485 | break 486 | 487 | for relation_item in relation_item_list['relations']: 488 | 489 | # check eval_fail and pred_fail 490 | if eval_fail or pred_fail: 491 | break 492 | 493 | try: 494 | # get info from relation_item 495 | relation = relation_item['relation'] 496 | subject_list = relation_item['subject'] 497 | object_list = relation_item['object'] 498 | 499 | assert type(subject_list) == list and type(object_list) == list 500 | 501 | # get example_list to evaluate 502 | eval_example_list = example_list[relation] 503 | print(len(eval_example_list)) 504 | except: 505 | eval_fail_cnt += 1 506 | eval_fail = True 507 | break 508 | 509 | # set dir_symmetry 510 | dir_symmetry = judge_dir_symmetry(relation) 511 | 512 | # evaluate subject - object pair 513 | for s in subject_list: 514 | 515 | # check eval_fail and pred_fail 516 | if eval_fail or pred_fail: 517 | break 518 | 519 | for o in object_list: 520 | 521 | # get example_list according to subject and object 522 | eval_example_list_s_o = get_eval_example_list_s_o(eval_example_list, s, o, sim_map, sim_threshold=sim_threshold) 523 | 524 | # check len 525 | if len(eval_example_list_s_o) == 0: 526 | eval_fail_cnt += 1 527 | eval_fail = True 528 | break 529 | 530 | # get pred layout 531 | pred_s_bbox, pred_o_bbox = get_pred_layout(pred_item, s, o) 532 | 533 | # check pred layout 534 | if pred_s_bbox == -1 or pred_o_bbox == -1 or len(pred_s_bbox) != 4 or len(pred_o_bbox) != 4: 535 | pred_fail_cnt += 1 536 | pred_fail = True 537 | break 538 | 539 | # evaluate 540 | _, conf, ret = rule_eval_single_item(pred_s_bbox, pred_o_bbox, eval_example_list_s_o, weights=weights, use_balance=use_balance, use_symmetry=use_symmetry) 541 | 542 | # update ret_list 543 | ret['caption'] = pred_item['caption'] 544 | ret_list.append(ret) 545 | 546 | # update score_list, conf_list 547 | score_relation_list.append(ret['score']) 548 | score_size_relation_list.append(ret['score_size']) 549 | score_dist_relation_list.append(ret['score_dist']) 550 | score_dir_relation_list.append(ret['score_dir']) 551 | conf_relation_list.append(conf) 552 | 553 | if eval_fail or pred_fail or len(score_relation_list) == 0: 554 | score_list.append(0) 555 | score_size_list.append(0) 556 | score_dist_list.append(0) 557 | score_dir_list.append(0) 558 | conf_list.append(0) 559 | else: 560 | score, score_size, score_dist, score_dir = avg_score_by_conf(score_relation_list, score_size_relation_list, score_dist_relation_list, score_dir_relation_list, conf_relation_list) 561 | score_list.append(score) 562 | score_size_list.append(score_size) 563 | score_dist_list.append(score_dist) 564 | score_dir_list.append(score_dir) 565 | conf_list.append(min(conf_relation_list)) 566 | 567 | print("pred_fail_cnt: ", pred_fail_cnt) 568 | print("eval_fail_cnt: ", eval_fail_cnt) 569 | 570 | save_rec = { 571 | "ret_list": ret_list, 572 | "score_list": score_list, 573 | "score_size_list": score_size_list, 574 | "score_dist_list": score_dist_list, 575 | "score_dir_list": score_dir_list, 576 | "pred_fail_cnt": pred_fail_cnt, 577 | "eval_fail_cnt": eval_fail_cnt, 578 | "conf_list": conf_list, 579 | } 580 | 581 | file_name = os.path.basename(pred_path) 582 | folder_name = os.path.dirname(pred_path) 583 | 584 | tar_dir = os.path.join(folder_name, 'eval/') 585 | os.makedirs(tar_dir, exist_ok=True) 586 | 587 | prefix = f"{sim_threshold}_" 588 | if use_balance: 589 | prefix += 'use_balance_' 590 | if use_symmetry: 591 | prefix += 'use_symmetry_' 592 | file_name = prefix + file_name 593 | 594 | # print(tar_dir) 595 | # print(file_name) 596 | 597 | with open(os.path.join(tar_dir, file_name), 'w') as f: 598 | json.dump(save_rec, f) 599 | 600 | return score_list, score_size_list, score_dist_list, score_dir_list, conf_list, ret_list 601 | 602 | 603 | if __name__ == '__main__': 604 | 605 | parser = argparse.ArgumentParser() 606 | 607 | parser.add_argument('--pred_path', type=str, default='lvis/debug.json') 608 | parser.add_argument('--example_path', type=str, default='flickr/parsed/combine/relations_one_to_one.json') 609 | parser.add_argument('--use_balance', action='store_true') 610 | parser.add_argument('--use_symmetry', action='store_true') 611 | parser.add_argument('--sim_threshold', type=float, default=0.6) 612 | 613 | args = parser.parse_args() 614 | 615 | score_list, score_size_list, score_dist_list, score_dir_list, conf_list, ret_list = rule_eval_offline( 616 | pred_path=args.pred_path, 617 | example_path=args.example_path, 618 | use_balance=args.use_balance, 619 | use_symmetry=args.use_symmetry, 620 | sim_threshold=args.sim_threshold 621 | ) 622 | 623 | score, score_size, score_dist, score_dir = avg_score_by_conf(score_list, score_size_list, score_dist_list, score_dir_list, conf_list) 624 | print(score) 625 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Auto Cherry-Picker: Learning from High-quality Generative Data Driven by Language 2 | Yicheng Chen, Xiangtai Li, Yining Li, Yanhong Zeng, Jianzong Wu, Xiangyu Zhao, Kai Chen 3 | 4 | ## Updates 5 | * **[2024/06]** Our paper [Auto Cherry-Picker: Learning from High-quality Generative Data Driven by Language](https://arxiv.org/pdf/2406.20085) is released. 6 | 7 | ## Introduction 8 | Auto Cherry Picker is a innovative framework designed to synthesize training samples for both perception and multi-modal reasoning tasks from a simple object list in natural language. It employs a nowly designed metric, CLIS, to ensure the quality of the synthetic data. 9 | 10 | ## Main Results 11 | 12 | ### Long-tailed Instance Segmentation Benchmark 13 | | Method | Backbone | $AP_r^{mask}$ | $AP^{mask}$ | 14 | | ---- | ---- | ---- | ---- | 15 | | Mask R-CNN | ResNet-50 | 9.3 | 21.7 | 16 | | Mask R-CNN w. ACP | ResNet-50 | 14.5(+5.2) | 22.8(+1.1)| 17 | | CenterNet2 w. Copy-Paste | Swin-B | 29.3 | 39.3 | 18 | | CenterNet2 w. ACP | Swin-B | 30.7(+1.4) | 39.6(+0.3)| 19 | 20 | ### Open-vocabulary Object Detection Benchmark 21 | | Dataset | Method | Backbone | $AP_{novel}^{box}$ | $AP^{box}$ | 22 | | ---- | ---- | ---- | ---- | ---- | 23 | | LVIS | Grounding-DINO | Swin-T | 31.7 | 48.7 | 24 | | LVIS | Grounding-DINO w. ACP | Swin-T | 33.0(+1.3) | 49.2 | 25 | | COCO | Grounding-DINO | Swin-T | 60.4 | 57.1 | 26 | | COCO | Grounding-DINO w. ACP | Swin-T | 60.8(+0.4) | 56.9 | 27 | 28 | ### Multi-modal Image-based Benchmarks 29 | | Method | LLM Backbone | MME | GQA | 30 | | ---- | ---- | ---- | ---- | 31 | | LLaVA-1.5 | Vicuna-7B | 1434.4 | 58.9 | 32 | | LLaVA-1.5 | Vicuna-13B | 1438.3 | 60.7 | 33 | | LLaVA-1.5 | LLama-3-8B | 1445.3 | 60.1 | 34 | | LLaVA-1.5 w. ACP | Vicuna-7B | 1514.5(+80.1) | 59.3(+0.4) | 35 | 36 | ## Installation 37 | 38 | ### Requirements 39 | Python 3.10 40 | 41 | Pytorch 2.3.0 42 | 43 | ### Conda Environment Setup 44 | ``` 45 | pip install -r requirements.txt 46 | ``` 47 | 48 | ### Prepare Scene Graph Generator 49 | Download Qwen1.5-14B-Chat 50 | ``` 51 | git clone https://huggingface.co/Qwen/Qwen1.5-14B-Chat 52 | ``` 53 | You can try other LLMs as Scene Graph Generator, and add it in the `config/model_config.json`. 54 | 55 | 56 | 57 | ### Prepare Image Generator 58 | * Step 1: Download InstanceDiffusion 59 | 60 | ``` 61 | git clone https://github.com/frank-xwang/InstanceDiffusion.git 62 | ``` 63 | * Step 2: Download model weights 64 | 65 | Please download the pretrained InstanceDiffusion from [Hugging Face](https://huggingface.co/xudongw/InstanceDiffusion/tree/main) or [Google Drive](https://drive.google.com/drive/folders/1Jm3bsBmq5sHBnaN5DemRUqNR0d4cVzqG?usp=sharing) and [SD1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt), place them under `InstanceDiffusion/pretrained` folder. 66 | 67 | Then, create a soft link under `ACP` folder. 68 | ``` 69 | ln -s InstanceDiffusion/pretrained ./pretrained 70 | ``` 71 | * Step 3: Download CLIP 72 | ``` 73 | git clone https://huggingface.co/openai/clip-vit-large-patch14 74 | ``` 75 | 76 | * Step 4: Download SDXL Refiner (Optional) 77 | ``` 78 | git clone https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0 79 | ``` 80 | To disable SDXL, you can set `args.cascade_strength` at `infer_image.py` to `0`. 81 | 82 | ### Prepare Image Filter 83 | Please download Qwen-VL-Chat 84 | ``` 85 | git clone https://huggingface.co/Qwen/Qwen-VL-Chat 86 | ``` 87 | 88 | ### Prepare Layout Filter 89 | Please construct example pool for CLIS-L. 90 | 91 | Download [sim_map.json](https://drive.google.com/uc?export=download&id=1vccyYDSUhoOM17k4W1vJL8v64R7IWh9m) and [relations_one_to_one.json](https://drive.google.com/uc?export=download&id=1AhXIJNxBEwO9a6MpLTkKz8dgkItGdNSe) under `config/eval/` 92 | 93 | ### Prepare Segmentor 94 | 95 | Download SAM model weights at [Github](https://github.com/facebookresearch/segment-anything#model-checkpoints) 96 | 97 | ``` 98 | mkdir sam 99 | cd sam 100 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 101 | ``` 102 | 103 | ## Quick Start 104 | 105 | ``` 106 | python inference_single_data.py 107 | ``` 108 | You can custom object list at `inputs/demo.json`. The generated images are under `images/` and the synthesis training sample is under `syn_data/`. 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /config/bbox.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "prompt": { 4 | "objects": ["airplane-1", "car-2"], 5 | "dense_caption": "A commercial aircraft(1) flies low above a white car(2)." 6 | }, 7 | "output": { 8 | "width": 480, 9 | "height": 640, 10 | "Layout": [ 11 | {"object": "airplane-1", "bbox": [246.23, 276.98, 108.95, 34.85]}, 12 | {"object": "car-2", "bbox": [134.54, 446.69, 187.18, 157.08]} 13 | ] 14 | } 15 | }, 16 | { 17 | "prompt": { 18 | "objects": ["book-1", "computer-2"], 19 | "dense_caption": "A book(1) is to the left of an open laptop(2)." 20 | }, 21 | "output": { 22 | "width": 375, 23 | "height": 500, 24 | "Layout": [ 25 | {"object": "book-1", "bbox": [5.2, 84.03, 263.39, 181.98]}, 26 | {"object": "computer-2", "bbox": [239.33, 0.94, 135.67, 489.88]} 27 | ] 28 | } 29 | }, 30 | { 31 | "prompt": { 32 | "objects": ["refrigerator-1", "oven-2"], 33 | "dense_caption": "An oven(2) with a black refrigerator(1) next to it." 34 | }, 35 | "output": { 36 | "width": 640, 37 | "height": 384, 38 | "Layout": [ 39 | {"object": "refrigerator-1", "bbox": [111.22, 0.05, 175.04, 378.62]}, 40 | {"object": "oven-2", "bbox": [196.75, 162.23, 222.63, 217.46]} 41 | ] 42 | } 43 | }, 44 | { 45 | "prompt": { 46 | "objects": ["man-1", "man-2", "picture-3", "table-4", "clock-5"], 47 | "dense_caption": "A picture(picture-3) of two men(man-1, man-2) in the service on a table(table-4) below a clock(clock-5)." 48 | }, 49 | "output": "" 50 | } 51 | ] -------------------------------------------------------------------------------- /config/create_dataset.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "objects": { 4 | "objects": ["dog-1", "collar-2", "cowbell-3", "steps-4", "bag-5"] 5 | }, 6 | "output": { 7 | "attributes": [ 8 | {"name": "dog-1", "attributes": "a brown and white pub dog"}, 9 | {"name": "collar-2", "attributes": "a black dog chain collar"}, 10 | {"name": "cowbell-3", "attributes": "a cowbell"}, 11 | {"name": "steps-4", "attributes": "the steps of a building"}, 12 | {"name": "bag-5", "attributes": "a large white bag with black writing"} 13 | ], 14 | "groups": [ 15 | {"id": "group-1", "group": ["dog-1", "collar-2", "cowbell-3", "steps-4"]}, 16 | {"id": "group-2", "group": ["bag-5"]} 17 | ], 18 | "layers of depth": { 19 | "Foreground": ["group-1"], 20 | "Midground": [], 21 | "Background": ["group-2"] 22 | }, 23 | "relations": [ 24 | {"caption": "A dog(dog-1) is sitting on the steps(steps-4) of a building. It is wearing a collar(collar-2). The collar(collar-2) has a cowbell(cowbell-3) attached to it.", "relations": [ 25 | {"subject": ["dog-1"], "object": ["steps-4"], "relation": "sitting on"}, 26 | {"subject": ["dog-1"], "object": ["collar-4"], "relation": "wearing"}, 27 | {"subject": ["cowbell-3"], "object": ["collar-2"], "relation": "attached to"} 28 | ]} 29 | ], 30 | "caption": "A large brown dog(dog-1) is sitting on the steps(steps-4) of a building. It is wearing a black chain dog collar(collar-2). The collar(collar-2) has a cowbell(cowbell-3) attached to it. There is a bag(bag-5) in the background with black writings on it." 31 | } 32 | }, 33 | { 34 | "objects": { 35 | "objects": ["person-1", "person-2", "cellphone-3", "backpack-4", "jacket-5", "umbrella-6", "backpack-7", "flower-8"] 36 | }, 37 | "output": { 38 | "attributes": [ 39 | {"name": "person-1", "attributes": "a woman"}, 40 | {"name": "person-2", "attributes": "a woman"}, 41 | {"name": "cellphone-3", "attributes": "a cellphone"}, 42 | {"name": "backpack-4", "attributes": "a white and gray backpack"}, 43 | {"name": "jacket-5", "attributes": "a gray jacket"}, 44 | {"name": "umbrella-6", "attributes": "an umbrella"}, 45 | {"name": "backpack-7", "attributes": "a black backpack with a white tag"}, 46 | {"name": "follower-8", "attributes": "a tall planter of flowers"} 47 | ], 48 | "groups": [ 49 | {"id": "group-1", "group": ["woman-1", "jacket-5", "backpack-4", "cellphone-3"]}, 50 | {"id": "group-2", "group": ["woman-2", "umbrella-6", "backpack-7"]}, 51 | {"id": "group-3", "group": ["follower-8"]} 52 | ], 53 | "layers of depth": { 54 | "Foreground": ["group-1", "group-2"], 55 | "Midground": [], 56 | "Background": ["group-3"] 57 | }, 58 | "relations": [ 59 | {"caption": "a woman(woman-1) wearing a jacket(jacket-5), carrying a backpack(backpack-4) and a cellphone(cellphone-3)", "relations": [{"subject": ["woman-1"], "object": ["backpack-4", "cellphone-3"], "relation": "carrying"}, {"subject": ["woman-1"], "object": ["jacket-5"], "relation": "wearing"}]}, 60 | {"caption": "a woman(woman-2) in a gray coat with a backpack(backpack-7) holding an umbrella(umbrella-6)", "relations": [{"subject": ["woman-2"], "object": ["backpack-4"], "relation": "with"}, {"subject": ["woman-2"], "object": ["umbrella-6"], "relation": "holding"}]} 61 | ], 62 | "caption": "Two women(woman-1, woman-2) with backpacks(backpack-4, backpack-7) are taking a selfie with cellphone(cellphone-3) in front of a flower-covered(flower-8) wall, enjoying their time together in the city. One of them(woman-1) is wearing a gray jacket(jacket-5) and a white and gray backpack(backpack-4), while the other(woman-2) is holding a umbrella(umbrella-6) and a black backpack(backpack-7) with a white tag. They are surrounded by potted plants and a tall planter of flowers(flower-8)." 63 | } 64 | }, 65 | { 66 | "objects": { 67 | "objects": ["rose-1", "rose-2", "rose-3", "rose-4", "rose-5", "rose-6", "rose-7", "rose-8", "rose-9", "rose-10", "rose", "rose-12", "rose-13", "rose-14", "rose-15", "rose-16", "rose-17", "rose-18", "rose-19", "sky-20", "ocean-21", "pathway-22", "bouquet of flowers-23", "plant-24", "plant-25", "plant-26"] 68 | }, 69 | "output": "" 70 | } 71 | ] -------------------------------------------------------------------------------- /config/model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "qwen-1.5-6b": "Qwen/Qwen1.5-7B-Chat", 3 | "qwen-1.5-14b": "Qwen/Qwen1.5-14B-Chat", 4 | "qwen-1.5-72b": "Qwen/Qwen1.5-72B-Chat", 5 | "internlm-2-20b": "internlm/internlm2-chat-20b", 6 | "llama-2-7b": "meta-llama/Llama-2-7b-chat-hf", 7 | "llama-2-13b": "meta-llama/Llama-2-13b-chat-hf", 8 | "llama-3-8b": "meta-llama/Meta-Llama-3-8B-Instruct/" 9 | } -------------------------------------------------------------------------------- /infer_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import argparse 5 | import numpy as np 6 | 7 | import sys 8 | sys.path.append('./InstanceDiffusion') 9 | 10 | from functools import partial 11 | from omegaconf import OmegaConf 12 | from PIL import Image, ImageDraw 13 | from diffusers import StableDiffusionXLImg2ImgPipeline 14 | 15 | from ldm.util import instantiate_from_config 16 | from ldm.models.diffusion.plms import PLMSSampler 17 | from ldm.models.diffusion.plms_instance import PLMSSamplerInst 18 | from dataset.decode_item import sample_random_points_from_mask, sample_sparse_points_from_mask, decodeToBinaryMask, reorder_scribbles 19 | 20 | from skimage.transform import resize 21 | from InstanceDiffusion.utils.checkpoint import load_model_ckpt 22 | from InstanceDiffusion.utils.input import convert_points, prepare_batch, prepare_instance_meta 23 | from InstanceDiffusion.utils.model import create_clip_pretrain_model, set_alpha_scale, alpha_generator 24 | 25 | device = "cuda" 26 | 27 | def complete_mask(has_mask, max_objs): 28 | mask = torch.ones(1,max_objs) 29 | if has_mask == None: 30 | return mask 31 | 32 | if type(has_mask) == int or type(has_mask) == float: 33 | return mask * has_mask 34 | else: 35 | for idx, value in enumerate(has_mask): 36 | mask[0,idx] = value 37 | return mask 38 | 39 | @torch.no_grad() 40 | def get_model_inputs(meta, model, text_encoder, diffusion, clip_model, clip_processor, config, grounding_tokenizer_input, starting_noise=None, instance_input=False): 41 | if not instance_input: 42 | # update config from args 43 | config.update( vars(args) ) 44 | config = OmegaConf.create(config) 45 | 46 | # prepare a batch of samples 47 | batch = prepare_batch(meta, batch=config.num_images, max_objs=30, model=clip_model, processor=clip_processor, image_size=model.image_size, use_masked_att=True, device="cuda") 48 | context = text_encoder.encode( [meta["prompt"]]*config.num_images ) 49 | 50 | # unconditional input 51 | if not instance_input: 52 | uc = text_encoder.encode( config.num_images*[""] ) 53 | if args.negative_prompt is not None: 54 | uc = text_encoder.encode( config.num_images*[args.negative_prompt] ) 55 | else: 56 | uc = None 57 | 58 | # sampler 59 | if not instance_input: 60 | alpha_generator_func = partial(alpha_generator, type=meta.get("alpha_type")) 61 | if config.mis > 0: 62 | sampler = PLMSSamplerInst(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale, mis=config.mis) 63 | else: 64 | sampler = PLMSSampler(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale) 65 | steps = 50 66 | else: 67 | sampler, steps = None, None 68 | 69 | # grounding input 70 | grounding_input = grounding_tokenizer_input.prepare(batch, return_att_masks=return_att_masks) 71 | 72 | # model inputs 73 | input = dict(x = starting_noise, timesteps = None, context = context, grounding_input = grounding_input) 74 | return input, sampler, steps, uc, config 75 | 76 | @torch.no_grad() 77 | def run(meta, model, autoencoder, text_encoder, diffusion, clip_model, clip_processor, config, grounding_tokenizer_input, starting_noise=None, guidance_scale=None): 78 | # prepare models inputs 79 | input, sampler, steps, uc, config = get_model_inputs(meta, model, text_encoder, diffusion, clip_model, clip_processor, config, grounding_tokenizer_input, starting_noise, instance_input=False) 80 | if guidance_scale is not None: 81 | config.guidance_scale = guidance_scale 82 | 83 | # prepare models inputs for each instance if MIS is used 84 | if args.mis > 0: 85 | input_all = [input] 86 | for i in range(len(meta['phrases'])): 87 | meta_instance = prepare_instance_meta(meta, i) 88 | input_instance, _, _, _, _ = get_model_inputs(meta_instance, model, text_encoder, diffusion, clip_model, clip_processor, config, grounding_tokenizer_input, starting_noise, instance_input=True) 89 | input_all.append(input_instance) 90 | else: 91 | input_all = input 92 | 93 | # start sampling 94 | shape = (config.num_images, model.in_channels, model.image_size, model.image_size) 95 | with torch.autocast(device_type=device, dtype=torch.float16): 96 | samples_fake = sampler.sample(S=steps, shape=shape, input=input_all, uc=uc, guidance_scale=config.guidance_scale) 97 | samples_fake = autoencoder.decode(samples_fake) 98 | 99 | # define output folder 100 | output_folder = os.path.join( args.output, meta["save_folder_name"]) 101 | os.makedirs( output_folder, exist_ok=True) 102 | 103 | start = len( os.listdir(output_folder) ) 104 | image_ids = list(range(start,start+config.num_images)) 105 | # print(image_ids) 106 | 107 | # visualize the boudning boxes 108 | image_boxes = draw_boxes( meta["locations"], meta["phrases"], meta["prompt"] + ";alpha=" + str(meta['alpha_type'][0]) ) 109 | img_name = os.path.join( output_folder, str(image_ids[0])+'_boxes.png' ) 110 | image_boxes.save( img_name ) 111 | print("saved image with boxes at {}".format(img_name)) 112 | 113 | # if use cascade model, we will use SDXL-Refiner to refine the generated images 114 | if config.cascade_strength > 0: 115 | pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( 116 | "/mnt/petrelfs/chenyicheng/workspace/code/cyc/InstanceDiffusion/stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True 117 | ) 118 | pipe = pipe.to("cuda:0") 119 | strength, steps = config.cascade_strength, 20 # default setting, need to be manually tuned. 120 | 121 | # save the generated images 122 | for image_id, sample in zip(image_ids, samples_fake): 123 | img_name = str(int(image_id))+'.png' 124 | sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5 125 | sample = sample.cpu().numpy().transpose(1,2,0) * 255 126 | sample = Image.fromarray(sample.astype(np.uint8)) 127 | if config.cascade_strength > 0: 128 | prompt = meta["prompt"] 129 | refined_image = pipe(prompt, image=sample, strength=strength, num_inference_steps=steps).images[0] 130 | refined_image.save( os.path.join(output_folder, img_name.replace('.png', '_xl_s{}_n{}.png'.format(strength, steps))) ) 131 | sample.save( os.path.join(output_folder, img_name) ) 132 | 133 | def rescale_box(bbox, width, height): 134 | x0 = bbox[0]/width 135 | y0 = bbox[1]/height 136 | x1 = (bbox[0]+bbox[2])/width 137 | y1 = (bbox[1]+bbox[3])/height 138 | return [x0, y0, x1, y1] 139 | 140 | def get_point_from_box(bbox): 141 | x0, y0, x1, y1 = bbox[0], bbox[1], bbox[2], bbox[3] 142 | return [(x0 + x1)/2.0, (y0 + y1)/2.0] 143 | 144 | def rescale_points(point, width, height): 145 | return [point[0]/float(width), point[1]/float(height)] 146 | 147 | def rescale_scribbles(scribbles, width, height): 148 | return [[scribble[0]/float(width), scribble[1]/float(height)] for scribble in scribbles] 149 | 150 | # draw boxes given a lits of boxes: [[top left cornor, top right cornor, width, height],] 151 | # show descriptions per box if descriptions is not None 152 | def draw_boxes(boxes, descriptions=None, caption=None): 153 | width, height = 512, 512 154 | image = Image.new("RGB", (width, height), (255, 255, 255)) 155 | draw = ImageDraw.Draw(image) 156 | boxes = [ [ int(x*width) for x in box ] for box in boxes] 157 | for i, box in enumerate(boxes): 158 | draw.rectangle( ( (box[0], box[1]), (box[2], box[3]) ), outline=(0,0,0), width=2) 159 | if descriptions is not None: 160 | for idx, box in enumerate(boxes): 161 | draw.text((box[0], box[1]), descriptions[idx], fill="black") 162 | if caption is not None: 163 | draw.text((0, 0), caption, fill=(255,102,102)) 164 | return image 165 | 166 | def infer_image( 167 | output, 168 | num_images, 169 | input_json, 170 | ): 171 | 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument("--output", type=str, default="OUTPUT", help="root folder for output") 174 | parser.add_argument("--num_images", type=int, default=8, help="") 175 | parser.add_argument("--guidance_scale", type=float, default=5, help="") 176 | parser.add_argument("--negative_prompt", type=str, default='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality', help="") 177 | parser.add_argument("--input_json", type=str, default='/data/home/xudongw/InstanceDiffusion/demos/demo_multiround_r1.json', help="") 178 | parser.add_argument("--ckpt", type=str, default='InstanceDiffusion/pretrained/instancediffusion_sd15.pth', help="") 179 | parser.add_argument("--seed", type=int, default=0, help="random seed") 180 | parser.add_argument("--alpha", type=float, default=0.8, help="the percentage of timesteps using grounding inputs") 181 | parser.add_argument("--mis", type=float, default=0.36, help="the percentage of timesteps using MIS") 182 | parser.add_argument("--cascade_strength", type=float, default=0.3, help="strength of SDXL Refiner.") 183 | parser.add_argument("--test_config", type=str, default="InstanceDiffusion/configs/test_box.yaml", help="config for model inference.") 184 | 185 | global args 186 | args = parser.parse_args() 187 | 188 | args.output = output 189 | args.num_images = num_images 190 | args.input_json = input_json 191 | 192 | global return_att_masks 193 | return_att_masks = False 194 | ckpt = args.ckpt 195 | 196 | seed = args.seed 197 | save_folder_name = f"gc{args.guidance_scale}-seed{seed}-alpha{args.alpha}" 198 | 199 | # read json files 200 | with open(args.input_json) as f: 201 | data = json.load(f) 202 | 203 | # START: READ BOXES AND BINARY MASKS 204 | boxes = [] 205 | binay_masks = [] 206 | # class_names = [] 207 | instance_captions = [] 208 | points_list = [] 209 | scribbles_list = [] 210 | prompt = data['caption'] 211 | crop_mask_image = False 212 | for inst_idx in range(len(data['annos'])): 213 | if "mask" not in data['annos'][inst_idx] or data['annos'][inst_idx]['mask'] == []: 214 | instance_mask = np.zeros((512,512,1)) 215 | else: 216 | instance_mask = decodeToBinaryMask(data['annos'][inst_idx]['mask']) 217 | if crop_mask_image: 218 | # crop the instance_mask to 512x512, centered at the center of the instance_mask image 219 | # get the center of the instance_mask 220 | center = np.array([instance_mask.shape[0]//2, instance_mask.shape[1]//2]) 221 | # get the top left corner of the crop 222 | top_left = center - np.array([256, 256]) 223 | # get the bottom right corner of the crop 224 | bottom_right = center + np.array([256, 256]) 225 | # crop the instance_mask 226 | instance_mask = instance_mask[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] 227 | binay_masks.append(instance_mask) 228 | data['width'] = 512 229 | data['height'] = 512 230 | else: 231 | binay_masks.append(instance_mask) 232 | 233 | if "bbox" not in data['annos'][inst_idx]: 234 | boxes.append([0,0,0,0]) 235 | else: 236 | boxes.append(data['annos'][inst_idx]['bbox']) 237 | if 'point' in data['annos'][inst_idx]: 238 | points_list.append(data['annos'][inst_idx]['point']) 239 | if "scribble" in data['annos'][inst_idx]: 240 | scribbles_list.append(data['annos'][inst_idx]['scribble']) 241 | # class_names.append(data['annos'][inst_idx]['category_name']) 242 | instance_captions.append(data['annos'][inst_idx]['caption']) 243 | # show_binary_mask(binay_masks[inst_idx]) 244 | 245 | # END: READ BOXES AND BINARY MASKS 246 | img_info = {} 247 | img_info['width'] = data['width'] 248 | img_info['height'] = data['height'] 249 | 250 | locations = [rescale_box(box, img_info['width'], img_info['height']) for box in boxes] 251 | phrases = instance_captions 252 | 253 | # get points for each instance, if not provided, use the center of the box 254 | if len(points_list) == 0: 255 | points = [get_point_from_box(box) for box in locations] 256 | else: 257 | points = [rescale_points(point, img_info['width'], img_info['height']) for point in points_list] 258 | 259 | # get binary masks for each instance, if not provided, use all zeros 260 | binay_masks = [] 261 | for i in range(len(locations) - len(binay_masks)): 262 | binay_masks.append(np.zeros((512,512,1))) 263 | 264 | # get scribbles for each instance, if not provided, use random scribbles 265 | if len(scribbles_list) == 0: 266 | for idx, mask_fg in enumerate(binay_masks): 267 | # get scribbles from segmentation if scribble is not provided 268 | scribbles = sample_random_points_from_mask(mask_fg, 20) 269 | scribbles = convert_points(scribbles, img_info) 270 | scribbles_list.append(scribbles) 271 | else: 272 | scribbles_list = [rescale_scribbles(scribbles, img_info['width'], img_info['height']) for scribbles in scribbles_list] 273 | scribbles_list = reorder_scribbles(scribbles_list) 274 | 275 | print("num of inst captions, masks, boxes and points: ", len(phrases), len(binay_masks), len(locations), len(points)) 276 | 277 | # get polygons for each annotation 278 | polygons_list = [] 279 | segs_list = [] 280 | for idx, mask_fg in enumerate(binay_masks): 281 | # binary_mask = mask_fg[:,:,0] 282 | polygons = sample_sparse_points_from_mask(mask_fg, k=256) 283 | if polygons is None: 284 | polygons = [0 for _ in range(256*2)] 285 | polygons = convert_points(polygons, img_info) 286 | polygons_list.append(polygons) 287 | 288 | segs_list.append(resize(mask_fg.astype(np.float32), (512, 512, 1))) 289 | 290 | segs = np.stack(segs_list).astype(np.float32).squeeze() if len(segs_list) > 0 else segs_list 291 | polygons = polygons_list 292 | scribbles = scribbles_list 293 | 294 | meta_list = [ 295 | # grounding inputs for generation 296 | dict( 297 | ckpt = ckpt, 298 | prompt = prompt, 299 | phrases = phrases, 300 | polygons = polygons, 301 | scribbles = scribbles, 302 | segs = segs, 303 | locations = locations, 304 | points = points, 305 | alpha_type = [args.alpha, 0.0, 1-args.alpha], 306 | save_folder_name=save_folder_name 307 | ), 308 | ] 309 | 310 | # set seed 311 | torch.manual_seed(seed) 312 | starting_noise = torch.randn(args.num_images, 4, 64, 64).to(device) 313 | 314 | model, autoencoder, text_encoder, diffusion, config = load_model_ckpt(meta_list[0]["ckpt"], args, device) 315 | clip_model, clip_processor = create_clip_pretrain_model() 316 | 317 | grounding_tokenizer_input = instantiate_from_config(config['grounding_tokenizer_input']) 318 | model.grounding_tokenizer_input = grounding_tokenizer_input 319 | 320 | for meta in meta_list: 321 | run(meta, model, autoencoder, text_encoder, diffusion, clip_model, clip_processor, config, grounding_tokenizer_input, starting_noise, guidance_scale=args.guidance_scale) 322 | -------------------------------------------------------------------------------- /inference_single_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import copy 5 | import argparse 6 | import numpy as np 7 | 8 | from transformers import AutoModelForCausalLM, AutoTokenizer 9 | from segment_anything import sam_model_registry, SamPredictor 10 | 11 | from prompt import get_prompt 12 | from llm import infer 13 | from utils.parse import parse_str 14 | from utils.utils import convert_example_list, get_query, get_sg, get_attribute, scale_bbox 15 | from utils.visualize import visualize_seg 16 | from infer_image import infer_image 17 | from CLIS.eval_image import eval_image 18 | from CLIS.eval_layout import rule_eval_offline, avg_score_by_conf 19 | 20 | 21 | # ------------------------- Inference ------------------------- # 22 | 23 | def infer_description( 24 | object_list: list, 25 | model: str = 'qwen-1.5-14b', 26 | task: str = 'create_dataset', 27 | example_prefix: str = 'config/', 28 | ): 29 | ''' 30 | Function: 31 | infer description given object list 32 | ''' 33 | 34 | # get example list 35 | with open(example_prefix + f"{task}.json", 'r') as f: 36 | ori_example_list = json.load(f) 37 | 38 | example_list = convert_example_list(task, ori_example_list) 39 | example_list = example_list[:-1] 40 | 41 | # get query 42 | ori_q = get_query(task, object_list) 43 | q = { 44 | "prompt": json.dumps(ori_q, indent=4), 45 | "output": "" 46 | } 47 | 48 | # get prompt 49 | prompt = get_prompt(task, example_list, q) 50 | 51 | # get response 52 | response = infer(model, prompt, h_model=llm, h_tokenizer=llm_tokenizer, max_tokens=1024) 53 | parsed_response = parse_str(response) 54 | 55 | return parsed_response 56 | 57 | 58 | def infer_layout( 59 | desc: dict, 60 | model: str = 'qwen-1.5-14b', 61 | task: str = 'bbox', 62 | example_prefix: str = 'config/' 63 | ): 64 | ''' 65 | Function: 66 | infer layout given description 67 | ''' 68 | 69 | with open(example_prefix + f"{task}.json", 'r') as f: 70 | ori_example_list = json.load(f) 71 | 72 | example_list = convert_example_list(task, ori_example_list) 73 | example_list = example_list[:-1] 74 | 75 | ori_q = get_query(task, desc) 76 | q = { 77 | "prompt": json.dumps(ori_q, indent=4), 78 | "output": "" 79 | } 80 | 81 | # get prompt 82 | prompt = get_prompt(task, example_list, q) 83 | 84 | # get response 85 | response = infer(model, prompt, h_model=llm, h_tokenizer=llm_tokenizer, max_tokens=1024) 86 | parsed_response = parse_str(response)['Layout'] 87 | 88 | return parsed_response 89 | 90 | 91 | def gen_image( 92 | sg: dict, 93 | num_images: int = 8, 94 | save_info_dir: str = 'gen_info/', 95 | save_img_dir: str = 'images/', 96 | prefix: str = '/mnt/petrelfs/chenyicheng/workspace/code/ACP/' 97 | ): 98 | 99 | # initialize task 100 | caption = sg['caption'] # global caption 101 | annos= [] 102 | for i in sg['layout']: 103 | i_bbox = i['bbox'] 104 | i_caption = get_attribute(i['object'], sg) 105 | annos.append({ 106 | "bbox": i_bbox, 107 | "caption": i_caption, 108 | }) 109 | 110 | gen_info = { 111 | "caption": caption, 112 | "width": 1.0, 113 | "height": 1.0, 114 | "annos": annos 115 | } 116 | 117 | if not os.path.exists(save_info_dir): 118 | os.makedirs(save_info_dir) 119 | cnt = len([name for name in os.listdir(save_info_dir) if os.path.isfile(os.path.join(save_info_dir, name))]) 120 | input_json = f"{save_info_dir}{cnt}.json" 121 | with open(input_json, 'w') as f: 122 | json.dump(gen_info, f) 123 | 124 | # generate images 125 | if not os.path.exists(save_img_dir): 126 | os.makedirs(save_img_dir) 127 | cnt = len([name for name in os.listdir(save_img_dir) if os.path.isdir(os.path.join(save_img_dir, name))]) 128 | output = f"{save_img_dir}{cnt}/" 129 | 130 | infer_image( 131 | output=f"{prefix}{output}", 132 | num_images=num_images, 133 | input_json=f"{prefix}{input_json}" 134 | ) 135 | 136 | # update 137 | syn_data = copy.deepcopy(sg) 138 | syn_data['img_dir'] = output 139 | 140 | return syn_data 141 | 142 | 143 | def seg( 144 | syn_data: dict, 145 | ): 146 | 147 | # load model 148 | sam = sam_model_registry['vit_h'](checkpoint='sam/sam_vit_h_4b8939.pth') 149 | sam.cuda() 150 | sam_mask_predictor = SamPredictor(sam) 151 | 152 | img_dir = syn_data['img_dir'] 153 | 154 | seg_syn_data_list = [] 155 | 156 | for root, dirs, files in os.walk(img_dir): 157 | for file in files: 158 | 159 | if 'xl' not in file: 160 | continue 161 | 162 | img_path = os.path.join(root, file) 163 | img_bgr = cv2.imread(img_path) 164 | img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) 165 | height, width = img_bgr.shape[:2] 166 | 167 | sam_mask_predictor.set_image(img_rgb) 168 | 169 | seg_syn_data = copy.deepcopy(syn_data) 170 | seg_syn_data['layout'] = [] 171 | 172 | for i, l in enumerate(syn_data['layout']): 173 | 174 | bbox = scale_bbox(l['bbox'], width, height) 175 | box = np.array(bbox) 176 | 177 | # get mask 178 | masks, scores, logits = sam_mask_predictor.predict( 179 | box=box, 180 | multimask_output=True 181 | ) 182 | mask = masks[np.argmax(scores)] 183 | 184 | # get contours 185 | contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 186 | 187 | if not contours: 188 | continue 189 | 190 | max_contour = max(contours, key=cv2.contourArea) 191 | segmentation = max_contour.flatten().tolist() 192 | 193 | seg_l = copy.deepcopy(l) 194 | seg_l['segmentation'] = [segmentation] 195 | 196 | # update 197 | seg_syn_data['layout'].append(seg_l) 198 | 199 | seg_syn_data['img_path'] = img_path 200 | 201 | seg_syn_data_list.append(seg_syn_data) 202 | 203 | return seg_syn_data_list 204 | 205 | 206 | if __name__ == '__main__': 207 | 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument("--object_list", type=str, default="inputs/demo.json", help="input json file of object list") 210 | 211 | args = parser.parse_args() 212 | 213 | # initialize object list 214 | with open(args.object_list, 'r') as f: 215 | object_list = json.load(f) 216 | 217 | # initialize model 218 | vlm_tokenizer = AutoTokenizer.from_pretrained('Qwen-VL-Chat', trust_remote_code=True) 219 | vlm = AutoModelForCausalLM.from_pretrained("Qwen-VL-Chat", device_map="cuda", trust_remote_code=True) 220 | vlm.cuda() 221 | vlm.eval() 222 | 223 | with open('config/model_config.json', 'r') as f: 224 | model_pool = json.load(f) 225 | 226 | llm_tokenizer = AutoTokenizer.from_pretrained(model_pool['qwen-1.5-14b'], trust_remote_code=True) 227 | llm = AutoModelForCausalLM.from_pretrained(model_pool['qwen-1.5-14b'], torch_dtype='auto', device_map='cuda', trust_remote_code=True) 228 | llm.cuda() 229 | llm.eval() 230 | 231 | vis_flag = True 232 | eval_layout_flag = True 233 | 234 | # get description 235 | desc= infer_description(object_list) 236 | 237 | # get layout 238 | layout = infer_layout(desc) 239 | 240 | # combine scene graph 241 | sg = get_sg(desc, layout) 242 | 243 | # generate images 244 | syn_data = gen_image(sg, num_images=16) 245 | 246 | # seg 247 | seg_syn_data_list = seg(syn_data) 248 | 249 | # visualize segmentation 250 | if vis_flag: 251 | for seg_syn_data in seg_syn_data_list: 252 | visualize_seg(seg_syn_data) 253 | 254 | # eval 255 | syn_data, score_list = eval_image(syn_data, vlm=vlm, vlm_tokenizer=vlm_tokenizer, llm=llm, llm_tokenizer=llm_tokenizer) 256 | print(syn_data) 257 | 258 | # save 259 | data_dir = 'syn_data/' 260 | if not os.path.exists(data_dir): 261 | os.makedirs(data_dir) 262 | cnt = len([name for name in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, name))]) 263 | 264 | save_path = f"{data_dir}{cnt}.json" 265 | with open(save_path, 'w') as f: 266 | json.dump([syn_data], f) 267 | 268 | if eval_layout_flag: 269 | score_list, score_size_list, score_dist_list, score_dir_list, conf_list, ret_list = rule_eval_offline(pred_path=save_path, sim_threshold=0.4) 270 | score, score_size, score_dist, score_dir = avg_score_by_conf(score_list, score_size_list, score_dist_list, score_dir_list, conf_list) 271 | print(score) 272 | -------------------------------------------------------------------------------- /inputs/demo.json: -------------------------------------------------------------------------------- 1 | [ 2 | "dog", 3 | "frisbee" 4 | ] -------------------------------------------------------------------------------- /llm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | 4 | 5 | ''' 6 | define GPT parameters 7 | ''' 8 | gpt_url = "" 9 | gpt_headers = { 10 | 11 | } 12 | 13 | 14 | def infer(model, prompt, max_tokens=512, tp=2, top_p=0.8, top_k=40, temperature=0.8, h_model=None, h_tokenizer=None): 15 | ''' 16 | Function: 17 | Infer response of content from model 18 | 19 | Args: 20 | model: models to infer, ['gpt-3.5-turbo', 'gpt-4', 'internlm', 'backend', 'qwen'] 21 | content: the prompt to infer 22 | ''' 23 | 24 | # get messages from content 25 | messages = [ 26 | { 27 | "role": "user", 28 | "content": prompt 29 | } 30 | ] 31 | 32 | if h_model and h_tokenizer: 33 | device = "cuda" 34 | 35 | if 'llama' in model: 36 | input_ids = h_tokenizer.apply_chat_template( 37 | messages, 38 | add_generation_prompt=True, 39 | return_tensors="pt" 40 | ).to(device) 41 | 42 | terminators = [ 43 | h_tokenizer.eos_token_id, 44 | h_tokenizer.convert_tokens_to_ids("<|eot_id|>") 45 | ] 46 | 47 | outputs = h_model.generate( 48 | input_ids, 49 | max_new_tokens=max_tokens, 50 | eos_token_id=terminators, 51 | do_sample=True, 52 | temperature=0.6, 53 | top_p=0.9, 54 | ) 55 | response_ori = outputs[0][input_ids.shape[-1]:] 56 | 57 | response = h_tokenizer.decode(response_ori, skip_special_tokens=True) 58 | 59 | return response 60 | 61 | 62 | text = h_tokenizer.apply_chat_template( 63 | messages, 64 | tokenize=False, 65 | add_generation_prompt=True 66 | ) 67 | 68 | model_inputs = h_tokenizer([text], return_tensors="pt").to(device) 69 | 70 | generated_ids = h_model.generate( 71 | model_inputs.input_ids, 72 | max_new_tokens=max_tokens, 73 | ) 74 | generated_ids = [ 75 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 76 | ] 77 | 78 | response = h_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 79 | 80 | return response 81 | 82 | if 'gpt' in model: 83 | payload = { 84 | "model": model, 85 | "messages": messages 86 | } 87 | 88 | response = requests.post(gpt_url, headers=gpt_headers, data=json.dumps(payload)).json()['data'] 89 | 90 | return response.text if 'backend' in model else response['choices'][0]['message']['content'] 91 | -------------------------------------------------------------------------------- /prompt.py: -------------------------------------------------------------------------------- 1 | def get_content_template(task): 2 | ''' 3 | Function: 4 | ''' 5 | 6 | if task == 'create_dataset': 7 | content_template = """Please provide a JSON format with "attributes", "groups", "layers of depth", "relations", and "caption" based on the following prompt: {prompt}.\nDesired output:\n""" 8 | elif task == 'bbox': 9 | content_template = """Please provide a json format with 'Layout' based on the following prompt: {prompt}.\nThe layout is a list of json with 'object' and 'bbox'. 'object' refers to the object name in the prompt provided, while 'bbox' is formulated as [x,y,w,h], where "x,y" denotes the top left coordinate of the bounding box. "w" denotes the width, and "h" denotes the height. The bounding boxes should not go beyond the image boundaries. The six values "x,y,w,h,x+w,y+h" are all larger than 0 and smaller than 1.""" 10 | 11 | return content_template 12 | 13 | 14 | def get_task_description(task): 15 | ''' 16 | Function: 17 | Generate task description given different task 18 | ''' 19 | 20 | if task == 'create_dataset': 21 | # generate the data 22 | task_description = """We want to generate a scene graph given a list of objects. The object is in the format of '{object name}-{identifier}', where 'object name' means any object name in the world and 'identifier' is a unique number representing the difference between objects, especially those with the same name.\nPlease provide a JSON format with 'attributes', 'groups', 'layers of depth', 'relations', and 'caption'.\n1. 'attributes': should be descriptive color or texture of the corresponding object.\n2. 'groups': A group of objects exhibit strong spatial relationships that interact with each other.\n3. 'layers of depth': The scene is divided into different layers based on proximity - 'Immediate Foreground', 'Foreground', 'Midground', and 'Background'. Each layer depicts one or more groups of objects in (2) at that depth in the scene.\n4. 'relations': This section illustrates the interactions or spatial relationships between various objects or groups.\n5. 'caption': A simple and straightforward 1-2 sentence image caption. Please include all the objects in the caption and refer to them in '()'. Create the caption as if you are directly observing the image. Do not mention the use of any source data. Do not use words like 'indicate', 'suggest', 'hint', 'likely', or 'possibly'.\n\n""" 23 | elif task == 'bbox': 24 | # generate bbox given a list of objects and dense caption 25 | task_description = """The provided prompt is a list of object and corresponding description. The object is in the format of '{object name}-{identifier}', where 'object name' means any object name in the world and 'identifier' is a unique number representing the difference between objects, especially those with the same name. The corresponding description shows the relationship between objects.""" 26 | 27 | task_description = task_description + "Please refer to the examples below.\n" 28 | 29 | return task_description 30 | 31 | 32 | def get_norm_examples(example_list): 33 | 34 | example_template = """##Example-{index}\nPrompt: {prompt}\nDesired Output: {output}\n""" 35 | 36 | example_content = [] 37 | for i, example in enumerate(example_list): 38 | example_content.append(example_template.format(index=i, prompt=example['prompt'], output=example['output'])) 39 | 40 | examples = '\n'.join(example_content) 41 | 42 | return examples 43 | 44 | 45 | def get_examples(task, example_list): 46 | 47 | return get_norm_examples(example_list) 48 | 49 | 50 | def get_norm_content(task, q): 51 | 52 | content_template = get_content_template(task) 53 | 54 | content = content_template.format(prompt=q['prompt'], output=q['output']) if 'feedback' in task else content_template.format(prompt=q['prompt']) 55 | 56 | return content 57 | 58 | 59 | def get_content(task, q): 60 | 61 | return get_norm_content(task, q) 62 | 63 | 64 | def get_prompt( 65 | task: str, 66 | example_list: list = [], 67 | q: dict = None, 68 | ): 69 | ''' 70 | Function: 71 | 72 | ''' 73 | 74 | if task == 'vlm_global_describe': 75 | prompt = """You are my assistant to evaluate the correspondence of the image to a given text prompt.\nBriefly describe the image within 50 words, focus on the objects in the image and their attributes (such as color, shape, texture), spatial layout and action relationships.\n""" 76 | 77 | return prompt 78 | 79 | elif task == 'vlm_local_describe': 80 | prompt = """You are my assistant to identify the object and its color (shape, texture) in the ({xmin},{ymin}),({xmax},{ymax}) of the image.\nBriefly describe what it is in the specific part of the image within 50 words.\n""" 81 | 82 | return prompt 83 | 84 | elif task == 'llm_align': 85 | prompt = """You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs.\nYour task is to compare the predicted answer with the correct answer and determine if they match correctly based on the objects, and their actions, relationships. Here's how you can accomplish the task:\n------\n##INSTRUCTIONS:\n- Focus on the objects mentioned in the description and their actions and relationships when evaluating the meaningful match.\n- Consider synonyms or paraphrases as valid matches.\n- Evaluate the correctness of the prediction compared to the answer.\n\nPlease Evaluate the following answer pair:\n\nCorrect Answer: {answer}\nPredicted Answer: {pred}\n\nProvide your evaluation in the JONSON format with 'score' and 'explanation' key. The score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. The explanation should be within 20 words.""" 86 | 87 | return prompt 88 | 89 | template = """Task Description:\n{task_description}\nExamples:\n{examples}\nPlease complete the following one:\n{content}""" 90 | 91 | prompt = template.format( 92 | task_description=get_task_description(task), 93 | examples=get_examples(task, example_list), 94 | content=get_content(task, q), 95 | ) 96 | 97 | return prompt 98 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | academictorrents==2.3.3 2 | albumentations==1.4.10 3 | diffusers==0.21.0.dev0 4 | einops==0.8.0 5 | kornia==0.5.8 6 | lmdeploy==0.4.2 7 | matplotlib==3.9.0 8 | mmengine==0.10.4 9 | natsort==8.4.0 10 | nltk==3.9 11 | numpy==2.0.0 12 | omegaconf==2.3.0 13 | opencv_python==4.7.0.72 14 | pandas==2.0.2 15 | Pillow==10.3.0 16 | pycocotools==2.0.7 17 | pytorch_lightning==2.3.1 18 | PyYAML==6.0.1 19 | PyYAML==6.0.1 20 | ram==0.1 21 | Requests==2.32.3 22 | scipy==1.14.0 23 | segment_anything==1.0 24 | skimage==0.0 25 | sng_parser==1.3.0 26 | spacy==3.7.4 27 | submitit==1.5.1 28 | tiktoken==0.6.0 29 | timm==0.4.12 30 | torch==2.3.0 31 | torchvision==0.18.0 32 | tqdm==4.65.2 33 | transformers==4.40.2 34 | transformers_stream_generator==0.0.5 35 | vllm==0.4.2 36 | wandb==0.17.3 37 | webdataset==0.2.86 38 | -------------------------------------------------------------------------------- /utils/parse.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | 4 | 5 | # ------------------------- Parse ------------------------- # 6 | 7 | def parse_str(response, stop_word=None, prefix_word=None): 8 | ''' 9 | Function: 10 | from response str parse target dict 11 | 12 | Args: 13 | stop_word: str, default = None, filter the content of response before stop_word 14 | ''' 15 | 16 | # filter response 17 | if stop_word: 18 | stop_index = response.lower().find(stop_word.lower()) 19 | if stop_index != -1: 20 | response = response[:stop_index] 21 | 22 | if prefix_word: 23 | prefix_index = response.lower().find(prefix_word.lower()) 24 | if prefix_index != -1: 25 | response = response[prefix_index:] 26 | 27 | start = response.find('{') 28 | end = response.rfind('}') 29 | 30 | if start != -1 and end != -1: 31 | try: 32 | processed_response = json.loads(response[start:end+1]) 33 | except: 34 | cleaned_response = re.sub(r"//.*", "", response[start:end+1]) 35 | cleaned_response = re.sub(r'(?