├── csp_func_control.py ├── csp_solver.py ├── data ├── ablation_queries.json ├── counting_queries.json ├── negative_queries.json ├── scannet200_constants.py └── scannetv2_val.txt ├── misc_utils.py ├── output ├── eval_data_gt_nr3d.json ├── eval_data_gt_scanrefer.json ├── eval_data_mask3d_scanrefer.json ├── eval_results_gt_nr3d.json ├── eval_results_gt_scanrefer.json ├── eval_results_mask3d_scanrefer.json └── zsvg3d_eval_results.json ├── program_evaluator.py ├── program_functions_csp.py ├── program_generator.py ├── program_validator.py ├── prompts ├── filter_relevant_objects.txt ├── generate_program_csp.txt ├── generate_program_csp_no_counting.txt ├── generate_program_csp_no_counting_negation.txt ├── generate_program_csp_no_counting_negation_minmax.txt ├── generate_program_csp_no_minmax.txt ├── guess_thresholds.txt └── select_solution.txt ├── readme.md ├── run_scripts ├── scannet_utils.py ├── scope_env.py ├── score_funcs.py ├── visualize_eval_results.py └── visualize_segs.py /csp_func_control.py: -------------------------------------------------------------------------------- 1 | DISABLE_MINMAX = False 2 | DISABLE_NEGATION = False 3 | # DISABLE_COUNTING = False 4 | -------------------------------------------------------------------------------- /csp_solver.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import random 3 | import re 4 | import sys 5 | import textwrap 6 | import time 7 | import uuid 8 | from abc import ABC, abstractmethod 9 | from collections import defaultdict 10 | from typing import Generator 11 | 12 | import numpy as np 13 | from openai import OpenAI 14 | from typing_extensions import Self 15 | 16 | from scannet_utils import ObjInstance 17 | from scope_env import GlobalState, set_target_info 18 | 19 | ID_COUNTER = 0 20 | 21 | 22 | def reset_var_counter(): 23 | global ID_COUNTER 24 | ID_COUNTER = 0 25 | 26 | 27 | class CSPVar: 28 | """a CSP variable that represents a single instance in the scene""" 29 | 30 | def __init__(self, labels: list[str] = []): 31 | global ID_COUNTER 32 | ID_COUNTER += 1 33 | 34 | self.labels = labels 35 | self.obj_id = str(ID_COUNTER) 36 | self.is_target = False # only CSPVar can be a target 37 | self.negative = False 38 | 39 | def __hash__(self) -> int: 40 | return hash(self.obj_id) 41 | 42 | def __eq__(self, other: Self) -> bool: 43 | return self.obj_id == other.obj_id 44 | 45 | def set_as_target(self): 46 | self.is_target = True 47 | 48 | def get_identifier(self) -> str: 49 | return f"{self.obj_id}-{'-'.join(self.labels)}".replace(" ", "_") 50 | 51 | 52 | Solution = dict[CSPVar, ObjInstance] 53 | 54 | 55 | def solution_to_str(sol_dict: Solution) -> str: 56 | text = [] 57 | for csp_var, inst in sol_dict.items(): 58 | text.append(f"[{csp_var.get_identifier()} -> {inst.inst_id}: {inst.label}]") 59 | return "{" + " ".join(text) + "}" 60 | 61 | 62 | def load_prompt(file_path): 63 | with open(file_path) as f: 64 | msg_started = False 65 | msg_role = None 66 | msg = None 67 | prompt_dialog = [] 68 | for line in f: 69 | if not msg_started: 70 | line = line.strip() 71 | if line in {"<[SYSTEM]>", "<[USER]>", "<[ASSISTANT]>"}: 72 | msg_started = True 73 | msg_role = line[2:-2].lower().strip() 74 | msg = [] 75 | else: 76 | assert msg_role in {"system", "user", "assistant"} 77 | if line.strip() in {"<[SYSTEM]>", "<[USER]>", "<[ASSISTANT]>"}: 78 | assert msg 79 | prompt_dialog.append( 80 | {"role": msg_role, "content": "\n".join(msg).strip()} 81 | ) 82 | msg_role = line.strip()[2:-2].lower().strip() 83 | msg = [] 84 | else: 85 | msg.append(line.rstrip()) 86 | if msg: 87 | prompt_dialog.append({"role": msg_role, "content": "\n".join(msg).strip()}) 88 | 89 | # validate the dialog 90 | assert len(prompt_dialog) > 0 91 | assert prompt_dialog[0]["role"] == "system" 92 | for i, turn in enumerate(prompt_dialog[1:]): 93 | if i % 2 == 0: 94 | assert turn["role"] == "user" 95 | else: 96 | assert turn["role"] == "assistant" 97 | assert prompt_dialog[-1]["role"] in {"assistant", "system"} 98 | 99 | return prompt_dialog 100 | 101 | 102 | MODEL_NAME = "mistral-large-instruct-2407-awq" 103 | MODEL_NAME = "/cvhci/temp/qyuan/" + MODEL_NAME 104 | 105 | PROMPT = load_prompt("./prompts/select_solution.txt") 106 | 107 | 108 | def select_best_solution( 109 | query: str, 110 | csp_desc: str, 111 | valid_solutions: list[Solution], 112 | ) -> Solution: 113 | # def get_completion(prompt): 114 | # assert len(prompt) > 0 115 | # assert prompt[-1]["role"] == "user" 116 | 117 | # openai_client = OpenAI( 118 | # api_key="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", 119 | # # base_url="http://127.0.0.1:2242/v1", 120 | # timeout=3600.0, 121 | # ) 122 | 123 | # while True: 124 | # try: 125 | # completion = openai_client.chat.completions.create( 126 | # # model=MODEL_NAME, 127 | # model="gpt-4o-mini", 128 | # messages=prompt, 129 | # # max_tokens=2048, 130 | # # temperature=0.7, 131 | # ) 132 | 133 | # # print(completion.model) 134 | # answer = completion.choices[0].message.content.strip() 135 | 136 | # except Exception: 137 | # # print("retry...") 138 | # time.sleep(2.0) 139 | # else: 140 | # break 141 | 142 | # return answer 143 | 144 | def get_completion(prompt): 145 | assert len(prompt) > 0 146 | assert prompt[-1]["role"] == "user" 147 | 148 | openai_client = OpenAI( 149 | api_key="db72ad53ea0db1354d46405703546670", 150 | base_url="http://127.0.0.1:2242/v1", 151 | timeout=3600.0, 152 | ) 153 | 154 | completion = openai_client.chat.completions.create( 155 | model=MODEL_NAME, 156 | messages=prompt, 157 | # max_tokens=2048, 158 | # temperature=0.7, 159 | ) 160 | 161 | # print(completion.model) 162 | answer = completion.choices[0].message.content.strip() 163 | return answer 164 | 165 | assert query 166 | assert csp_desc 167 | assert 2 <= len(valid_solutions) <= 4 168 | 169 | objects_str = [] 170 | solutions_str = [] 171 | 172 | for csp_var in valid_solutions[0].keys(): 173 | csp_var_id = f"{'_'.join(csp_var.labels).replace(' ', '_')}_{csp_var.obj_id}" 174 | objects_str.append(csp_var_id) 175 | 176 | objects_str = ", ".join(objects_str) 177 | 178 | for i, sol_dict in enumerate(valid_solutions): 179 | solutions_str.append(f"solution {i}") 180 | for csp_var, inst in sol_dict.items(): 181 | csp_var_id = ( 182 | f"{'_'.join(csp_var.labels).replace(' ', '_')}_{csp_var.obj_id}" 183 | ) 184 | solutions_str.append( 185 | f"{csp_var_id}: " 186 | f"center=[{inst.bbox.center[0]:.3f}, {inst.bbox.center[1]:.3f}, {inst.bbox.center[2]:.3f}]; " 187 | f"size=[{inst.bbox.size[0]:.3f}, {inst.bbox.size[1]:.3f}, {inst.bbox.size[2]:.3f}]" 188 | ) 189 | solutions_str.append("") 190 | 191 | solutions_str = "\n".join(solutions_str) 192 | 193 | user_prompt = ( 194 | f"@QUERY\n{query}\n\n" f"@OBJECTS\n{objects_str}\n\n@SOLUTIONS\n{solutions_str}" 195 | ) 196 | 197 | prompt = PROMPT.copy() 198 | prompt.append({"role": "user", "content": user_prompt.strip()}) 199 | 200 | for i in range(5): 201 | response = get_completion(prompt) 202 | 203 | # print() 204 | # print("=" * 40) 205 | # print() 206 | # print(user_prompt) 207 | # print() 208 | # print("=" * 40) 209 | # print() 210 | # print(response) 211 | # print() 212 | # print("=" * 40) 213 | # print() 214 | 215 | matches = re.findall( 216 | r"^@BEGIN\{ANSWER\}\ncorrect solution index: \[(\d+)\]\n@END\{ANSWER\}", 217 | response, 218 | re.MULTILINE, 219 | ) 220 | if len(matches) != 1: 221 | prompt.append({"role": "assistant", "content": response}) 222 | prompt.append( 223 | { 224 | "role": "user", 225 | "content": ( 226 | "Your output is valid. You should include\n" 227 | "@BEGIN{ANSWER}\ncorrect solution index: [index]\n@END{ANSWER}\n" 228 | "exactly once!" 229 | ), 230 | } 231 | ) 232 | continue 233 | 234 | assert len(matches) == 1 235 | 236 | # print() 237 | # print(matches[0]) 238 | # print() 239 | 240 | break 241 | 242 | return valid_solutions[int(matches[0])] 243 | 244 | 245 | def print_solution(sol_dict: Solution): 246 | print() 247 | print("======= vvv SOLUTION vvv =======") 248 | for csp_var, cand_insts in sol_dict.items(): 249 | print( 250 | f"{type(csp_var)} {csp_var.obj_id} {csp_var.labels} => " 251 | f"{[(inst.inst_id, inst.label) for inst in cand_insts]}" 252 | ) 253 | print("======= ^^^ SOLUTION ^^^ =======") 254 | print() 255 | 256 | 257 | class CSPConstraint(ABC): 258 | """ 259 | base class for all constraints 260 | """ 261 | 262 | def __init__(self): 263 | global ID_COUNTER 264 | ID_COUNTER += 1 265 | 266 | # a unique constraint id 267 | # self.con_id = str(uuid.uuid4()) 268 | self.con_id = "con-" + str(ID_COUNTER) 269 | 270 | # the function name used in the generated program 271 | self.apparent_name: str = "(not set)" 272 | 273 | def __hash__(self) -> int: 274 | return hash(self.con_id) 275 | 276 | def __eq__(self, other: Self) -> bool: 277 | return self.con_id == other.con_id 278 | 279 | def set_apparent_name(self, name: str): 280 | self.apparent_name = name 281 | 282 | @abstractmethod 283 | def get_desc_str(self) -> str: 284 | """return a string describing the variables involved in this constraint""" 285 | 286 | @abstractmethod 287 | def get_target_var(self) -> CSPVar: 288 | """return the target variable of this constraint""" 289 | 290 | @abstractmethod 291 | def get_vars(self) -> set[CSPVar]: 292 | """return all variables invovled in this constraint""" 293 | 294 | @abstractmethod 295 | def check_solution(self, solution_dict: Solution) -> bool: 296 | """check if the given solution satisfies this constraint""" 297 | 298 | 299 | def solution_generator( 300 | csp_vars: list[CSPVar], 301 | ) -> Generator[Solution, None, None]: 302 | """generate all possible combinations of assignments for free variables""" 303 | assert all(isinstance(x, CSPVar) for x in csp_vars) 304 | csp_vars.sort(key=lambda x: x.get_identifier()) 305 | 306 | solution_dict: Solution = {} 307 | used_instances: set[ObjInstance] = set() 308 | 309 | def gen_func(i_var: int) -> Generator[Solution, None, None]: 310 | nonlocal solution_dict 311 | 312 | if i_var >= len(csp_vars): 313 | yield solution_dict.copy() 314 | return 315 | 316 | cur_var = csp_vars[i_var] 317 | assert not cur_var.negative 318 | 319 | for inst in GlobalState.get_cand_insts(cur_var.labels): 320 | if inst in used_instances: 321 | continue 322 | 323 | solution_dict[cur_var] = inst 324 | 325 | used_instances.add(inst) 326 | yield from gen_func(i_var + 1) 327 | used_instances.remove(inst) 328 | 329 | yield from gen_func(0) 330 | 331 | 332 | def check_solution(sol_dict: Solution) -> bool: 333 | """check if an instance is assigned to multiple variables""" 334 | used_insts = set() 335 | 336 | for csp_var, cand_inst in sol_dict.items(): 337 | if cand_inst in used_insts: 338 | return False 339 | 340 | used_insts.add(cand_inst) 341 | 342 | return True 343 | 344 | 345 | def get_solution_heuristic_score(sol_dict: Solution) -> tuple[float, str]: 346 | instances = list(sol_dict.values()) 347 | assert len(instances) > 0 348 | 349 | total_dist = 0 350 | for i in range(len(instances)): 351 | for j in range(i): 352 | total_dist += np.linalg.norm( 353 | instances[i].bbox.center - instances[j].bbox.center 354 | ) 355 | 356 | solution_str = "-".join( 357 | [csp_var.get_identifier() for csp_var, _ in sol_dict.items()] 358 | ) 359 | 360 | return (total_dist, solution_str) 361 | 362 | # # compute the center of all instances 363 | # center = np.mean([inst.bbox.center for inst in instances], axis=0) 364 | 365 | # # compute the average distance to the center 366 | # return np.average([np.linalg.norm(inst.bbox.center - center) for inst in instances]) 367 | 368 | 369 | def get_solution_heuristic_score_2( 370 | csp_var_groups: set[tuple[CSPVar, ...]], 371 | sol_dict: Solution, 372 | ) -> float: 373 | instances = list(sol_dict.values()) 374 | assert len(instances) > 0 375 | 376 | total_dist = 0 377 | for grp in csp_var_groups: 378 | inst_centers = [sol_dict[csp_var].bbox.center for csp_var in grp] 379 | center = np.mean(inst_centers, axis=0) 380 | total_dist += np.sum([np.linalg.norm(c - center) for c in inst_centers]) 381 | 382 | return total_dist 383 | 384 | 385 | class CSPSolver: 386 | """ 387 | the CSP (Constraint Satisfication Problem) whose solution will give the solution to 388 | the 3D visual grounding problem. 389 | """ 390 | 391 | def __init__(self): 392 | self.variables: set[CSPVar] = set() 393 | self.normal_constraints: set[CSPConstraint] = set() 394 | self.minmax_constraints: set[CSPConstraint] = set() 395 | 396 | # # only used to ensure each variable has at most one min/max constraint 397 | # self.minmax_variables: set[CSPVar] = set() 398 | 399 | def create_var(self, labels: list[str] = [], negative: bool = False) -> CSPVar: 400 | csp_var = CSPVar(labels=labels) 401 | csp_var.negative = negative 402 | self.variables.add(csp_var) 403 | return csp_var 404 | 405 | def add_constraint(self, constraint: CSPConstraint): 406 | assert hasattr(constraint, "FUNC_NAME") 407 | 408 | func_names: set[str] = set() 409 | if isinstance(constraint.FUNC_NAME, str): 410 | func_names.add(constraint.FUNC_NAME) 411 | else: 412 | func_names |= set(constraint.FUNC_NAME) 413 | 414 | if func_names & {"CONSTRAINT_MIN_OF", "CONSTRAINT_MAX_OF"}: 415 | self.minmax_constraints.add(constraint) 416 | # # a variable can only have a single min/max constraint 417 | # assert constraint.get_target_var() not in self.minmax_variables 418 | # self.minmax_variables.add(constraint.get_target_var()) 419 | else: 420 | self.normal_constraints.add(constraint) 421 | 422 | def get_constraints(self, csp_var: CSPVar) -> set[CSPConstraint]: 423 | """return all constraints related to the given variable""" 424 | return {con for con in self.normal_constraints if csp_var in con.get_vars()} 425 | 426 | def solve( 427 | self, 428 | query: str, 429 | select_solution: str = "min_dist", 430 | verbose: bool = False, 431 | ): 432 | assert query 433 | assert self.variables 434 | # assert self.normal_constraints or self.minmax_constraints 435 | 436 | if verbose: 437 | print() 438 | print("=" * 30) 439 | print() 440 | print("solving csp:") 441 | print(self.get_desc_str()) 442 | 443 | normal_vars = [csp_var for csp_var in self.variables if not csp_var.negative] 444 | negative_vars = [csp_var for csp_var in self.variables if csp_var.negative] 445 | 446 | for con in self.minmax_constraints: 447 | assert all(not csp_var.negative for csp_var in con.get_vars()) 448 | 449 | # find the target variable 450 | target_var_set = set(var for var in self.variables if var.is_target) 451 | assert len(target_var_set) == 1 452 | target_var = next(iter(target_var_set)) 453 | 454 | # if the target variable has only a single candidate, skip the searching 455 | target_var_cand_insts = GlobalState.get_cand_insts(target_var.labels) 456 | if len(target_var_cand_insts) == 1: 457 | if verbose: 458 | print() 459 | print("target is unique.") 460 | print() 461 | print("=" * 30) 462 | print() 463 | 464 | set_target_info( 465 | best_instance=target_var_cand_insts[0], 466 | candidate_instances=target_var_cand_insts, 467 | anchor_instances={}, 468 | csp_desc=self.get_desc_str(), 469 | llm_used=False, 470 | ) 471 | return 472 | 473 | # iterate through possible solutions 474 | valid_solutions: list[Solution] = [] 475 | solution_counter = 0 476 | 477 | for sol_dict in solution_generator(csp_vars=normal_vars): 478 | solution_counter += 1 479 | if solution_counter > 1000: 480 | break 481 | 482 | # discard this solution if an instance is used twice 483 | if not check_solution(sol_dict): 484 | continue 485 | 486 | # we should have an assignment for all variables by now 487 | assert len(sol_dict) + len(negative_vars) == len(self.variables) 488 | assert all(var in sol_dict for var in normal_vars) 489 | 490 | con_failed = False 491 | for con in self.normal_constraints: 492 | neg_vars = {csp_var for csp_var in con.get_vars() if csp_var.negative} 493 | 494 | if neg_vars: 495 | assert len(neg_vars) == 1 496 | neg_var = next(iter(neg_vars)) 497 | assert neg_var not in sol_dict 498 | 499 | neg_var_cand_insts = GlobalState.get_cand_insts(neg_var.labels) 500 | assert len(neg_var_cand_insts) >= 1 501 | 502 | for inst in neg_var_cand_insts: 503 | sol_dict[neg_var] = inst 504 | if con.check_solution(sol_dict): 505 | con_failed = True 506 | break 507 | 508 | del sol_dict[neg_var] 509 | 510 | else: 511 | if not con.check_solution(sol_dict): 512 | con_failed = True 513 | 514 | if con_failed: 515 | break 516 | 517 | if not con_failed: 518 | valid_solutions.append(sol_dict) 519 | 520 | if valid_solutions: 521 | assert all(target_var in sol_dict for sol_dict in valid_solutions) 522 | 523 | if verbose: 524 | print() 525 | print("valid solutions:") 526 | for sol in valid_solutions: 527 | print(solution_to_str(sol)) 528 | 529 | minmax_constraints = sorted(self.minmax_constraints, key=lambda x: x.con_id) 530 | normal_constraints = sorted(self.normal_constraints, key=lambda x: x.con_id) 531 | 532 | # print() 533 | # print(self.get_desc_str()) 534 | # print() 535 | 536 | # print() 537 | # print("num minmax_con:", len(minmax_constraints)) 538 | # for con in minmax_constraints: 539 | # print(con.con_id, con.get_desc_str()) 540 | # print() 541 | 542 | # print() 543 | # for sol in valid_solutions: 544 | # for csp_var, inst in sol.items(): 545 | # print( 546 | # csp_var.get_identifier(), 547 | # f"{inst.label}: {inst.inst_id}", 548 | # end=" || ", 549 | # ) 550 | # print() 551 | # print() 552 | 553 | for minmax_con in minmax_constraints: 554 | con_target_var = minmax_con.get_target_var() 555 | con_anchor_vars: set[CSPVar] = set() 556 | 557 | for normal_con in normal_constraints: 558 | if normal_con.get_target_var() == con_target_var: 559 | con_anchor_vars |= normal_con.get_vars() - {con_target_var} 560 | 561 | anchor_var_groups: dict[tuple[ObjInstance, ...], list[Solution]] = ( 562 | defaultdict(lambda: []) 563 | ) 564 | 565 | # print() 566 | for sol in valid_solutions: 567 | key = tuple( 568 | sorted( 569 | [sol[csp_var] for csp_var in con_anchor_vars], 570 | key=lambda x: x.inst_id, 571 | ) 572 | ) 573 | # print([f"{inst.inst_id}-{inst.label}" for inst in key]) 574 | anchor_var_groups[key].append(sol) 575 | # print() 576 | 577 | # print() 578 | # print("num groups:", len(anchor_var_groups)) 579 | # print() 580 | valid_solutions: list[Solution] = [] 581 | for _, sols in anchor_var_groups.items(): 582 | valid_solutions += minmax_con.filter_solutions(sols) 583 | 584 | if verbose: 585 | print() 586 | print("valid solutions after handling min/max constraints:") 587 | for sol in valid_solutions: 588 | print(solution_to_str(sol)) 589 | 590 | # print() 591 | # for sol in valid_solutions: 592 | # for csp_var, inst in sol.items(): 593 | # print( 594 | # csp_var.get_identifier(), 595 | # f"{inst.label}: {inst.inst_id}", 596 | # end=" || ", 597 | # ) 598 | # print() 599 | # print() 600 | 601 | # select one solution with a heuristic... 602 | first_valid_solution = valid_solutions[0].copy() 603 | valid_solutions.sort(key=get_solution_heuristic_score) 604 | # print() 605 | # print("2:", len(valid_solutions)) 606 | # print("2:", valid_solutions[0][target_var].inst_id) 607 | # print() 608 | 609 | best_solution_0 = valid_solutions[0] 610 | candidate_instances = {best_solution_0[target_var]} 611 | good_solutions = [best_solution_0] 612 | 613 | if len(valid_solutions) >= 2: 614 | best_solution_1 = valid_solutions[-1] 615 | candidate_instances.add(best_solution_1[target_var]) 616 | good_solutions.append(best_solution_1) 617 | 618 | if len(valid_solutions) >= 3: 619 | best_solution_2 = valid_solutions[1] 620 | candidate_instances.add(best_solution_2[target_var]) 621 | good_solutions.append(best_solution_2) 622 | 623 | # if len(valid_solutions) >= 4: 624 | # best_solution_3 = valid_solutions[-2] 625 | # candidate_instances.add(best_solution_3[target_var]) 626 | # good_solutions.append(best_solution_3) 627 | 628 | # target_var_insts = {sol_dict[target_var] for sol_dict in valid_solutions} 629 | llm_used = False 630 | # query_words = set(query.lower().strip().split(" ")) 631 | # view_dep_words = { 632 | # "front", 633 | # "behind", 634 | # "back", 635 | # "right", 636 | # "left", 637 | # "facing", 638 | # "leftmost", 639 | # "rightmost", 640 | # "looking", 641 | # "across", 642 | # } 643 | # view_dep = len(query_words & view_dep_words) > 0 644 | # if len(target_var_insts) >= 2 and not view_dep: 645 | # best_solution_0 = select_best_solution( 646 | # query=query, 647 | # csp_desc=self.get_desc_str(), 648 | # valid_solutions=good_solutions, 649 | # ) 650 | # llm_used = True 651 | 652 | if select_solution == "max_dist": 653 | best_solution_0, best_solution_1 = best_solution_1, best_solution_0 654 | elif select_solution == "random": 655 | best_solution_0 = random.choice(valid_solutions) 656 | elif select_solution == "first": 657 | best_solution_0 = first_valid_solution 658 | else: 659 | assert select_solution == "min_dist" 660 | 661 | if verbose: 662 | print() 663 | print("best solution:") 664 | print(solution_to_str(best_solution_0)) 665 | 666 | set_target_info( 667 | best_instance=best_solution_0[target_var], 668 | candidate_instances=candidate_instances, 669 | anchor_instances={ 670 | csp_var.get_identifier(): inst 671 | for csp_var, inst in best_solution_0.items() 672 | if csp_var != target_var 673 | }, 674 | csp_desc=self.get_desc_str(), 675 | llm_used=llm_used, 676 | ) 677 | 678 | print() 679 | print("=" * 30) 680 | print() 681 | 682 | def solve_naive(self): 683 | raise NotImplementedError() 684 | 685 | def solve_non_csp(self): 686 | def generate_solutions( 687 | csp_vars: list[CSPVar], 688 | csp_var_insts: dict[CSPVar, set[ObjInstance]], 689 | ) -> Generator[Solution, None, None]: 690 | csp_vars = list(set(csp_vars)) 691 | total_sols = 0 692 | 693 | solution_dict: Solution = {} 694 | used_instances: set[ObjInstance] = set() 695 | 696 | def gen_func(i_var: int) -> Generator[Solution, None, None]: 697 | nonlocal solution_dict 698 | nonlocal total_sols 699 | if total_sols >= 1000: 700 | return 701 | 702 | if i_var >= len(csp_vars): 703 | total_sols += 1 704 | yield solution_dict.copy() 705 | return 706 | 707 | cur_var = csp_vars[i_var] 708 | for inst in csp_var_insts[cur_var]: 709 | if inst in used_instances: 710 | continue 711 | 712 | solution_dict[cur_var] = inst 713 | 714 | used_instances.add(inst) 715 | yield from gen_func(i_var + 1) 716 | used_instances.remove(inst) 717 | 718 | yield from gen_func(0) 719 | 720 | csp_var_dep_cons: dict[CSPVar, set[CSPConstraint]] = { 721 | csp_var: { 722 | con 723 | for con in self.normal_constraints 724 | if con.get_target_var() == csp_var 725 | } 726 | for csp_var in self.variables 727 | } 728 | 729 | csp_var_minmax_cons: dict[CSPVar, set[CSPConstraint]] = { 730 | csp_var: { 731 | con 732 | for con in self.minmax_constraints 733 | if con.get_target_var() == csp_var 734 | } 735 | for csp_var in self.variables 736 | } 737 | 738 | csp_var_insts: dict[CSPVar, set[ObjInstance]] = { 739 | csp_var: set(GlobalState.get_cand_insts(csp_var.labels)) 740 | for csp_var in self.variables 741 | } 742 | 743 | constraints: set[CSPConstraint] = self.normal_constraints.copy() 744 | 745 | for i in range(100): 746 | if not constraints: 747 | break 748 | 749 | con_processed = False 750 | 751 | processed_cons: set[CSPConstraint] = set() 752 | for con in constraints: 753 | con_target_var = con.get_target_var() 754 | con_anchor_vars = con.get_vars() - {con_target_var} 755 | 756 | if all( 757 | len(csp_var_dep_cons[csp_var]) == 0 for csp_var in con_anchor_vars 758 | ): 759 | con_processed = True 760 | target_var_insts: set[ObjInstance] = set() 761 | 762 | for sol_dict in generate_solutions(con.get_vars(), csp_var_insts): 763 | assert con_target_var in sol_dict 764 | assert all(csp_var in sol_dict for csp_var in con_anchor_vars) 765 | if con.check_solution(sol_dict): 766 | target_var_insts.add(sol_dict[con_target_var]) 767 | 768 | csp_var_insts[con_target_var] = target_var_insts.copy() 769 | csp_var_dep_cons[con_target_var].remove(con) 770 | processed_cons.add(con) 771 | 772 | constraints -= processed_cons 773 | 774 | processed_vars: set[CSPVar] = set() 775 | for csp_var, minmax_cons in csp_var_minmax_cons.items(): 776 | if not minmax_cons or csp_var_dep_cons[csp_var]: 777 | continue 778 | 779 | # assert len(minmax_cons) == 1 780 | minmax_con = next(iter(minmax_cons)) 781 | 782 | con_target_var = minmax_con.get_target_var() 783 | con_anchor_vars = minmax_con.get_vars() - {con_target_var} 784 | 785 | if con_anchor_vars and any( 786 | csp_var_dep_cons[anchor_var] for anchor_var in con_anchor_vars 787 | ): 788 | continue 789 | 790 | con_processed = True 791 | solutions = list( 792 | generate_solutions(minmax_con.get_vars(), csp_var_insts) 793 | ) 794 | solutions: list[Solution] = minmax_con.filter_solutions(solutions) 795 | csp_var_insts[con_target_var] = { 796 | sol[con_target_var] for sol in solutions 797 | } 798 | 799 | processed_vars.add(csp_var) 800 | 801 | for csp_var in processed_vars: 802 | del csp_var_minmax_cons[csp_var] 803 | 804 | if not con_processed: 805 | return 806 | 807 | target_var_set = set(var for var in self.variables if var.is_target) 808 | assert len(target_var_set) == 1 809 | target_var = next(iter(target_var_set)) 810 | 811 | # if the target variable has only a single candidate, skip the searching 812 | target_var_cand_insts = csp_var_insts[target_var] 813 | if target_var_cand_insts: 814 | set_target_info( 815 | best_instance=next(iter(target_var_cand_insts)), 816 | candidate_instances=target_var_cand_insts, 817 | anchor_instances={}, 818 | csp_desc=self.get_desc_str(), 819 | llm_used=False, 820 | ) 821 | 822 | def get_desc_str(self) -> str: 823 | desc = [] 824 | 825 | for csp_var in self.variables: 826 | desc.append(csp_var.get_identifier()) 827 | if csp_var.is_target: 828 | desc[-1] += " (target)" 829 | if csp_var.negative: 830 | desc[-1] += " (negative)" 831 | 832 | for con in self.normal_constraints: 833 | desc.append(f"{con.apparent_name} {con.get_desc_str()}") 834 | 835 | for con in self.minmax_constraints: 836 | desc.append(f"{con.apparent_name} {con.get_desc_str()}") 837 | 838 | return "\n".join(desc) 839 | -------------------------------------------------------------------------------- /data/counting_queries.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "scene_id": "scene0598_00", 4 | "object_id": "3", 5 | "object_name": "monitor", 6 | "description": "the second monitor from the left." 7 | }, 8 | { 9 | "scene_id": "scene0616_00", 10 | "object_id": "12", 11 | "object_name": "chair", 12 | "description": "the third chair away from the table with a lamp on it." 13 | }, 14 | { 15 | "scene_id": "scene0278_00", 16 | "object_id": "11", 17 | "object_name": "chair", 18 | "description": "the second chair from the copier." 19 | } 20 | ] -------------------------------------------------------------------------------- /data/negative_queries.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "scene_id": "scene0217_00", 4 | "object_id": "5", 5 | "object_name": "desk", 6 | "description": "the desk without any backpack on it." 7 | }, 8 | { 9 | "scene_id": "scene0217_00", 10 | "object_id": "7", 11 | "object_name": "bed", 12 | "description": "the bed not beside the door." 13 | }, 14 | { 15 | "scene_id": "scene0426_00", 16 | "object_id": "5", 17 | "object_name": "jacket", 18 | "description": "the jacket not on the chair." 19 | }, 20 | { 21 | "scene_id": "scene0565_00", 22 | "object_id": "12", 23 | "object_name": "desk", 24 | "description": "the desk without printers on it." 25 | }, 26 | { 27 | "scene_id": "scene0598_00", 28 | "object_id": "13", 29 | "object_name": "table", 30 | "description": "the table without monitors on it." 31 | } 32 | ] -------------------------------------------------------------------------------- /data/scannet200_constants.py: -------------------------------------------------------------------------------- 1 | ### ScanNet Benchmark constants ### 2 | VALID_CLASS_IDS_20 = ( 3 | 1, 4 | 2, 5 | 3, 6 | 4, 7 | 5, 8 | 6, 9 | 7, 10 | 8, 11 | 9, 12 | 10, 13 | 11, 14 | 12, 15 | 14, 16 | 16, 17 | 24, 18 | 28, 19 | 33, 20 | 34, 21 | 36, 22 | 39, 23 | ) 24 | 25 | CLASS_LABELS_20 = ( 26 | "wall", 27 | "floor", 28 | "cabinet", 29 | "bed", 30 | "chair", 31 | "sofa", 32 | "table", 33 | "door", 34 | "window", 35 | "bookshelf", 36 | "picture", 37 | "counter", 38 | "desk", 39 | "curtain", 40 | "refrigerator", 41 | "shower curtain", 42 | "toilet", 43 | "sink", 44 | "bathtub", 45 | "otherfurniture", 46 | ) 47 | 48 | SCANNET_COLOR_MAP_20 = { 49 | 0: (0.0, 0.0, 0.0), 50 | 1: (174.0, 199.0, 232.0), 51 | 2: (152.0, 223.0, 138.0), 52 | 3: (31.0, 119.0, 180.0), 53 | 4: (255.0, 187.0, 120.0), 54 | 5: (188.0, 189.0, 34.0), 55 | 6: (140.0, 86.0, 75.0), 56 | 7: (255.0, 152.0, 150.0), 57 | 8: (214.0, 39.0, 40.0), 58 | 9: (197.0, 176.0, 213.0), 59 | 10: (148.0, 103.0, 189.0), 60 | 11: (196.0, 156.0, 148.0), 61 | 12: (23.0, 190.0, 207.0), 62 | 14: (247.0, 182.0, 210.0), 63 | 15: (66.0, 188.0, 102.0), 64 | 16: (219.0, 219.0, 141.0), 65 | 17: (140.0, 57.0, 197.0), 66 | 18: (202.0, 185.0, 52.0), 67 | 19: (51.0, 176.0, 203.0), 68 | 20: (200.0, 54.0, 131.0), 69 | 21: (92.0, 193.0, 61.0), 70 | 22: (78.0, 71.0, 183.0), 71 | 23: (172.0, 114.0, 82.0), 72 | 24: (255.0, 127.0, 14.0), 73 | 25: (91.0, 163.0, 138.0), 74 | 26: (153.0, 98.0, 156.0), 75 | 27: (140.0, 153.0, 101.0), 76 | 28: (158.0, 218.0, 229.0), 77 | 29: (100.0, 125.0, 154.0), 78 | 30: (178.0, 127.0, 135.0), 79 | 32: (146.0, 111.0, 194.0), 80 | 33: (44.0, 160.0, 44.0), 81 | 34: (112.0, 128.0, 144.0), 82 | 35: (96.0, 207.0, 209.0), 83 | 36: (227.0, 119.0, 194.0), 84 | 37: (213.0, 92.0, 176.0), 85 | 38: (94.0, 106.0, 211.0), 86 | 39: (82.0, 84.0, 163.0), 87 | 40: (100.0, 85.0, 144.0), 88 | } 89 | 90 | ### ScanNet200 Benchmark constants ### 91 | VALID_CLASS_IDS_200 = ( 92 | 1, 93 | 2, 94 | 3, 95 | 4, 96 | 5, 97 | 6, 98 | 7, 99 | 8, 100 | 9, 101 | 10, 102 | 11, 103 | 13, 104 | 14, 105 | 15, 106 | 16, 107 | 17, 108 | 18, 109 | 19, 110 | 21, 111 | 22, 112 | 23, 113 | 24, 114 | 26, 115 | 27, 116 | 28, 117 | 29, 118 | 31, 119 | 32, 120 | 33, 121 | 34, 122 | 35, 123 | 36, 124 | 38, 125 | 39, 126 | 40, 127 | 41, 128 | 42, 129 | 44, 130 | 45, 131 | 46, 132 | 47, 133 | 48, 134 | 49, 135 | 50, 136 | 51, 137 | 52, 138 | 54, 139 | 55, 140 | 56, 141 | 57, 142 | 58, 143 | 59, 144 | 62, 145 | 63, 146 | 64, 147 | 65, 148 | 66, 149 | 67, 150 | 68, 151 | 69, 152 | 70, 153 | 71, 154 | 72, 155 | 73, 156 | 74, 157 | 75, 158 | 76, 159 | 77, 160 | 78, 161 | 79, 162 | 80, 163 | 82, 164 | 84, 165 | 86, 166 | 87, 167 | 88, 168 | 89, 169 | 90, 170 | 93, 171 | 95, 172 | 96, 173 | 97, 174 | 98, 175 | 99, 176 | 100, 177 | 101, 178 | 102, 179 | 103, 180 | 104, 181 | 105, 182 | 106, 183 | 107, 184 | 110, 185 | 112, 186 | 115, 187 | 116, 188 | 118, 189 | 120, 190 | 121, 191 | 122, 192 | 125, 193 | 128, 194 | 130, 195 | 131, 196 | 132, 197 | 134, 198 | 136, 199 | 138, 200 | 139, 201 | 140, 202 | 141, 203 | 145, 204 | 148, 205 | 154, 206 | 155, 207 | 156, 208 | 157, 209 | 159, 210 | 161, 211 | 163, 212 | 165, 213 | 166, 214 | 168, 215 | 169, 216 | 170, 217 | 177, 218 | 180, 219 | 185, 220 | 188, 221 | 191, 222 | 193, 223 | 195, 224 | 202, 225 | 208, 226 | 213, 227 | 214, 228 | 221, 229 | 229, 230 | 230, 231 | 232, 232 | 233, 233 | 242, 234 | 250, 235 | 261, 236 | 264, 237 | 276, 238 | 283, 239 | 286, 240 | 300, 241 | 304, 242 | 312, 243 | 323, 244 | 325, 245 | 331, 246 | 342, 247 | 356, 248 | 370, 249 | 392, 250 | 395, 251 | 399, 252 | 408, 253 | 417, 254 | 488, 255 | 540, 256 | 562, 257 | 570, 258 | 572, 259 | 581, 260 | 609, 261 | 748, 262 | 776, 263 | 1156, 264 | 1163, 265 | 1164, 266 | 1165, 267 | 1166, 268 | 1167, 269 | 1168, 270 | 1169, 271 | 1170, 272 | 1171, 273 | 1172, 274 | 1173, 275 | 1174, 276 | 1175, 277 | 1176, 278 | 1178, 279 | 1179, 280 | 1180, 281 | 1181, 282 | 1182, 283 | 1183, 284 | 1184, 285 | 1185, 286 | 1186, 287 | 1187, 288 | 1188, 289 | 1189, 290 | 1190, 291 | 1191, 292 | ) 293 | 294 | CLASS_LABELS_200 = ( 295 | "wall", 296 | "chair", 297 | "floor", 298 | "table", 299 | "door", 300 | "couch", 301 | "cabinet", 302 | "shelf", 303 | "desk", 304 | "office chair", 305 | "bed", 306 | "pillow", 307 | "sink", 308 | "picture", 309 | "window", 310 | "toilet", 311 | "bookshelf", 312 | "monitor", 313 | "curtain", 314 | "book", 315 | "armchair", 316 | "coffee table", 317 | "box", 318 | "refrigerator", 319 | "lamp", 320 | "kitchen cabinet", 321 | "towel", 322 | "clothes", 323 | "tv", 324 | "nightstand", 325 | "counter", 326 | "dresser", 327 | "stool", 328 | "cushion", 329 | "plant", 330 | "ceiling", 331 | "bathtub", 332 | "end table", 333 | "dining table", 334 | "keyboard", 335 | "bag", 336 | "backpack", 337 | "toilet paper", 338 | "printer", 339 | "tv stand", 340 | "whiteboard", 341 | "blanket", 342 | "shower curtain", 343 | "trash can", 344 | "closet", 345 | "stairs", 346 | "microwave", 347 | "stove", 348 | "shoe", 349 | "computer tower", 350 | "bottle", 351 | "bin", 352 | "ottoman", 353 | "bench", 354 | "board", 355 | "washing machine", 356 | "mirror", 357 | "copier", 358 | "basket", 359 | "sofa chair", 360 | "file cabinet", 361 | "fan", 362 | "laptop", 363 | "shower", 364 | "paper", 365 | "person", 366 | "paper towel dispenser", 367 | "oven", 368 | "blinds", 369 | "rack", 370 | "plate", 371 | "blackboard", 372 | "piano", 373 | "suitcase", 374 | "rail", 375 | "radiator", 376 | "recycling bin", 377 | "container", 378 | "wardrobe", 379 | "soap dispenser", 380 | "telephone", 381 | "bucket", 382 | "clock", 383 | "stand", 384 | "light", 385 | "laundry basket", 386 | "pipe", 387 | "clothes dryer", 388 | "guitar", 389 | "toilet paper holder", 390 | "seat", 391 | "speaker", 392 | "column", 393 | "bicycle", 394 | "ladder", 395 | "bathroom stall", 396 | "shower wall", 397 | "cup", 398 | "jacket", 399 | "storage bin", 400 | "coffee maker", 401 | "dishwasher", 402 | "paper towel roll", 403 | "machine", 404 | "mat", 405 | "windowsill", 406 | "bar", 407 | "toaster", 408 | "bulletin board", 409 | "ironing board", 410 | "fireplace", 411 | "soap dish", 412 | "kitchen counter", 413 | "doorframe", 414 | "toilet paper dispenser", 415 | "mini fridge", 416 | "fire extinguisher", 417 | "ball", 418 | "hat", 419 | "shower curtain rod", 420 | "water cooler", 421 | "paper cutter", 422 | "tray", 423 | "shower door", 424 | "pillar", 425 | "ledge", 426 | "toaster oven", 427 | "mouse", 428 | "toilet seat cover dispenser", 429 | "furniture", 430 | "cart", 431 | "storage container", 432 | "scale", 433 | "tissue box", 434 | "light switch", 435 | "crate", 436 | "power outlet", 437 | "decoration", 438 | "sign", 439 | "projector", 440 | "closet door", 441 | "vacuum cleaner", 442 | "candle", 443 | "plunger", 444 | "stuffed animal", 445 | "headphones", 446 | "dish rack", 447 | "broom", 448 | "guitar case", 449 | "range hood", 450 | "dustpan", 451 | "hair dryer", 452 | "water bottle", 453 | "handicap bar", 454 | "purse", 455 | "vent", 456 | "shower floor", 457 | "water pitcher", 458 | "mailbox", 459 | "bowl", 460 | "paper bag", 461 | "alarm clock", 462 | "music stand", 463 | "projector screen", 464 | "divider", 465 | "laundry detergent", 466 | "bathroom counter", 467 | "object", 468 | "bathroom vanity", 469 | "closet wall", 470 | "laundry hamper", 471 | "bathroom stall door", 472 | "ceiling light", 473 | "trash bin", 474 | "dumbbell", 475 | "stair rail", 476 | "tube", 477 | "bathroom cabinet", 478 | "cd case", 479 | "closet rod", 480 | "coffee kettle", 481 | "structure", 482 | "shower head", 483 | "keyboard piano", 484 | "case of water bottles", 485 | "coat rack", 486 | "storage organizer", 487 | "folded chair", 488 | "fire alarm", 489 | "power strip", 490 | "calendar", 491 | "poster", 492 | "potted plant", 493 | "luggage", 494 | "mattress", 495 | ) 496 | 497 | SCANNET_COLOR_MAP_200 = { 498 | 0: (0.0, 0.0, 0.0), 499 | 1: (174.0, 199.0, 232.0), 500 | 2: (188.0, 189.0, 34.0), 501 | 3: (152.0, 223.0, 138.0), 502 | 4: (255.0, 152.0, 150.0), 503 | 5: (214.0, 39.0, 40.0), 504 | 6: (91.0, 135.0, 229.0), 505 | 7: (31.0, 119.0, 180.0), 506 | 8: (229.0, 91.0, 104.0), 507 | 9: (247.0, 182.0, 210.0), 508 | 10: (91.0, 229.0, 110.0), 509 | 11: (255.0, 187.0, 120.0), 510 | 13: (141.0, 91.0, 229.0), 511 | 14: (112.0, 128.0, 144.0), 512 | 15: (196.0, 156.0, 148.0), 513 | 16: (197.0, 176.0, 213.0), 514 | 17: (44.0, 160.0, 44.0), 515 | 18: (148.0, 103.0, 189.0), 516 | 19: (229.0, 91.0, 223.0), 517 | 21: (219.0, 219.0, 141.0), 518 | 22: (192.0, 229.0, 91.0), 519 | 23: (88.0, 218.0, 137.0), 520 | 24: (58.0, 98.0, 137.0), 521 | 26: (177.0, 82.0, 239.0), 522 | 27: (255.0, 127.0, 14.0), 523 | 28: (237.0, 204.0, 37.0), 524 | 29: (41.0, 206.0, 32.0), 525 | 31: (62.0, 143.0, 148.0), 526 | 32: (34.0, 14.0, 130.0), 527 | 33: (143.0, 45.0, 115.0), 528 | 34: (137.0, 63.0, 14.0), 529 | 35: (23.0, 190.0, 207.0), 530 | 36: (16.0, 212.0, 139.0), 531 | 38: (90.0, 119.0, 201.0), 532 | 39: (125.0, 30.0, 141.0), 533 | 40: (150.0, 53.0, 56.0), 534 | 41: (186.0, 197.0, 62.0), 535 | 42: (227.0, 119.0, 194.0), 536 | 44: (38.0, 100.0, 128.0), 537 | 45: (120.0, 31.0, 243.0), 538 | 46: (154.0, 59.0, 103.0), 539 | 47: (169.0, 137.0, 78.0), 540 | 48: (143.0, 245.0, 111.0), 541 | 49: (37.0, 230.0, 205.0), 542 | 50: (14.0, 16.0, 155.0), 543 | 51: (196.0, 51.0, 182.0), 544 | 52: (237.0, 80.0, 38.0), 545 | 54: (138.0, 175.0, 62.0), 546 | 55: (158.0, 218.0, 229.0), 547 | 56: (38.0, 96.0, 167.0), 548 | 57: (190.0, 77.0, 246.0), 549 | 58: (208.0, 49.0, 84.0), 550 | 59: (208.0, 193.0, 72.0), 551 | 62: (55.0, 220.0, 57.0), 552 | 63: (10.0, 125.0, 140.0), 553 | 64: (76.0, 38.0, 202.0), 554 | 65: (191.0, 28.0, 135.0), 555 | 66: (211.0, 120.0, 42.0), 556 | 67: (118.0, 174.0, 76.0), 557 | 68: (17.0, 242.0, 171.0), 558 | 69: (20.0, 65.0, 247.0), 559 | 70: (208.0, 61.0, 222.0), 560 | 71: (162.0, 62.0, 60.0), 561 | 72: (210.0, 235.0, 62.0), 562 | 73: (45.0, 152.0, 72.0), 563 | 74: (35.0, 107.0, 149.0), 564 | 75: (160.0, 89.0, 237.0), 565 | 76: (227.0, 56.0, 125.0), 566 | 77: (169.0, 143.0, 81.0), 567 | 78: (42.0, 143.0, 20.0), 568 | 79: (25.0, 160.0, 151.0), 569 | 80: (82.0, 75.0, 227.0), 570 | 82: (253.0, 59.0, 222.0), 571 | 84: (240.0, 130.0, 89.0), 572 | 86: (123.0, 172.0, 47.0), 573 | 87: (71.0, 194.0, 133.0), 574 | 88: (24.0, 94.0, 205.0), 575 | 89: (134.0, 16.0, 179.0), 576 | 90: (159.0, 32.0, 52.0), 577 | 93: (213.0, 208.0, 88.0), 578 | 95: (64.0, 158.0, 70.0), 579 | 96: (18.0, 163.0, 194.0), 580 | 97: (65.0, 29.0, 153.0), 581 | 98: (177.0, 10.0, 109.0), 582 | 99: (152.0, 83.0, 7.0), 583 | 100: (83.0, 175.0, 30.0), 584 | 101: (18.0, 199.0, 153.0), 585 | 102: (61.0, 81.0, 208.0), 586 | 103: (213.0, 85.0, 216.0), 587 | 104: (170.0, 53.0, 42.0), 588 | 105: (161.0, 192.0, 38.0), 589 | 106: (23.0, 241.0, 91.0), 590 | 107: (12.0, 103.0, 170.0), 591 | 110: (151.0, 41.0, 245.0), 592 | 112: (133.0, 51.0, 80.0), 593 | 115: (184.0, 162.0, 91.0), 594 | 116: (50.0, 138.0, 38.0), 595 | 118: (31.0, 237.0, 236.0), 596 | 120: (39.0, 19.0, 208.0), 597 | 121: (223.0, 27.0, 180.0), 598 | 122: (254.0, 141.0, 85.0), 599 | 125: (97.0, 144.0, 39.0), 600 | 128: (106.0, 231.0, 176.0), 601 | 130: (12.0, 61.0, 162.0), 602 | 131: (124.0, 66.0, 140.0), 603 | 132: (137.0, 66.0, 73.0), 604 | 134: (250.0, 253.0, 26.0), 605 | 136: (55.0, 191.0, 73.0), 606 | 138: (60.0, 126.0, 146.0), 607 | 139: (153.0, 108.0, 234.0), 608 | 140: (184.0, 58.0, 125.0), 609 | 141: (135.0, 84.0, 14.0), 610 | 145: (139.0, 248.0, 91.0), 611 | 148: (53.0, 200.0, 172.0), 612 | 154: (63.0, 69.0, 134.0), 613 | 155: (190.0, 75.0, 186.0), 614 | 156: (127.0, 63.0, 52.0), 615 | 157: (141.0, 182.0, 25.0), 616 | 159: (56.0, 144.0, 89.0), 617 | 161: (64.0, 160.0, 250.0), 618 | 163: (182.0, 86.0, 245.0), 619 | 165: (139.0, 18.0, 53.0), 620 | 166: (134.0, 120.0, 54.0), 621 | 168: (49.0, 165.0, 42.0), 622 | 169: (51.0, 128.0, 133.0), 623 | 170: (44.0, 21.0, 163.0), 624 | 177: (232.0, 93.0, 193.0), 625 | 180: (176.0, 102.0, 54.0), 626 | 185: (116.0, 217.0, 17.0), 627 | 188: (54.0, 209.0, 150.0), 628 | 191: (60.0, 99.0, 204.0), 629 | 193: (129.0, 43.0, 144.0), 630 | 195: (252.0, 100.0, 106.0), 631 | 202: (187.0, 196.0, 73.0), 632 | 208: (13.0, 158.0, 40.0), 633 | 213: (52.0, 122.0, 152.0), 634 | 214: (128.0, 76.0, 202.0), 635 | 221: (187.0, 50.0, 115.0), 636 | 229: (180.0, 141.0, 71.0), 637 | 230: (77.0, 208.0, 35.0), 638 | 232: (72.0, 183.0, 168.0), 639 | 233: (97.0, 99.0, 203.0), 640 | 242: (172.0, 22.0, 158.0), 641 | 250: (155.0, 64.0, 40.0), 642 | 261: (118.0, 159.0, 30.0), 643 | 264: (69.0, 252.0, 148.0), 644 | 276: (45.0, 103.0, 173.0), 645 | 283: (111.0, 38.0, 149.0), 646 | 286: (184.0, 9.0, 49.0), 647 | 300: (188.0, 174.0, 67.0), 648 | 304: (53.0, 206.0, 53.0), 649 | 312: (97.0, 235.0, 252.0), 650 | 323: (66.0, 32.0, 182.0), 651 | 325: (236.0, 114.0, 195.0), 652 | 331: (241.0, 154.0, 83.0), 653 | 342: (133.0, 240.0, 52.0), 654 | 356: (16.0, 205.0, 144.0), 655 | 370: (75.0, 101.0, 198.0), 656 | 392: (237.0, 95.0, 251.0), 657 | 395: (191.0, 52.0, 49.0), 658 | 399: (227.0, 254.0, 54.0), 659 | 408: (49.0, 206.0, 87.0), 660 | 417: (48.0, 113.0, 150.0), 661 | 488: (125.0, 73.0, 182.0), 662 | 540: (229.0, 32.0, 114.0), 663 | 562: (158.0, 119.0, 28.0), 664 | 570: (60.0, 205.0, 27.0), 665 | 572: (18.0, 215.0, 201.0), 666 | 581: (79.0, 76.0, 153.0), 667 | 609: (134.0, 13.0, 116.0), 668 | 748: (192.0, 97.0, 63.0), 669 | 776: (108.0, 163.0, 18.0), 670 | 1156: (95.0, 220.0, 156.0), 671 | 1163: (98.0, 141.0, 208.0), 672 | 1164: (144.0, 19.0, 193.0), 673 | 1165: (166.0, 36.0, 57.0), 674 | 1166: (212.0, 202.0, 34.0), 675 | 1167: (23.0, 206.0, 34.0), 676 | 1168: (91.0, 211.0, 236.0), 677 | 1169: (79.0, 55.0, 137.0), 678 | 1170: (182.0, 19.0, 117.0), 679 | 1171: (134.0, 76.0, 14.0), 680 | 1172: (87.0, 185.0, 28.0), 681 | 1173: (82.0, 224.0, 187.0), 682 | 1174: (92.0, 110.0, 214.0), 683 | 1175: (168.0, 80.0, 171.0), 684 | 1176: (197.0, 63.0, 51.0), 685 | 1178: (175.0, 199.0, 77.0), 686 | 1179: (62.0, 180.0, 98.0), 687 | 1180: (8.0, 91.0, 150.0), 688 | 1181: (77.0, 15.0, 130.0), 689 | 1182: (154.0, 65.0, 96.0), 690 | 1183: (197.0, 152.0, 11.0), 691 | 1184: (59.0, 155.0, 45.0), 692 | 1185: (12.0, 147.0, 145.0), 693 | 1186: (54.0, 35.0, 219.0), 694 | 1187: (210.0, 73.0, 181.0), 695 | 1188: (221.0, 124.0, 77.0), 696 | 1189: (149.0, 214.0, 66.0), 697 | 1190: (72.0, 185.0, 134.0), 698 | 1191: (42.0, 94.0, 198.0), 699 | } 700 | 701 | ### For instance segmentation the non-object categories ### 702 | VALID_PANOPTIC_IDS = (1, 3) 703 | 704 | CLASS_LABELS_PANOPTIC = ("wall", "floor") 705 | -------------------------------------------------------------------------------- /data/scannetv2_val.txt: -------------------------------------------------------------------------------- 1 | scene0568_00 2 | scene0568_01 3 | scene0568_02 4 | scene0304_00 5 | scene0488_00 6 | scene0488_01 7 | scene0412_00 8 | scene0412_01 9 | scene0217_00 10 | scene0019_00 11 | scene0019_01 12 | scene0414_00 13 | scene0575_00 14 | scene0575_01 15 | scene0575_02 16 | scene0426_00 17 | scene0426_01 18 | scene0426_02 19 | scene0426_03 20 | scene0549_00 21 | scene0549_01 22 | scene0578_00 23 | scene0578_01 24 | scene0578_02 25 | scene0665_00 26 | scene0665_01 27 | scene0050_00 28 | scene0050_01 29 | scene0050_02 30 | scene0257_00 31 | scene0025_00 32 | scene0025_01 33 | scene0025_02 34 | scene0583_00 35 | scene0583_01 36 | scene0583_02 37 | scene0701_00 38 | scene0701_01 39 | scene0701_02 40 | scene0580_00 41 | scene0580_01 42 | scene0565_00 43 | scene0169_00 44 | scene0169_01 45 | scene0655_00 46 | scene0655_01 47 | scene0655_02 48 | scene0063_00 49 | scene0221_00 50 | scene0221_01 51 | scene0591_00 52 | scene0591_01 53 | scene0591_02 54 | scene0678_00 55 | scene0678_01 56 | scene0678_02 57 | scene0462_00 58 | scene0427_00 59 | scene0595_00 60 | scene0193_00 61 | scene0193_01 62 | scene0164_00 63 | scene0164_01 64 | scene0164_02 65 | scene0164_03 66 | scene0598_00 67 | scene0598_01 68 | scene0598_02 69 | scene0599_00 70 | scene0599_01 71 | scene0599_02 72 | scene0328_00 73 | scene0300_00 74 | scene0300_01 75 | scene0354_00 76 | scene0458_00 77 | scene0458_01 78 | scene0423_00 79 | scene0423_01 80 | scene0423_02 81 | scene0307_00 82 | scene0307_01 83 | scene0307_02 84 | scene0606_00 85 | scene0606_01 86 | scene0606_02 87 | scene0432_00 88 | scene0432_01 89 | scene0608_00 90 | scene0608_01 91 | scene0608_02 92 | scene0651_00 93 | scene0651_01 94 | scene0651_02 95 | scene0430_00 96 | scene0430_01 97 | scene0689_00 98 | scene0357_00 99 | scene0357_01 100 | scene0574_00 101 | scene0574_01 102 | scene0574_02 103 | scene0329_00 104 | scene0329_01 105 | scene0329_02 106 | scene0153_00 107 | scene0153_01 108 | scene0616_00 109 | scene0616_01 110 | scene0671_00 111 | scene0671_01 112 | scene0618_00 113 | scene0382_00 114 | scene0382_01 115 | scene0490_00 116 | scene0621_00 117 | scene0607_00 118 | scene0607_01 119 | scene0149_00 120 | scene0695_00 121 | scene0695_01 122 | scene0695_02 123 | scene0695_03 124 | scene0389_00 125 | scene0377_00 126 | scene0377_01 127 | scene0377_02 128 | scene0342_00 129 | scene0139_00 130 | scene0629_00 131 | scene0629_01 132 | scene0629_02 133 | scene0496_00 134 | scene0633_00 135 | scene0633_01 136 | scene0518_00 137 | scene0652_00 138 | scene0406_00 139 | scene0406_01 140 | scene0406_02 141 | scene0144_00 142 | scene0144_01 143 | scene0494_00 144 | scene0278_00 145 | scene0278_01 146 | scene0316_00 147 | scene0609_00 148 | scene0609_01 149 | scene0609_02 150 | scene0609_03 151 | scene0084_00 152 | scene0084_01 153 | scene0084_02 154 | scene0696_00 155 | scene0696_01 156 | scene0696_02 157 | scene0351_00 158 | scene0351_01 159 | scene0643_00 160 | scene0644_00 161 | scene0645_00 162 | scene0645_01 163 | scene0645_02 164 | scene0081_00 165 | scene0081_01 166 | scene0081_02 167 | scene0647_00 168 | scene0647_01 169 | scene0535_00 170 | scene0353_00 171 | scene0353_01 172 | scene0353_02 173 | scene0559_00 174 | scene0559_01 175 | scene0559_02 176 | scene0593_00 177 | scene0593_01 178 | scene0246_00 179 | scene0653_00 180 | scene0653_01 181 | scene0064_00 182 | scene0064_01 183 | scene0356_00 184 | scene0356_01 185 | scene0356_02 186 | scene0030_00 187 | scene0030_01 188 | scene0030_02 189 | scene0222_00 190 | scene0222_01 191 | scene0338_00 192 | scene0338_01 193 | scene0338_02 194 | scene0378_00 195 | scene0378_01 196 | scene0378_02 197 | scene0660_00 198 | scene0553_00 199 | scene0553_01 200 | scene0553_02 201 | scene0527_00 202 | scene0663_00 203 | scene0663_01 204 | scene0663_02 205 | scene0664_00 206 | scene0664_01 207 | scene0664_02 208 | scene0334_00 209 | scene0334_01 210 | scene0334_02 211 | scene0046_00 212 | scene0046_01 213 | scene0046_02 214 | scene0203_00 215 | scene0203_01 216 | scene0203_02 217 | scene0088_00 218 | scene0088_01 219 | scene0088_02 220 | scene0088_03 221 | scene0086_00 222 | scene0086_01 223 | scene0086_02 224 | scene0670_00 225 | scene0670_01 226 | scene0256_00 227 | scene0256_01 228 | scene0256_02 229 | scene0249_00 230 | scene0441_00 231 | scene0658_00 232 | scene0704_00 233 | scene0704_01 234 | scene0187_00 235 | scene0187_01 236 | scene0131_00 237 | scene0131_01 238 | scene0131_02 239 | scene0207_00 240 | scene0207_01 241 | scene0207_02 242 | scene0461_00 243 | scene0011_00 244 | scene0011_01 245 | scene0343_00 246 | scene0251_00 247 | scene0077_00 248 | scene0077_01 249 | scene0684_00 250 | scene0684_01 251 | scene0550_00 252 | scene0686_00 253 | scene0686_01 254 | scene0686_02 255 | scene0208_00 256 | scene0500_00 257 | scene0500_01 258 | scene0552_00 259 | scene0552_01 260 | scene0648_00 261 | scene0648_01 262 | scene0435_00 263 | scene0435_01 264 | scene0435_02 265 | scene0435_03 266 | scene0690_00 267 | scene0690_01 268 | scene0693_00 269 | scene0693_01 270 | scene0693_02 271 | scene0700_00 272 | scene0700_01 273 | scene0700_02 274 | scene0699_00 275 | scene0231_00 276 | scene0231_01 277 | scene0231_02 278 | scene0697_00 279 | scene0697_01 280 | scene0697_02 281 | scene0697_03 282 | scene0474_00 283 | scene0474_01 284 | scene0474_02 285 | scene0474_03 286 | scene0474_04 287 | scene0474_05 288 | scene0355_00 289 | scene0355_01 290 | scene0146_00 291 | scene0146_01 292 | scene0146_02 293 | scene0196_00 294 | scene0702_00 295 | scene0702_01 296 | scene0702_02 297 | scene0314_00 298 | scene0277_00 299 | scene0277_01 300 | scene0277_02 301 | scene0095_00 302 | scene0095_01 303 | scene0015_00 304 | scene0100_00 305 | scene0100_01 306 | scene0100_02 307 | scene0558_00 308 | scene0558_01 309 | scene0558_02 310 | scene0685_00 311 | scene0685_01 312 | scene0685_02 -------------------------------------------------------------------------------- /misc_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.typing as npt 3 | 4 | 5 | def lookat_matrix( 6 | eye: npt.NDArray, target: npt.NDArray, up: npt.NDArray = [0, 0, 1] 7 | ) -> npt.NDArray: 8 | """ 9 | build a lookat matrix (4x4): global coordiantes -> camera/eye coordiantes 10 | (might still have bugs...) 11 | """ 12 | mat = np.zeros((4, 4)) 13 | eye = np.array(eye, dtype=np.float64) 14 | target = np.array(target, dtype=np.float64) 15 | up = np.array(up, dtype=np.float64) 16 | 17 | axis_z = target - eye 18 | axis_z /= np.linalg.norm(axis_z) 19 | 20 | up /= np.linalg.norm(up) 21 | axis_x = np.cross(up, axis_z) 22 | axis_x /= np.linalg.norm(axis_x) 23 | 24 | axis_y = np.cross(axis_z, axis_x) 25 | axis_y /= np.linalg.norm(axis_y) 26 | 27 | mat[0, :3] = axis_x 28 | mat[1, :3] = axis_y 29 | mat[2, :3] = axis_z 30 | mat[:3, 3] = -mat[:3, :3] @ eye 31 | mat[3, 3] = 1 32 | 33 | return mat 34 | 35 | 36 | def check(expr: bool, msg: str = "") -> None: 37 | if not (expr): 38 | raise SystemError(msg) 39 | 40 | 41 | def check_list_of_type(x: list, ty: type, msg: str = "") -> None: 42 | check(isinstance(x, list) and all(isinstance(i, ty) for i in x), msg) 43 | 44 | 45 | def check_set_of_type(x: set, ty: type, msg: str = "") -> None: 46 | check(isinstance(x, set) and all(isinstance(i, ty) for i in x), msg) 47 | 48 | 49 | def is_list_of_type(x: list, ty: type) -> bool: 50 | return isinstance(x, list) and all(isinstance(i, ty) for i in x) 51 | 52 | 53 | def is_set_of_type(x: set, ty: type) -> bool: 54 | return isinstance(x, set) and all(isinstance(i, ty) for i in x) 55 | -------------------------------------------------------------------------------- /program_evaluator.py: -------------------------------------------------------------------------------- 1 | # run the generated programs and evalute the grounding results 2 | 3 | import argparse 4 | import functools 5 | import json 6 | import os 7 | import random 8 | import sys 9 | import traceback 10 | from dataclasses import dataclass 11 | from multiprocessing import Pool 12 | 13 | from tqdm import tqdm 14 | from tqdm.contrib.concurrent import process_map 15 | from typing_extensions import Self 16 | 17 | from program_functions_csp import reset_csp_solver, run_csp_solver 18 | from scannet_utils import ObjInstance, ScanNetScene 19 | from scope_env import ( 20 | TargetInfo, 21 | get_eval_scope, 22 | set_instance_map, 23 | set_room_center, 24 | ) 25 | 26 | 27 | @dataclass 28 | class EvalSingleResultScanRefer: 29 | acc05: bool = False 30 | acc05_potential: bool = False 31 | acc025: bool = False 32 | acc025_potential: bool = False 33 | is_unique: bool = True 34 | eval_result: dict = None 35 | error: bool = False 36 | llm_used: bool = False 37 | 38 | 39 | @dataclass 40 | class EvalSingleResultNr3D: 41 | eval_result: dict = None 42 | is_hard: bool = False 43 | is_view_dependent: bool = False 44 | success: bool = False 45 | could_success: bool = False 46 | error: bool = False 47 | llm_used: bool = False 48 | 49 | 50 | def eval_single_scanrefer( 51 | query_info: tuple[int, dict], 52 | mask3d_pred_path: str | None = None, 53 | maskcluster_pred_path: str | None = None, 54 | cache_root: str | None = None, 55 | solver_type: str = "default", 56 | select_solution: str = "min_dist", 57 | verbose: int = 0, 58 | ) -> EvalSingleResultScanRefer: 59 | query_id, query = query_info 60 | result = EvalSingleResultScanRefer() 61 | result.eval_result = query.copy() 62 | result.eval_result["query_id"] = query_id 63 | result.eval_result["acc05"] = False 64 | result.eval_result["acc025"] = False 65 | 66 | # todo: 67 | # loading the scene for every query is slow! 68 | # we should load all relevant scens at the beginning. 69 | 70 | scene_id = query["scene_id"] 71 | scene = ScanNetScene( 72 | f"./data/scans/{scene_id}", 73 | mask3d_pred_path=mask3d_pred_path, 74 | maskcluster_pred_path=maskcluster_pred_path, 75 | add_room_center=True, 76 | add_room_corners=True, 77 | cache_root=cache_root, 78 | ) 79 | # scene.visualize() 80 | 81 | # used to get groundtruth bounding boxes 82 | gt_scene = ScanNetScene( 83 | f"./data/scans/{scene_id}", 84 | mask3d_pred_path=None, 85 | maskcluster_pred_path=None, 86 | add_room_center=False, 87 | add_room_corners=False, 88 | cache_root=cache_root, 89 | ) 90 | 91 | # check if the query is "unique" or "multiple" 92 | result.is_unique = gt_scene.is_unique_label( 93 | query["target_label"].replace("_", " ").lower().strip() 94 | ) 95 | 96 | if verbose >= 2: 97 | print(f"==================> vvv Query {query_id} vvv <==================") 98 | print(query["text"]) 99 | print(query["scene_id"]) 100 | print(query["target_id"]) 101 | print(query["target_label"]) 102 | print(query["program"]) 103 | print() 104 | 105 | use_min_dist_heuristic = True 106 | query_text = query["text"].lower() 107 | if any( 108 | x in query_text 109 | for x in {"far", "across", "opposite", "away", "remote", "distant"} 110 | ): 111 | use_min_dist_heuristic = False 112 | 113 | try: 114 | # this is such a mess... it would be better to implement an interpreter manually, 115 | # which will provide better error messages and debugging support 116 | 117 | set_instance_map(scene.get_instance_map()) 118 | set_room_center(scene.get_room_center()) 119 | reset_csp_solver() 120 | TargetInfo.reset() 121 | 122 | exec(query["program"], get_eval_scope(use_type_check_funcs=False)) 123 | run_csp_solver( 124 | query=query["text"], 125 | solver_type=solver_type, 126 | select_solution=select_solution, 127 | ) 128 | 129 | # check acc@0.25/0.5 130 | gt_bbox = gt_scene.get_instance_bbox(str(query["target_id"])) 131 | if TargetInfo.best_instance is not None: 132 | iou_best = TargetInfo.best_instance.bbox.iou(gt_bbox) 133 | iou_others = [ 134 | inst.bbox.iou(gt_bbox) for inst in TargetInfo.candidate_instances 135 | ] 136 | 137 | result.eval_result["predicted_bbox"] = { 138 | "pmin": [float(x) for x in list(TargetInfo.best_instance.bbox.pmin)], 139 | "pmax": [float(x) for x in list(TargetInfo.best_instance.bbox.pmax)], 140 | } 141 | 142 | result.eval_result["anchor_bboxes"] = { 143 | name: { 144 | "pmin": [float(x) for x in list(inst.bbox.pmin)], 145 | "pmax": [float(x) for x in list(inst.bbox.pmax)], 146 | } 147 | if isinstance(inst, ObjInstance) 148 | else [ 149 | { 150 | "pmin": [float(x) for x in list(y.bbox.pmin)], 151 | "pmax": [float(x) for x in list(y.bbox.pmax)], 152 | } 153 | for y in inst 154 | ] 155 | for name, inst in TargetInfo.anchor_instances.items() 156 | } 157 | 158 | result.eval_result["csp_desc"] = TargetInfo.csp_desc 159 | result.llm_used = TargetInfo.llm_used 160 | 161 | if iou_best >= 0.5: 162 | result.acc05 = True 163 | result.eval_result["acc05"] = True 164 | if iou_best >= 0.25: 165 | result.acc025 = True 166 | result.eval_result["acc025"] = True 167 | 168 | if (result.acc05 or result.acc025) and verbose >= 2: 169 | print() 170 | print("*** SUCCESSFUL ***") 171 | print() 172 | 173 | if any(x >= 0.5 for x in iou_others): 174 | result.acc05_potential = True 175 | # if result.acc05 is False: 176 | # result.eval_result["far"] = True 177 | if any(x >= 0.25 for x in iou_others): 178 | result.acc025_potential = True 179 | 180 | if (result.acc05_potential or result.acc025_potential) and verbose >= 2: 181 | print() 182 | print("*** COULD BE SUCCESSFUL ***") 183 | print() 184 | 185 | if verbose >= 2: 186 | # print(TargetInfo.label) 187 | # for inst in TargetInfo.instances: 188 | # print(inst.inst_id, inst.label) 189 | print(f"==================> ^^^ Query {query_id} ^^^ <==================") 190 | print() 191 | print() 192 | print() 193 | 194 | except Exception as e: 195 | # raise e 196 | if verbose >= 2: 197 | print() 198 | print("============> Query Failed <============") 199 | print(query["program"]) 200 | print("========================================") 201 | print() 202 | 203 | if verbose >= 1: 204 | print() 205 | print(f"QUERY {query_id} THROWS EXCEPTION!") 206 | print() 207 | print(os.linesep.join([s for s in query["program"].splitlines() if s])) 208 | print() 209 | traceback.print_exc(file=sys.stdout) 210 | print() 211 | 212 | result.acc025 = False 213 | result.acc025_potential = False 214 | result.acc05 = False 215 | result.acc05_potential = False 216 | 217 | result.eval_result = query.copy() 218 | result.eval_result["acc05"] = False 219 | result.eval_result["acc025"] = False 220 | 221 | result.error = True 222 | 223 | return result 224 | 225 | 226 | def eval_single_nr3d( 227 | query_info: tuple[int, dict], 228 | mask3d_pred_path: str | None = None, 229 | maskcluster_pred_path: str | None = None, 230 | cache_root: str | None = None, 231 | solver_type: str = "default", 232 | select_solution: str = "min_dist", 233 | verbose: int = 0, 234 | ) -> EvalSingleResultNr3D: 235 | query_id, query = query_info 236 | result = EvalSingleResultNr3D() 237 | result.eval_result = query.copy() 238 | result.eval_result["success"] = False 239 | 240 | assert "view_dependent" in query 241 | assert "hard" in query 242 | result.is_hard = query["hard"] 243 | result.is_view_dependent = query["view_dependent"] 244 | 245 | scene_id = query["scene_id"] 246 | scene = ScanNetScene( 247 | f"./data/scans/{scene_id}", 248 | mask3d_pred_path=mask3d_pred_path, 249 | maskcluster_pred_path=maskcluster_pred_path, 250 | add_room_center=True, 251 | add_room_corners=True, 252 | cache_root=cache_root, 253 | ) 254 | # scene.visualize() 255 | 256 | # used to get groundtruth bounding boxes 257 | gt_scene = ScanNetScene( 258 | f"./data/scans/{scene_id}", 259 | mask3d_pred_path=None, 260 | maskcluster_pred_path=None, 261 | add_room_center=False, 262 | add_room_corners=False, 263 | cache_root=cache_root, 264 | ) 265 | 266 | if verbose >= 2: 267 | print(f"==================> vvv Query {query_id} vvv <==================") 268 | print(query["text"]) 269 | print(query["scene_id"]) 270 | print(query["target_id"]) 271 | print(query["target_label"]) 272 | print(query["program"]) 273 | print() 274 | 275 | use_min_dist_heuristic = True 276 | query_text = query["text"].lower() 277 | if any( 278 | x in query_text 279 | for x in {"far", "across", "opposite", "away", "remote", "distant"} 280 | ): 281 | use_min_dist_heuristic = False 282 | 283 | try: 284 | # this is such a mess... it would be better to implement an interpreter manually, 285 | # which will provide better error messages and debugging support 286 | 287 | set_instance_map(scene.get_instance_map()) 288 | set_room_center(scene.get_room_center()) 289 | reset_csp_solver() 290 | TargetInfo.reset() 291 | 292 | exec(query["program"], get_eval_scope(use_type_check_funcs=False)) 293 | run_csp_solver( 294 | query=query["text"], 295 | solver_type=solver_type, 296 | select_solution=select_solution, 297 | ) 298 | 299 | except Exception as e: 300 | # raise e 301 | if verbose >= 2: 302 | print() 303 | print("============> Query Failed <============") 304 | print(query["program"]) 305 | print("========================================") 306 | print() 307 | 308 | if verbose >= 1: 309 | print() 310 | print(f"QUERY {query_id} THROWS EXCEPTION!") 311 | print() 312 | print( 313 | os.linesep.join( 314 | [s for s in query["program"].splitlines() if s.rstrip()] 315 | ) 316 | ) 317 | print() 318 | traceback.print_exc(file=sys.stdout) 319 | print() 320 | 321 | result.error = True 322 | 323 | if TargetInfo.best_instance is not None: 324 | if str(TargetInfo.best_instance.inst_id) == str(query["target_id"]): 325 | result.success = True 326 | result.eval_result["success"] = True 327 | if verbose >= 2: 328 | print() 329 | print("*** SUCCESSFUL ***") 330 | print() 331 | 332 | if str(TargetInfo.best_instance.inst_id) in set( 333 | str(inst.inst_id) for inst in TargetInfo.candidate_instances 334 | ): 335 | result.could_success = True 336 | if verbose >= 2: 337 | print() 338 | print("*** COULD BE SUCCESSFUL ***") 339 | print() 340 | 341 | result.eval_result["predicted_bbox"] = { 342 | "pmin": [float(x) for x in list(TargetInfo.best_instance.bbox.pmin)], 343 | "pmax": [float(x) for x in list(TargetInfo.best_instance.bbox.pmax)], 344 | } 345 | 346 | result.eval_result["anchor_bboxes"] = { 347 | name: { 348 | "pmin": [float(x) for x in list(inst.bbox.pmin)], 349 | "pmax": [float(x) for x in list(inst.bbox.pmax)], 350 | } 351 | if isinstance(inst, ObjInstance) 352 | else [ 353 | { 354 | "pmin": [float(x) for x in list(y.bbox.pmin)], 355 | "pmax": [float(x) for x in list(y.bbox.pmax)], 356 | } 357 | for y in inst 358 | ] 359 | for name, inst in TargetInfo.anchor_instances.items() 360 | } 361 | 362 | result.eval_result["csp_desc"] = TargetInfo.csp_desc 363 | result.llm_used = TargetInfo.llm_used 364 | 365 | if verbose >= 2: 366 | # print(TargetInfo.label) 367 | # for inst in TargetInfo.instances: 368 | # print(inst.inst_id, inst.label) 369 | print(f"==================> ^^^ Query {query_id} ^^^ <==================") 370 | print() 371 | print() 372 | print() 373 | 374 | assert isinstance(result, EvalSingleResultNr3D) 375 | return result 376 | 377 | 378 | def main(): 379 | # register predefined functions 380 | # import program_functions as _ 381 | import program_functions_csp as _ 382 | 383 | parser = argparse.ArgumentParser() 384 | parser.add_argument( 385 | "--solver", type=str, choices=["default", "non_csp"], default="default" 386 | ) 387 | parser.add_argument( 388 | "--select-solution", 389 | type=str, 390 | choices=["min_dist", "max_dist", "random", "first"], 391 | default="min_dist", 392 | ) 393 | parser.add_argument("--dataset", type=str, required=True) 394 | parser.add_argument("--experiment-name", type=str) 395 | parser.add_argument("--seg", type=str, choices=["gt", "mask3d"], required=True) 396 | parser.add_argument("--num-threads", type=int, default=30) 397 | parser.add_argument("--num-queries", type=int) 398 | parser.add_argument("--verbose", type=int, default=0) 399 | parser.add_argument("--print-if-succeed", action="store_true") 400 | args = parser.parse_args() 401 | 402 | tags = [] 403 | 404 | if args.seg == "gt": 405 | print("using gt segmentation") 406 | tags.append("gt") 407 | 408 | mask3d_pred = None 409 | maskcluster_pred = None 410 | 411 | if args.seg == "mask3d": 412 | print("using mask3d segmentation") 413 | mask3d_pred = "./data/eval_output/mask3d_val" 414 | tags.append("mask3d") 415 | 416 | if args.seg == "maskcluster": 417 | print("using maskclustering segmentation") 418 | maskcluster_pred = "./data/eval_output/maskcluster" 419 | tags.append("maskcluster") 420 | raise NotImplementedError() 421 | 422 | assert len(tags) == 1 423 | 424 | if args.dataset == "scanrefer": 425 | print("using scanrefer dataset") 426 | tags.append("scanrefer") 427 | 428 | if args.dataset == "nr3d": 429 | print("using nr3d dataset") 430 | tags.append("nr3d") 431 | 432 | if args.dataset == "custom": 433 | print("using custom dataset") 434 | tags.append("custom") 435 | 436 | if len(tags) < 2: 437 | assert args.dataset 438 | print(f"using custom dataset: {args.dataset}") 439 | tags.append(args.dataset) 440 | 441 | assert len(tags) == 2 442 | 443 | if args.experiment_name: 444 | print(f"using experiment name: {args.experiment_name}") 445 | tags.append(args.experiment_name) 446 | 447 | cache_root = "./data/instance_cache" 448 | 449 | eval_file_path = f"./output/eval_data_{'_'.join(tags)}.json" 450 | eval_results_file_path = f"./output/eval_results_{'_'.join(tags)}.json" 451 | 452 | print(f"loading eval file: [{eval_file_path}].") 453 | print(f"will write results to file: [{eval_results_file_path}].") 454 | 455 | if not os.path.isfile(eval_file_path): 456 | print("eval file not found.") 457 | return 458 | 459 | print(f"using csp solver: {args.solver}.") 460 | 461 | with open(eval_file_path) as f: 462 | eval_data = json.load(f) 463 | 464 | if args.num_queries is not None: 465 | random.seed() 466 | random.shuffle(eval_data) 467 | eval_data = eval_data[0 : args.num_queries] 468 | 469 | num_threads = args.num_threads 470 | num_errors = 0 471 | num_llm_used = 0 472 | num_llm_correct = 0 473 | verbose = args.verbose 474 | 475 | print() 476 | 477 | if args.dataset != "nr3d": 478 | # accumulate some statistics 479 | num_acc05 = 0 480 | num_acc025 = 0 481 | num_acc05_potential = 0 482 | num_acc025_potential = 0 483 | num_evals = 0 484 | 485 | num_acc05_unique = 0 486 | num_acc025_unique = 0 487 | num_acc05_potential_unique = 0 488 | num_acc025_potential_unique = 0 489 | num_evals_unique = 0 490 | 491 | num_acc05_multiple = 0 492 | num_acc025_multiple = 0 493 | num_acc05_potential_multiple = 0 494 | num_acc025_potential_multiple = 0 495 | num_evals_multiple = 0 496 | 497 | good_query_indices = [] 498 | 499 | # start evaluation 500 | eval_results = [] 501 | eval_data = eval_data[:] 502 | 503 | eval_single_func = functools.partial( 504 | eval_single_scanrefer, 505 | mask3d_pred_path=mask3d_pred, 506 | maskcluster_pred_path=maskcluster_pred, 507 | cache_root=cache_root, 508 | solver_type=args.solver, 509 | select_solution=args.select_solution, 510 | verbose=verbose, 511 | ) 512 | 513 | with Pool(num_threads) as pool: 514 | for result in tqdm( 515 | pool.imap_unordered( 516 | eval_single_func, 517 | zip(range(len(eval_data)), eval_data), 518 | ), 519 | total=len(eval_data), 520 | ): 521 | if args.print_if_succeed: 522 | print() 523 | print(f"query: {result.eval_result['text']}") 524 | print(f"program:\n{result.eval_result['program']}") 525 | print("success" if result.acc05 else "failure") 526 | print() 527 | 528 | num_errors += int(result.error) 529 | if result.llm_used: 530 | num_llm_used += 1 531 | num_llm_correct += int(result.acc05) 532 | 533 | num_acc05 += int(result.acc05) 534 | num_acc025 += int(result.acc025) 535 | num_acc05_potential += int(result.acc05_potential) 536 | num_acc025_potential += int(result.acc025_potential) 537 | num_evals += 1 538 | 539 | if result.is_unique: 540 | num_acc05_unique += int(result.acc05) 541 | num_acc025_unique += int(result.acc025) 542 | num_acc05_potential_unique += int(result.acc05_potential) 543 | num_acc025_potential_unique += int(result.acc025_potential) 544 | num_evals_unique += 1 545 | else: 546 | num_acc05_multiple += int(result.acc05) 547 | num_acc025_multiple += int(result.acc025) 548 | num_acc05_potential_multiple += int(result.acc05_potential) 549 | num_acc025_potential_multiple += int(result.acc025_potential) 550 | num_evals_multiple += 1 551 | 552 | assert result.eval_result 553 | eval_results.append(result.eval_result) 554 | 555 | if result.acc05: 556 | good_query_indices.append(result.eval_result["query_id"]) 557 | 558 | print() 559 | print("========================= evalutation results =========================") 560 | print() 561 | 562 | # fmt: off 563 | print(f"acc@0.5: {num_acc05 / num_evals:.4f} ({num_acc05} / {num_evals})") 564 | print(f"acc@0.25: {num_acc025 / num_evals:.4f} ({num_acc025} / {num_evals})") 565 | print(f"acc@0.5 (?): {num_acc05_potential / num_evals:.4f} ({num_acc05_potential} / {num_evals})") 566 | print(f"acc@0.25 (?): {num_acc025_potential / num_evals:.4f} ({num_acc025_potential} / {num_evals})") 567 | 568 | if num_evals_unique > 0: 569 | print(f"acc@0.5 (u): {num_acc05_unique / num_evals_unique:.4f} ({num_acc05_unique} / {num_evals_unique})") 570 | print(f"acc@0.25 (u): {num_acc025_unique / num_evals_unique:.4f} ({num_acc025_unique} / {num_evals_unique})") 571 | 572 | if num_evals_multiple > 0: 573 | print(f"acc@0.5 (m): {num_acc05_multiple / num_evals_multiple:.4f} ({num_acc05_multiple} / {num_evals_multiple})") 574 | print(f"acc@0.25 (m): {num_acc025_multiple / num_evals_multiple:.4f} ({num_acc025_multiple} / {num_evals_multiple})") 575 | 576 | print(f"errors: {num_errors}") 577 | print(f"llm: {num_llm_correct} / {num_llm_used}") 578 | # fmt: on 579 | 580 | # print(sorted(good_query_indices)) 581 | 582 | print() 583 | print("========================= evalutation results =========================") 584 | print() 585 | 586 | if args.dataset == "nr3d": 587 | num_overall = 0 588 | num_overall_success = 0 589 | 590 | num_easy = 0 591 | num_easy_success = 0 592 | 593 | num_hard = 0 594 | num_hard_success = 0 595 | 596 | num_view_dep = 0 597 | num_view_dep_success = 0 598 | 599 | num_view_indep = 0 600 | num_view_indep_success = 0 601 | 602 | # start evaluation 603 | eval_results = [] 604 | eval_data = eval_data[:] 605 | 606 | eval_single_func = functools.partial( 607 | eval_single_nr3d, 608 | mask3d_pred_path=mask3d_pred, 609 | maskcluster_pred_path=maskcluster_pred, 610 | cache_root=cache_root, 611 | solver_type=args.solver, 612 | select_solution=args.select_solution, 613 | verbose=verbose, 614 | ) 615 | 616 | with Pool(num_threads) as pool: 617 | for result in tqdm( 618 | pool.imap_unordered( 619 | eval_single_func, 620 | zip(range(len(eval_data)), eval_data), 621 | ), 622 | total=len(eval_data), 623 | ): 624 | if args.print_if_succeed: 625 | print() 626 | print(f"query: {result.eval_result['text']}") 627 | print("success" if result.acc05 else "failure") 628 | print() 629 | 630 | num_errors += int(result.error) 631 | if result.llm_used: 632 | num_llm_used += 1 633 | num_llm_correct += int(result.acc05) 634 | 635 | num_overall += 1 636 | num_overall_success += int(result.success) 637 | 638 | if result.is_hard: 639 | num_hard += 1 640 | num_hard_success += int(result.success) 641 | else: 642 | num_easy += 1 643 | num_easy_success += int(result.success) 644 | 645 | if result.is_view_dependent: 646 | num_view_dep += 1 647 | num_view_dep_success += int(result.success) 648 | else: 649 | num_view_indep += 1 650 | num_view_indep_success += int(result.success) 651 | 652 | assert result.eval_result 653 | eval_results.append(result.eval_result) 654 | 655 | print() 656 | print("========================= evalutation results =========================") 657 | print() 658 | 659 | # fmt: off 660 | print(f"overall: {num_overall_success / num_overall:.4f} ({num_overall_success} / {num_overall})") 661 | print(f"easy: {num_easy_success / num_easy:.4f} ({num_easy_success} / {num_easy})") 662 | print(f"hard: {num_hard_success / num_hard:.4f} ({num_hard_success} / {num_hard})") 663 | print(f"view dep: {num_view_dep_success / num_view_dep:.4f} ({num_view_dep_success} / {num_view_dep})") 664 | print(f"view indep: {num_view_indep_success / num_view_indep:.4f} ({num_view_indep_success} / {num_view_indep})") 665 | print(f"errors: {num_errors}") 666 | print(f"llm: {num_llm_correct} / {num_llm_used}") 667 | # fmt: on 668 | 669 | print() 670 | print("========================= evalutation results =========================") 671 | print() 672 | 673 | with open(eval_results_file_path, "w") as f: 674 | json.dump(eval_results, f) 675 | 676 | 677 | if __name__ == "__main__": 678 | # random.seed(0) 679 | main() 680 | -------------------------------------------------------------------------------- /program_validator.py: -------------------------------------------------------------------------------- 1 | import re 2 | import textwrap 3 | 4 | 5 | def validate_program(prog: str): 6 | """ 7 | validate the input program (@prog). 8 | if the program is valid or can be corrected, return a clean and corrected version of it. 9 | if the program contains fatal errors, return all errors. 10 | """ 11 | # extract python code if the llm output is in markdown format (which happens from time to time...) 12 | if "```" in prog: 13 | if prog.count("```") != 2: 14 | print() 15 | print("=>> ↓↓↓ PROGRAM VALIDATION FAILED! ↓↓↓ <<=") 16 | print() 17 | print(prog) 18 | print() 19 | print("=>> ↑↑↑ PROGRAM VALIDATION FAILED! ↑↑↑ <<=") 20 | print() 21 | return textwrap.dedent(prog).strip() 22 | 23 | is_code = False 24 | code_lines = [] 25 | 26 | for line in prog.split("\n"): 27 | match = re.match( 28 | r"(?:```\s*python|```\s*Python|```\s*PYTHON|```)(.*?)$", 29 | line.strip(), 30 | ) 31 | if match: 32 | is_code = not is_code 33 | continue # simply skip this line 34 | 35 | if is_code: 36 | # assuming no indent in the code 37 | code_lines.append(line.strip()) 38 | 39 | prog = "\n".join(code_lines) 40 | 41 | prog = textwrap.dedent(prog).strip() 42 | 43 | # print("cleaned program:") 44 | # print() 45 | # print(prog) 46 | # print() 47 | 48 | # maybe we don't need a validator for now... 49 | 50 | return prog 51 | -------------------------------------------------------------------------------- /prompts/filter_relevant_objects.txt: -------------------------------------------------------------------------------- 1 | <[SYSTEM]> 2 | 3 | You are given the 3D visual grounding task. The input is a language query describing some object(s) in the 3D scene. Besides that, you are also given a list of available objects that actually exist in the 3D scene. Your task is to filter the object list and find only those objects relevant to the input query. 4 | 5 | Objects that are directly mentioned in the input query should be deemed as being relevant. Basically you should check all nouns in the input query and output those with corresponding entries in the list of available objects. 6 | 7 | Sometimes there are multiple object labels that actually refer to the same thing. In this case you should keep only one of them. 8 | 9 | Please only output the relevant objects. You SHOULD NOT output any explainations or comments. If you are unsure of if an object is relevant or not, just add it to the output. Please follow the format of your previous outputs strictly. 10 | 11 | If something is in the middle or the room, or is described as being "farthest", you should include the room center object as relevant. There are many similar phrases, and you should detect them and include the room center object is applicable. 12 | 13 | You should pay attention to similar labels. For example, if "desk" is mentioned in the query, "table" could also be a relevant object, and vice versa. "Bookshelf" and "shelf" might also be related. You should generalize this idea to other labels. 14 | 15 | Again, Please ***DO NOT*** output ANY EXTRA COMMENTS! Please include ALL objects that are possibly relevant! 16 | 17 | Please DO NOT output empty results! There must be at least one relevant object! 18 | 19 | <[USER]> 20 | 21 | QUERY: 22 | this is a grey chair. its at the end of the shorter coffee table closest to the window. 23 | 24 | OBJECTS IN 3D SCENE: 25 | [0] room center 26 | [1] vase 27 | [2] table 28 | [3] desk 29 | [4] chair 30 | [5] window 31 | [6] floor 32 | [7] refrigerator 33 | [8] trash can 34 | [9] coffee table 35 | 36 | <[ASSISTANT]> 37 | 38 | RELEVANT OBJECTS: 39 | @obj [4] chair 40 | @obj [5] window 41 | @obj [9] coffee table 42 | 43 | <[USER]> 44 | 45 | QUERY: 46 | it is the tallest chair at the center of the room. 47 | 48 | OBJECTS IN 3D SCENE: 49 | [0] room center 50 | [1] vase 51 | [2] table 52 | [3] desk 53 | [4] chair 54 | [5] window 55 | [6] floor 56 | [7] refrigerator 57 | [8] trash can 58 | [9] coffee table 59 | 60 | <[ASSISTANT]> 61 | 62 | RELEVANT OBJECTS: 63 | @obj [0] room center 64 | @obj [4] chair 65 | 66 | <[USER]> 67 | 68 | QUERY: 69 | it is the white ceramic vase on the window sill. 70 | 71 | OBJECTS IN 3D SCENE: 72 | [0] room center 73 | [1] vase 74 | [2] table 75 | [3] desk 76 | [4] chair 77 | [5] window 78 | [6] floor 79 | [7] refrigerator 80 | [8] trash can 81 | [9] coffee table 82 | [10] window sill 83 | [11] window ledge 84 | [12] windowsills 85 | 86 | <[ASSISTANT]> 87 | 88 | RELEVANT OBJECTS: 89 | @obj [1] vase 90 | @obj [10] window sill 91 | 92 | -------------------------------------------------------------------------------- /prompts/generate_program_csp.txt: -------------------------------------------------------------------------------- 1 | <[SYSTEM]> 2 | 3 | You are given the 3D visual grounding task. The input is a language query describing some object(s) in the 3D scene. Besides that, you are also given a list of object labels that are relevant to this query. Any other objects should not be considered. 4 | 5 | Your task is to generate a Python program that locate the target object specified by the query. You don't have to use all the relevant object labels given to you. Some of them may actually be irrelevant. What's important is, you should always use the labels of the relevant objects given to you. 6 | 7 | The Python program you generate should solve a CSP (Constraint Satisfication Problem). Solving the CSP will result in the correct grounding of the target objects. The variables are objects and constraints are spatial relationships among objects or appearance information of objects. 8 | 9 | You should first try to do some spatial reasoning of the input query and simplify the logic. Some relations between objects can be converted into simpler ones. You should always first try to reduce the complexity of the input query. E.g. an object that has one object on the left and another on the right is between those two objects. So you should use a single relation "between", instead of two relations "left" and "right". There are many similar cases. 10 | 11 | There are two types of variables: normal variables (DEFINE_VARIABLE), which represent a single object, and negative variables (DEFINE_NEGATIVE_VARIABLE), which indicate the abscence of an object satisfying some spatial relations. You should only use negative variables when the query mentions some negative conditions with e.g., 'do not' or 'without'. 12 | 13 | Your should only use the following predefined functions in your program. No other functions are allowed. 14 | 15 | <[REGISTERED_FUNCTIONS_PLACEHOLDER]> 16 | 17 | Some of the predefined functions above have a "score_func" parameter, which specifies how to compare objects. The following score functions are available. 18 | 19 | <[REGISTERED_SCORE_FUNCTIONS_PLACEHOLDER]> 20 | 21 | Your output should be a valid Python program, which means that any additional explanation should be comments (start with #). You only have to output the python code. Please DO NOT follow markdown conventions. You DO NOT have to enclose the code with ```. 22 | 23 | Some extra tips: 24 | - "North", "south", "east" and "west" should always be translated into "near". 25 | - If something is at the side of a room, it should be far from the room center. 26 | - If an object you need is not in the relevant object list, please ignore information about that object, and do not use it. 27 | - You should not ignore relations that can be represented with the given relevant object labels and available functions. 28 | - Relations like "farthest", "northwestern-most" in reference to the room should be translated to being farthest to the room center. 29 | - You should identify only ONE target, to which the given 3D visual grounding query is referring to. You should always set the target at the very end. No function calls are allowed after setting the target. 30 | - Please DO NOT define variables that you will not use later. Every variables defined must be set in relation with at least one other variable. 31 | - Sometimes multiple labels refer to similar objects. For example, different queries may use "desk" or "table" to refer to the same thing. So when you need a "desk", and only "table" are available, you should use table, and vice versa. 32 | - For each object, you either define a normal variable or a negative variable. Please do not define both. 33 | 34 | <[USER]> 35 | 36 | QUERY: 37 | this is a grey chair. it is at the end of the shorter coffee table closest to the window. 38 | 39 | RELEVANT OBJECT LABELS: 40 | [0] chair 41 | [1] window 42 | [2] coffee table 43 | 44 | <[ASSISTANT]> 45 | 46 | # define variables for the CSP 47 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 48 | WINDOW_0 = DEFINE_VARIABLE(labels=["window"]) 49 | COFFEE_TABLE_0 = DEFINE_VARIABLE(labels=["coffee table"]) 50 | 51 | # add a constraint: the grey chair is at the end of the shorter coffee table 52 | CONSTRAINT_NEAR(target=CHAIR_0, anchor=COFFEE_TABLE_0) 53 | 54 | # add a constraint: the coffee table is the one closest to the window 55 | CONSTRAINT_MIN_OF(target=COFFEE_TABLE_0, anchor=WINDOW_0, score_func="distance") 56 | 57 | # finally set the CSP variable representing the target object 58 | SET_TARGET(CHAIR_0) 59 | 60 | <[USER]> 61 | 62 | QUERY: 63 | the desk is one of three in the row. there is a desk to the right of it, and a desk to the left of it. 64 | 65 | RELEVANT OBJECT LABELS: 66 | [0] desk 67 | 68 | <[ASSISTANT]> 69 | 70 | # since there are three desks in total, I'll create an object to represent each of them 71 | DESK_MIDDLE = DEFINE_VARIABLE(labels=["desk"]) 72 | DESK_LEFT = DEFINE_VARIABLE(labels=["desk"]) 73 | DESK_RIGHT = DEFINE_VARIABLE(labels=["desk"]) 74 | 75 | # the target desk has one desk to the right, one to the left, therefore by a simple reasoning I know that it is the desk between two other desks 76 | CONSTRAINT_BETWEEN(target=DESK_MIDDLE, anchors={DESK_LEFT, DESK_RIGHT}) 77 | 78 | # set the variable for the target object 79 | SET_TARGET(DESK_MIDDLE) 80 | 81 | <[USER]> 82 | 83 | QUERY: 84 | it is the tallest chair at the center of the room. 85 | 86 | RELEVANT OBJECT LABELS: 87 | [0] room center 88 | [1] chair 89 | 90 | <[ASSISTANT]> 91 | 92 | # define a variable for the target chair object 93 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 94 | 95 | # define the room center variable 96 | ROOM_CENTER = DEFINE_VARIABLE(labels=["room center"]) 97 | 98 | # add a constraint: the chair is near the room center 99 | # we pass the set of all chairs in order to keep only chairs at the room center 100 | CONSTRAINT_NEAR(target=CHAIR_0, anchor=ROOM_CENTER) 101 | 102 | # add a constraint: the chair is the tallest one, i.e. the one with largest height 103 | # we find the chair with the largest extent along z-axis in the set of chairs at the room center 104 | CONSTRAINT_MAX_OF(target=CHAIR_0, score_func="size-z") 105 | 106 | # set the final target 107 | SET_TARGET(CHAIR_0) 108 | 109 | <[USER]> 110 | 111 | QUERY: 112 | this is a blue trash can that is beside a orange one. it is not the trash can far from other trash cans. 113 | 114 | RELEVANT OBJECT LABELS: 115 | [0] trash can 116 | 117 | <[ASSISTANT]> 118 | 119 | # define a variable for the blue trash can (target) 120 | TRASH_CAN_BLUE = DEFINE_VARIABLE(labels=["trash can"]) 121 | 122 | # define a variable for the orange trash can 123 | TRASH_CAN_ORANGE = DEFINE_VARIABLE(labels=["trash can"]) 124 | 125 | # define a variable for the far away trash can 126 | TRASH_CAN_FAR = DEFINE_VARIABLE(labels=["trash can"]) 127 | 128 | # the blue trash can should be beside the orange one 129 | CONSTRAINT_NEAR(target=TRASH_CAN_BLUE, anchor=TRASH_CAN_ORANGE) 130 | 131 | # the far away trash can should be far from the other trash cans 132 | CONSTRAINT_FAR(target=TRASH_CAN_FAR, anchor=TRASH_CAN_BLUE) 133 | CONSTRAINT_FAR(target=TRASH_CAN_FAR, anchor=TRASH_CAN_ORANGE) 134 | 135 | # finally set the variable for the target object 136 | SET_TARGET(TRASH_CAN_BLUE) 137 | 138 | <[USER]> 139 | 140 | QUERY: 141 | it is the smallest cup on the largest table. 142 | 143 | RELEVANT OBJECT LABELS: 144 | [0] cup 145 | [1] table 146 | 147 | <[ASSISTANT]> 148 | 149 | # define a variable for the largest table 150 | TABLE_0 = DEFINE_VARIABLE(labels=["table"]) 151 | 152 | # define a variable for the smallest cup on the largest table 153 | CUP_0 = DEFINE_VARIABLE(labels=["cup"]) 154 | 155 | # constraint: the largest table 156 | CONSTRAINT_MAX_OF(target=TABLE_0, score_func="size") 157 | 158 | # constraint: the cup is on the table 159 | CONSTRAINT_ON(target=CUP_0, anchor=TABLE_0) 160 | 161 | # constraint: the smallest cup 162 | CONSTRAINT_MIN_OF(target=CUP_0, score_func="size") 163 | 164 | # set the variable for the target object 165 | SET_TARGET(CUP_0) 166 | 167 | <[USER]> 168 | 169 | QUERY: 170 | this is the third chair from the left. 171 | 172 | RELEVANT OBJECT LABELS: 173 | [0] chair 174 | 175 | <[ASSISTANT]> 176 | 177 | # this is an example of how to handle counting-based query 178 | 179 | # since the target is the third chair from the left, there must be at least three chairs in total 180 | # so we define a variable for each of them 181 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 182 | CHAIR_1 = DEFINE_VARIABLE(labels=["chair"]) 183 | CHAIR_2 = DEFINE_VARIABLE(labels=["chair"]) 184 | 185 | # constraint: the target chair has a chair on the left, which also has another chair to its left 186 | CONSTRAINT_LEFT(target=CHAIR_1, anchor=CHAIR_0) 187 | CONSTRAINT_LEFT(target=CHAIR_2, anchor=CHAIR_1) 188 | 189 | # set the variable for the target object 190 | SET_TARGET(CHAIR_0) 191 | 192 | <[USER]> 193 | 194 | QUERY: 195 | this is the chair third closest to the corner of the room. 196 | 197 | RELEVANT OBJECT LABELS: 198 | [0] chair 199 | [1] room corner 200 | 201 | <[ASSISTANT]> 202 | 203 | # this is another example of how to handle counting-based query 204 | # it is also a comparison-based query 205 | 206 | # since the target is the third chair from the room corner, there must be at least three chairs in total 207 | # so we define a variable for each of them 208 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 209 | CHAIR_1 = DEFINE_VARIABLE(labels=["chair"]) 210 | CHAIR_2 = DEFINE_VARIABLE(labels=["chair"]) 211 | 212 | # define a variable for the corner of the room 213 | ROOM_CORNER = DEFINE_VARIABLE(labels=["room corner"]) 214 | 215 | # constraint: the target chair (CHAIR_2) is farther from the room corner than another chair (CHAIR_1), which (CHAIR_1) again is farther from the room corner than yet a third chair (CHAIR_0) 216 | CONSTRAINT_MORE(target=CHAIR_2, reference=CHAIR_1, anchor=ROOM_CORNER, score_func="distance") 217 | CONSTRAINT_MORE(target=CHAIR_1, reference=CHAIR_0, anchor=ROOM_CORNER, score_func="distance") 218 | 219 | # set the variable for the target object 220 | SET_TARGET(CHAIR_2) 221 | 222 | <[USER]> 223 | 224 | QUERY: 225 | there are some windows on the wall. this is the window in the middle. 226 | 227 | RELEVANT OBJECT LABELS: 228 | [0] wall 229 | [1] window 230 | 231 | <[ASSISTANT]> 232 | 233 | # define variables for the window and the wall 234 | WALL_0 = DEFINE_VARIABLE(labels=["wall"]) 235 | WINDOW_0 = DEFINE_VARIABLE(labels=["window"]) 236 | 237 | # the windows should be near the wall 238 | CONSTRAINT_NEAR(target=WINDOW_0, anchor=WALL_0) 239 | 240 | # the target window should be in the middle, i.e., the "distance-to-center" score must be minimized 241 | CONSTRAINT_MIN_OF(target=WINDOW_0, score_func="distance-to-center") 242 | 243 | # set the variable for the target object 244 | SET_TARGET(WINDOW_0) 245 | 246 | <[USER]> 247 | 248 | QUERY: 249 | this is a chair. it does not have a teddy bear sitting on it. 250 | 251 | RELEVANT OBJECT LABELS: 252 | [0] chair 253 | [1] teddy bear 254 | 255 | <[ASSISTANT]> 256 | 257 | # this is an example of using negative variables 258 | 259 | # define variables for the chair and the teddy bear 260 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 261 | 262 | # define a negative variable to express that the chair has no teddy bear on it 263 | TEDDY_BEAR_NEG = DEFINE_NEGATIVE_VARIABLE(labels=["teddy bear"]) 264 | 265 | # first create a normal constraint: the teddy bear is on the chair 266 | CONSTRAINT_ON(target=TEDDY_BEAR_NEG, anchor=CHAIR_0) 267 | 268 | # set the variable for the target object 269 | SET_TARGET(CHAIR_0) 270 | 271 | <[USER]> 272 | 273 | QUERY: 274 | it is the cup not on the table. 275 | 276 | RELEVANT OBJECT LABELS: 277 | [0] cup 278 | [1] table 279 | 280 | <[ASSISTANT]> 281 | 282 | # this is another example of using negative variables 283 | 284 | # define the cup variable, which is the target 285 | CUP_0 = DEFINE_VARIABLE(labels=["cup"]) 286 | 287 | # define the table variable as negative, since the cup variable is the target and we want to locate it 288 | TABLE_NEG_0 = DEFINE_NEGATIVE_VARIABLE(labels=["table"]) 289 | 290 | # the cup is not on the table 291 | CONSTRAINT_ON(target=CUP_0, anchor=TABLE_NEG_0) 292 | 293 | # set the variable for the target object 294 | SET_TARGET(CUP_0) 295 | 296 | <[USER]> 297 | 298 | QUERY: 299 | this is a recycling trash can. its blue and white in color and is on the left of a box. it is also below a wooden counter. 300 | 301 | RELEVANT OBJECT LABELS: 302 | [0] trash can 303 | [1] box 304 | [2] counter 305 | [3] plastic trash bin 306 | [4] wastebin 307 | 308 | <[ASSISTANT]> 309 | 310 | # define variables for the trash can. since we have multiple labels denoting trash cans (with synonyms), we include all of them in the label. 311 | TRASH_CAN_0 = DEFINE_VARIABLE(labels=["trash can", "plastic trash bin", "wastebin"]) 312 | 313 | # define variables for the box and the counter 314 | BOX_0 = DEFINE_VARIABLE(labels=["box"]) 315 | COUNTER_0 = DEFINE_VARIABLE(labels=["counter"]) 316 | 317 | # add a constraint: the trash can is on the left of a box 318 | CONSTRAINT_LEFT(target=TRASH_CAN_0, anchor=BOX_0) 319 | 320 | # add a constraint: the trash can is also below a wooden counter 321 | CONSTRAINT_UNDER(target=TRASH_CAN_1, anchor=COUNTER_0) 322 | 323 | # set the CSP variable for the target object 324 | SET_TARGET(TRASH_CAN_0) -------------------------------------------------------------------------------- /prompts/generate_program_csp_no_counting.txt: -------------------------------------------------------------------------------- 1 | <[SYSTEM]> 2 | 3 | You are given the 3D visual grounding task. The input is a language query describing some object(s) in the 3D scene. Besides that, you are also given a list of object labels that are relevant to this query. Any other objects should not be considered. 4 | 5 | Your task is to generate a Python program that locate the target object specified by the query. You don't have to use all the relevant object labels given to you. Some of them may actually be irrelevant. What's important is, you should always use the labels of the relevant objects given to you. 6 | 7 | The Python program you generate should solve a CSP (Constraint Satisfication Problem). Solving the CSP will result in the correct grounding of the target objects. The variables are objects and constraints are spatial relationships among objects or appearance information of objects. 8 | 9 | You should first try to do some spatial reasoning of the input query and simplify the logic. Some relations between objects can be converted into simpler ones. You should always first try to reduce the complexity of the input query. E.g. an object that has one object on the left and another on the right is between those two objects. So you should use a single relation "between", instead of two relations "left" and "right". There are many similar cases. 10 | 11 | There are two types of variables: normal variables (DEFINE_VARIABLE), which represent a single object, and negative variables (DEFINE_NEGATIVE_VARIABLE), which indicate the abscence of an object satisfying some spatial relations. You should only use negative variables when the query mentions some negative conditions with e.g., 'do not' or 'without'. 12 | 13 | Your should only use the following predefined functions in your program. No other functions are allowed. 14 | 15 | <[REGISTERED_FUNCTIONS_PLACEHOLDER]> 16 | 17 | Some of the predefined functions above have a "score_func" parameter, which specifies how to compare objects. The following score functions are available. 18 | 19 | <[REGISTERED_SCORE_FUNCTIONS_PLACEHOLDER]> 20 | 21 | Your output should be a valid Python program, which means that any additional explanation should be comments (start with #). You only have to output the python code. Please DO NOT follow markdown conventions. You DO NOT have to enclose the code with ```. 22 | 23 | Some extra tips: 24 | - "North", "south", "east" and "west" should always be translated into "near". 25 | - If something is at the side of a room, it should be far from the room center. 26 | - If an object you need is not in the relevant object list, please ignore information about that object, and do not use it. 27 | - You should not ignore relations that can be represented with the given relevant object labels and available functions. 28 | - Relations like "farthest", "northwestern-most" in reference to the room should be translated to being farthest to the room center. 29 | - You should identify only ONE target, to which the given 3D visual grounding query is referring to. You should always set the target at the very end. No function calls are allowed after setting the target. 30 | - Please DO NOT define variables that you will not use later. Every variables defined must be set in relation with at least one other variable. 31 | - Sometimes multiple labels refer to similar objects. For example, different queries may use "desk" or "table" to refer to the same thing. So when you need a "desk", and only "table" are available, you should use table, and vice versa. 32 | - For each object, you either define a normal variable or a negative variable. Please do not define both. 33 | 34 | <[USER]> 35 | 36 | QUERY: 37 | this is a grey chair. it is at the end of the shorter coffee table closest to the window. 38 | 39 | RELEVANT OBJECT LABELS: 40 | [0] chair 41 | [1] window 42 | [2] coffee table 43 | 44 | <[ASSISTANT]> 45 | 46 | # define variables for the CSP 47 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 48 | WINDOW_0 = DEFINE_VARIABLE(labels=["window"]) 49 | COFFEE_TABLE_0 = DEFINE_VARIABLE(labels=["coffee table"]) 50 | 51 | # add a constraint: the grey chair is at the end of the shorter coffee table 52 | CONSTRAINT_NEAR(target=CHAIR_0, anchor=COFFEE_TABLE_0) 53 | 54 | # add a constraint: the coffee table is the one closest to the window 55 | CONSTRAINT_MIN_OF(target=COFFEE_TABLE_0, anchor=WINDOW_0, score_func="distance") 56 | 57 | # finally set the CSP variable representing the target object 58 | SET_TARGET(CHAIR_0) 59 | 60 | <[USER]> 61 | 62 | QUERY: 63 | the desk is one of three in the row. there is a desk to the right of it, and a desk to the left of it. 64 | 65 | RELEVANT OBJECT LABELS: 66 | [0] desk 67 | 68 | <[ASSISTANT]> 69 | 70 | # since there are three desks in total, I'll create an object to represent each of them 71 | DESK_MIDDLE = DEFINE_VARIABLE(labels=["desk"]) 72 | DESK_LEFT = DEFINE_VARIABLE(labels=["desk"]) 73 | DESK_RIGHT = DEFINE_VARIABLE(labels=["desk"]) 74 | 75 | # the target desk has one desk to the right, one to the left, therefore by a simple reasoning I know that it is the desk between two other desks 76 | CONSTRAINT_BETWEEN(target=DESK_MIDDLE, anchors={DESK_LEFT, DESK_RIGHT}) 77 | 78 | # set the variable for the target object 79 | SET_TARGET(DESK_MIDDLE) 80 | 81 | <[USER]> 82 | 83 | QUERY: 84 | it is the tallest chair at the center of the room. 85 | 86 | RELEVANT OBJECT LABELS: 87 | [0] room center 88 | [1] chair 89 | 90 | <[ASSISTANT]> 91 | 92 | # define a variable for the target chair object 93 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 94 | 95 | # define the room center variable 96 | ROOM_CENTER = DEFINE_VARIABLE(labels=["room center"]) 97 | 98 | # add a constraint: the chair is near the room center 99 | # we pass the set of all chairs in order to keep only chairs at the room center 100 | CONSTRAINT_NEAR(target=CHAIR_0, anchor=ROOM_CENTER) 101 | 102 | # add a constraint: the chair is the tallest one, i.e. the one with largest height 103 | # we find the chair with the largest extent along z-axis in the set of chairs at the room center 104 | CONSTRAINT_MAX_OF(target=CHAIR_0, score_func="size-z") 105 | 106 | # set the final target 107 | SET_TARGET(CHAIR_0) 108 | 109 | <[USER]> 110 | 111 | QUERY: 112 | this is a blue trash can that is beside a orange one. it is not the trash can far from other trash cans. 113 | 114 | RELEVANT OBJECT LABELS: 115 | [0] trash can 116 | 117 | <[ASSISTANT]> 118 | 119 | # define a variable for the blue trash can (target) 120 | TRASH_CAN_BLUE = DEFINE_VARIABLE(labels=["trash can"]) 121 | 122 | # define a variable for the orange trash can 123 | TRASH_CAN_ORANGE = DEFINE_VARIABLE(labels=["trash can"]) 124 | 125 | # define a variable for the far away trash can 126 | TRASH_CAN_FAR = DEFINE_VARIABLE(labels=["trash can"]) 127 | 128 | # the blue trash can should be beside the orange one 129 | CONSTRAINT_NEAR(target=TRASH_CAN_BLUE, anchor=TRASH_CAN_ORANGE) 130 | 131 | # the far away trash can should be far from the other trash cans 132 | CONSTRAINT_FAR(target=TRASH_CAN_FAR, anchor=TRASH_CAN_BLUE) 133 | CONSTRAINT_FAR(target=TRASH_CAN_FAR, anchor=TRASH_CAN_ORANGE) 134 | 135 | # finally set the variable for the target object 136 | SET_TARGET(TRASH_CAN_BLUE) 137 | 138 | <[USER]> 139 | 140 | QUERY: 141 | it is the smallest cup on the largest table. 142 | 143 | RELEVANT OBJECT LABELS: 144 | [0] cup 145 | [1] table 146 | 147 | <[ASSISTANT]> 148 | 149 | # define a variable for the largest table 150 | TABLE_0 = DEFINE_VARIABLE(labels=["table"]) 151 | 152 | # define a variable for the smallest cup on the largest table 153 | CUP_0 = DEFINE_VARIABLE(labels=["cup"]) 154 | 155 | # constraint: the largest table 156 | CONSTRAINT_MAX_OF(target=TABLE_0, score_func="size") 157 | 158 | # constraint: the cup is on the table 159 | CONSTRAINT_ON(target=CUP_0, anchor=TABLE_0) 160 | 161 | # constraint: the smallest cup 162 | CONSTRAINT_MIN_OF(target=CUP_0, score_func="size") 163 | 164 | # set the variable for the target object 165 | SET_TARGET(CUP_0) 166 | 167 | <[USER]> 168 | 169 | QUERY: 170 | there are some windows on the wall. this is the window in the middle. 171 | 172 | RELEVANT OBJECT LABELS: 173 | [0] wall 174 | [1] window 175 | 176 | <[ASSISTANT]> 177 | 178 | # define variables for the window and the wall 179 | WALL_0 = DEFINE_VARIABLE(labels=["wall"]) 180 | WINDOW_0 = DEFINE_VARIABLE(labels=["window"]) 181 | 182 | # the windows should be near the wall 183 | CONSTRAINT_NEAR(target=WINDOW_0, anchor=WALL_0) 184 | 185 | # the target window should be in the middle, i.e., the "distance-to-center" score must be minimized 186 | CONSTRAINT_MIN_OF(target=WINDOW_0, score_func="distance-to-center") 187 | 188 | # set the variable for the target object 189 | SET_TARGET(WINDOW_0) 190 | 191 | <[USER]> 192 | 193 | QUERY: 194 | this is a chair. it does not have a teddy bear sitting on it. 195 | 196 | RELEVANT OBJECT LABELS: 197 | [0] chair 198 | [1] teddy bear 199 | 200 | <[ASSISTANT]> 201 | 202 | # this is an example of using negative variables 203 | 204 | # define variables for the chair and the teddy bear 205 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 206 | 207 | # define a negative variable to express that the chair has no teddy bear on it 208 | TEDDY_BEAR_NEG = DEFINE_NEGATIVE_VARIABLE(labels=["teddy bear"]) 209 | 210 | # first create a normal constraint: the teddy bear is on the chair 211 | CONSTRAINT_ON(target=TEDDY_BEAR_NEG, anchor=CHAIR_0) 212 | 213 | # set the variable for the target object 214 | SET_TARGET(CHAIR_0) 215 | 216 | <[USER]> 217 | 218 | QUERY: 219 | it is the cup not on the table. 220 | 221 | RELEVANT OBJECT LABELS: 222 | [0] cup 223 | [1] table 224 | 225 | <[ASSISTANT]> 226 | 227 | # this is another example of using negative variables 228 | 229 | # define the cup variable, which is the target 230 | CUP_0 = DEFINE_VARIABLE(labels=["cup"]) 231 | 232 | # define the table variable as negative, since the cup variable is the target and we want to locate it 233 | TABLE_NEG_0 = DEFINE_NEGATIVE_VARIABLE(labels=["table"]) 234 | 235 | # the cup is not on the table 236 | CONSTRAINT_ON(target=CUP_0, anchor=TABLE_NEG_0) 237 | 238 | # set the variable for the target object 239 | SET_TARGET(CUP_0) 240 | 241 | <[USER]> 242 | 243 | QUERY: 244 | this is a recycling trash can. its blue and white in color and is on the left of a box. it is also below a wooden counter. 245 | 246 | RELEVANT OBJECT LABELS: 247 | [0] trash can 248 | [1] box 249 | [2] counter 250 | [3] plastic trash bin 251 | [4] wastebin 252 | 253 | <[ASSISTANT]> 254 | 255 | # define variables for the trash can. since we have multiple labels denoting trash cans (with synonyms), we include all of them in the label. 256 | TRASH_CAN_0 = DEFINE_VARIABLE(labels=["trash can", "plastic trash bin", "wastebin"]) 257 | 258 | # define variables for the box and the counter 259 | BOX_0 = DEFINE_VARIABLE(labels=["box"]) 260 | COUNTER_0 = DEFINE_VARIABLE(labels=["counter"]) 261 | 262 | # add a constraint: the trash can is on the left of a box 263 | CONSTRAINT_LEFT(target=TRASH_CAN_0, anchor=BOX_0) 264 | 265 | # add a constraint: the trash can is also below a wooden counter 266 | CONSTRAINT_UNDER(target=TRASH_CAN_1, anchor=COUNTER_0) 267 | 268 | # set the CSP variable for the target object 269 | SET_TARGET(TRASH_CAN_0) -------------------------------------------------------------------------------- /prompts/generate_program_csp_no_counting_negation.txt: -------------------------------------------------------------------------------- 1 | <[SYSTEM]> 2 | 3 | You are given the 3D visual grounding task. The input is a language query describing some object(s) in the 3D scene. Besides that, you are also given a list of object labels that are relevant to this query. Any other objects should not be considered. 4 | 5 | Your task is to generate a Python program that locate the target object specified by the query. You don't have to use all the relevant object labels given to you. Some of them may actually be irrelevant. What's important is, you should always use the labels of the relevant objects given to you. 6 | 7 | The Python program you generate should solve a CSP (Constraint Satisfication Problem). Solving the CSP will result in the correct grounding of the target objects. The variables are objects and constraints are spatial relationships among objects or appearance information of objects. 8 | 9 | You should first try to do some spatial reasoning of the input query and simplify the logic. Some relations between objects can be converted into simpler ones. You should always first try to reduce the complexity of the input query. E.g. an object that has one object on the left and another on the right is between those two objects. So you should use a single relation "between", instead of two relations "left" and "right". There are many similar cases. 10 | 11 | Your should only use the following predefined functions in your program. No other functions are allowed. 12 | 13 | <[REGISTERED_FUNCTIONS_PLACEHOLDER]> 14 | 15 | Some of the predefined functions above have a "score_func" parameter, which specifies how to compare objects. The following score functions are available. 16 | 17 | <[REGISTERED_SCORE_FUNCTIONS_PLACEHOLDER]> 18 | 19 | Your output should be a valid Python program, which means that any additional explanation should be comments (start with #). You only have to output the python code. Please DO NOT follow markdown conventions. You DO NOT have to enclose the code with ```. 20 | 21 | Some extra tips: 22 | - "North", "south", "east" and "west" should always be translated into "near". 23 | - If something is at the side of a room, it should be far from the room center. 24 | - If an object you need is not in the relevant object list, please ignore information about that object, and do not use it. 25 | - You should not ignore relations that can be represented with the given relevant object labels and available functions. 26 | - Relations like "farthest", "northwestern-most" in reference to the room should be translated to being farthest to the room center. 27 | - You should identify only ONE target, to which the given 3D visual grounding query is referring to. You should always set the target at the very end. No function calls are allowed after setting the target. 28 | - Please DO NOT define variables that you will not use later. Every variables defined must be set in relation with at least one other variable. 29 | - Sometimes multiple labels refer to similar objects. For example, different queries may use "desk" or "table" to refer to the same thing. So when you need a "desk", and only "table" are available, you should use table, and vice versa. 30 | - For each object, you either define a normal variable or a negative variable. Please do not define both. 31 | 32 | <[USER]> 33 | 34 | QUERY: 35 | this is a grey chair. it is at the end of the shorter coffee table closest to the window. 36 | 37 | RELEVANT OBJECT LABELS: 38 | [0] chair 39 | [1] window 40 | [2] coffee table 41 | 42 | <[ASSISTANT]> 43 | 44 | # define variables for the CSP 45 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 46 | WINDOW_0 = DEFINE_VARIABLE(labels=["window"]) 47 | COFFEE_TABLE_0 = DEFINE_VARIABLE(labels=["coffee table"]) 48 | 49 | # add a constraint: the grey chair is at the end of the shorter coffee table 50 | CONSTRAINT_NEAR(target=CHAIR_0, anchor=COFFEE_TABLE_0) 51 | 52 | # add a constraint: the coffee table is the one closest to the window 53 | CONSTRAINT_MIN_OF(target=COFFEE_TABLE_0, anchor=WINDOW_0, score_func="distance") 54 | 55 | # finally set the CSP variable representing the target object 56 | SET_TARGET(CHAIR_0) 57 | 58 | <[USER]> 59 | 60 | QUERY: 61 | the desk is one of three in the row. there is a desk to the right of it, and a desk to the left of it. 62 | 63 | RELEVANT OBJECT LABELS: 64 | [0] desk 65 | 66 | <[ASSISTANT]> 67 | 68 | # since there are three desks in total, I'll create an object to represent each of them 69 | DESK_MIDDLE = DEFINE_VARIABLE(labels=["desk"]) 70 | DESK_LEFT = DEFINE_VARIABLE(labels=["desk"]) 71 | DESK_RIGHT = DEFINE_VARIABLE(labels=["desk"]) 72 | 73 | # the target desk has one desk to the right, one to the left, therefore by a simple reasoning I know that it is the desk between two other desks 74 | CONSTRAINT_BETWEEN(target=DESK_MIDDLE, anchors={DESK_LEFT, DESK_RIGHT}) 75 | 76 | # set the variable for the target object 77 | SET_TARGET(DESK_MIDDLE) 78 | 79 | <[USER]> 80 | 81 | QUERY: 82 | it is the tallest chair at the center of the room. 83 | 84 | RELEVANT OBJECT LABELS: 85 | [0] room center 86 | [1] chair 87 | 88 | <[ASSISTANT]> 89 | 90 | # define a variable for the target chair object 91 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 92 | 93 | # define the room center variable 94 | ROOM_CENTER = DEFINE_VARIABLE(labels=["room center"]) 95 | 96 | # add a constraint: the chair is near the room center 97 | # we pass the set of all chairs in order to keep only chairs at the room center 98 | CONSTRAINT_NEAR(target=CHAIR_0, anchor=ROOM_CENTER) 99 | 100 | # add a constraint: the chair is the tallest one, i.e. the one with largest height 101 | # we find the chair with the largest extent along z-axis in the set of chairs at the room center 102 | CONSTRAINT_MAX_OF(target=CHAIR_0, score_func="size-z") 103 | 104 | # set the final target 105 | SET_TARGET(CHAIR_0) 106 | 107 | <[USER]> 108 | 109 | QUERY: 110 | this is a blue trash can that is beside a orange one. it is not the trash can far from other trash cans. 111 | 112 | RELEVANT OBJECT LABELS: 113 | [0] trash can 114 | 115 | <[ASSISTANT]> 116 | 117 | # define a variable for the blue trash can (target) 118 | TRASH_CAN_BLUE = DEFINE_VARIABLE(labels=["trash can"]) 119 | 120 | # define a variable for the orange trash can 121 | TRASH_CAN_ORANGE = DEFINE_VARIABLE(labels=["trash can"]) 122 | 123 | # define a variable for the far away trash can 124 | TRASH_CAN_FAR = DEFINE_VARIABLE(labels=["trash can"]) 125 | 126 | # the blue trash can should be beside the orange one 127 | CONSTRAINT_NEAR(target=TRASH_CAN_BLUE, anchor=TRASH_CAN_ORANGE) 128 | 129 | # the far away trash can should be far from the other trash cans 130 | CONSTRAINT_FAR(target=TRASH_CAN_FAR, anchor=TRASH_CAN_BLUE) 131 | CONSTRAINT_FAR(target=TRASH_CAN_FAR, anchor=TRASH_CAN_ORANGE) 132 | 133 | # finally set the variable for the target object 134 | SET_TARGET(TRASH_CAN_BLUE) 135 | 136 | <[USER]> 137 | 138 | QUERY: 139 | it is the smallest cup on the largest table. 140 | 141 | RELEVANT OBJECT LABELS: 142 | [0] cup 143 | [1] table 144 | 145 | <[ASSISTANT]> 146 | 147 | # define a variable for the largest table 148 | TABLE_0 = DEFINE_VARIABLE(labels=["table"]) 149 | 150 | # define a variable for the smallest cup on the largest table 151 | CUP_0 = DEFINE_VARIABLE(labels=["cup"]) 152 | 153 | # constraint: the largest table 154 | CONSTRAINT_MAX_OF(target=TABLE_0, score_func="size") 155 | 156 | # constraint: the cup is on the table 157 | CONSTRAINT_ON(target=CUP_0, anchor=TABLE_0) 158 | 159 | # constraint: the smallest cup 160 | CONSTRAINT_MIN_OF(target=CUP_0, score_func="size") 161 | 162 | # set the variable for the target object 163 | SET_TARGET(CUP_0) 164 | 165 | <[USER]> 166 | 167 | QUERY: 168 | there are some windows on the wall. this is the window in the middle. 169 | 170 | RELEVANT OBJECT LABELS: 171 | [0] wall 172 | [1] window 173 | 174 | <[ASSISTANT]> 175 | 176 | # define variables for the window and the wall 177 | WALL_0 = DEFINE_VARIABLE(labels=["wall"]) 178 | WINDOW_0 = DEFINE_VARIABLE(labels=["window"]) 179 | 180 | # the windows should be near the wall 181 | CONSTRAINT_NEAR(target=WINDOW_0, anchor=WALL_0) 182 | 183 | # the target window should be in the middle, i.e., the "distance-to-center" score must be minimized 184 | CONSTRAINT_MIN_OF(target=WINDOW_0, score_func="distance-to-center") 185 | 186 | # set the variable for the target object 187 | SET_TARGET(WINDOW_0) 188 | 189 | <[USER]> 190 | 191 | QUERY: 192 | this is a recycling trash can. its blue and white in color and is on the left of a box. it is also below a wooden counter. 193 | 194 | RELEVANT OBJECT LABELS: 195 | [0] trash can 196 | [1] box 197 | [2] counter 198 | [3] plastic trash bin 199 | [4] wastebin 200 | 201 | <[ASSISTANT]> 202 | 203 | # define variables for the trash can. since we have multiple labels denoting trash cans (with synonyms), we include all of them in the label. 204 | TRASH_CAN_0 = DEFINE_VARIABLE(labels=["trash can", "plastic trash bin", "wastebin"]) 205 | 206 | # define variables for the box and the counter 207 | BOX_0 = DEFINE_VARIABLE(labels=["box"]) 208 | COUNTER_0 = DEFINE_VARIABLE(labels=["counter"]) 209 | 210 | # add a constraint: the trash can is on the left of a box 211 | CONSTRAINT_LEFT(target=TRASH_CAN_0, anchor=BOX_0) 212 | 213 | # add a constraint: the trash can is also below a wooden counter 214 | CONSTRAINT_UNDER(target=TRASH_CAN_1, anchor=COUNTER_0) 215 | 216 | # set the CSP variable for the target object 217 | SET_TARGET(TRASH_CAN_0) -------------------------------------------------------------------------------- /prompts/generate_program_csp_no_counting_negation_minmax.txt: -------------------------------------------------------------------------------- 1 | <[SYSTEM]> 2 | 3 | You are given the 3D visual grounding task. The input is a language query describing some object(s) in the 3D scene. Besides that, you are also given a list of object labels that are relevant to this query. Any other objects should not be considered. 4 | 5 | Your task is to generate a Python program that locate the target object specified by the query. You don't have to use all the relevant object labels given to you. Some of them may actually be irrelevant. What's important is, you should always use the labels of the relevant objects given to you. 6 | 7 | The Python program you generate should solve a CSP (Constraint Satisfication Problem). Solving the CSP will result in the correct grounding of the target objects. The variables are objects and constraints are spatial relationships among objects or appearance information of objects. 8 | 9 | You should first try to do some spatial reasoning of the input query and simplify the logic. Some relations between objects can be converted into simpler ones. You should always first try to reduce the complexity of the input query. E.g. an object that has one object on the left and another on the right is between those two objects. So you should use a single relation "between", instead of two relations "left" and "right". There are many similar cases. 10 | 11 | Your should only use the following predefined functions in your program. No other functions are allowed. 12 | 13 | <[REGISTERED_FUNCTIONS_PLACEHOLDER]> 14 | 15 | Some of the predefined functions above have a "score_func" parameter, which specifies how to compare objects. The following score functions are available. 16 | 17 | <[REGISTERED_SCORE_FUNCTIONS_PLACEHOLDER]> 18 | 19 | Your output should be a valid Python program, which means that any additional explanation should be comments (start with #). You only have to output the python code. Please DO NOT follow markdown conventions. You DO NOT have to enclose the code with ```. 20 | 21 | Some extra tips: 22 | - "North", "south", "east" and "west" should always be translated into "near". 23 | - If something is at the side of a room, it should be far from the room center. 24 | - If an object you need is not in the relevant object list, please ignore information about that object, and do not use it. 25 | - You should not ignore relations that can be represented with the given relevant object labels and available functions. 26 | - Relations like "farthest", "northwestern-most" in reference to the room should be translated to being farthest to the room center. 27 | - You should identify only ONE target, to which the given 3D visual grounding query is referring to. You should always set the target at the very end. No function calls are allowed after setting the target. 28 | - Please DO NOT define variables that you will not use later. Every variables defined must be set in relation with at least one other variable. 29 | - Sometimes multiple labels refer to similar objects. For example, different queries may use "desk" or "table" to refer to the same thing. So when you need a "desk", and only "table" are available, you should use table, and vice versa. 30 | - For each object, you either define a normal variable or a negative variable. Please do not define both. 31 | 32 | <[USER]> 33 | 34 | QUERY: 35 | the desk is one of three in the row. there is a desk to the right of it, and a desk to the left of it. 36 | 37 | RELEVANT OBJECT LABELS: 38 | [0] desk 39 | 40 | <[ASSISTANT]> 41 | 42 | # since there are three desks in total, I'll create an object to represent each of them 43 | DESK_MIDDLE = DEFINE_VARIABLE(labels=["desk"]) 44 | DESK_LEFT = DEFINE_VARIABLE(labels=["desk"]) 45 | DESK_RIGHT = DEFINE_VARIABLE(labels=["desk"]) 46 | 47 | # the target desk has one desk to the right, one to the left, therefore by a simple reasoning I know that it is the desk between two other desks 48 | CONSTRAINT_BETWEEN(target=DESK_MIDDLE, anchors={DESK_LEFT, DESK_RIGHT}) 49 | 50 | # set the variable for the target object 51 | SET_TARGET(DESK_MIDDLE) 52 | 53 | <[USER]> 54 | 55 | QUERY: 56 | this is a blue trash can that is beside a orange one. it is not the trash can far from other trash cans. 57 | 58 | RELEVANT OBJECT LABELS: 59 | [0] trash can 60 | 61 | <[ASSISTANT]> 62 | 63 | # define a variable for the blue trash can (target) 64 | TRASH_CAN_BLUE = DEFINE_VARIABLE(labels=["trash can"]) 65 | 66 | # define a variable for the orange trash can 67 | TRASH_CAN_ORANGE = DEFINE_VARIABLE(labels=["trash can"]) 68 | 69 | # define a variable for the far away trash can 70 | TRASH_CAN_FAR = DEFINE_VARIABLE(labels=["trash can"]) 71 | 72 | # the blue trash can should be beside the orange one 73 | CONSTRAINT_NEAR(target=TRASH_CAN_BLUE, anchor=TRASH_CAN_ORANGE) 74 | 75 | # the far away trash can should be far from the other trash cans 76 | CONSTRAINT_FAR(target=TRASH_CAN_FAR, anchor=TRASH_CAN_BLUE) 77 | CONSTRAINT_FAR(target=TRASH_CAN_FAR, anchor=TRASH_CAN_ORANGE) 78 | 79 | # finally set the variable for the target object 80 | SET_TARGET(TRASH_CAN_BLUE) 81 | 82 | <[USER]> 83 | 84 | QUERY: 85 | this is a recycling trash can. its blue and white in color and is on the left of a box. it is also below a wooden counter. 86 | 87 | RELEVANT OBJECT LABELS: 88 | [0] trash can 89 | [1] box 90 | [2] counter 91 | [3] plastic trash bin 92 | [4] wastebin 93 | 94 | <[ASSISTANT]> 95 | 96 | # define variables for the trash can. since we have multiple labels denoting trash cans (with synonyms), we include all of them in the label. 97 | TRASH_CAN_0 = DEFINE_VARIABLE(labels=["trash can", "plastic trash bin", "wastebin"]) 98 | 99 | # define variables for the box and the counter 100 | BOX_0 = DEFINE_VARIABLE(labels=["box"]) 101 | COUNTER_0 = DEFINE_VARIABLE(labels=["counter"]) 102 | 103 | # add a constraint: the trash can is on the left of a box 104 | CONSTRAINT_LEFT(target=TRASH_CAN_0, anchor=BOX_0) 105 | 106 | # add a constraint: the trash can is also below a wooden counter 107 | CONSTRAINT_UNDER(target=TRASH_CAN_1, anchor=COUNTER_0) 108 | 109 | # set the CSP variable for the target object 110 | SET_TARGET(TRASH_CAN_0) -------------------------------------------------------------------------------- /prompts/generate_program_csp_no_minmax.txt: -------------------------------------------------------------------------------- 1 | <[SYSTEM]> 2 | 3 | You are given the 3D visual grounding task. The input is a language query describing some object(s) in the 3D scene. Besides that, you are also given a list of object labels that are relevant to this query. Any other objects should not be considered. 4 | 5 | Your task is to generate a Python program that locate the target object specified by the query. You don't have to use all the relevant object labels given to you. Some of them may actually be irrelevant. What's important is, you should always use the labels of the relevant objects given to you. 6 | 7 | The Python program you generate should solve a CSP (Constraint Satisfication Problem). Solving the CSP will result in the correct grounding of the target objects. The variables are objects and constraints are spatial relationships among objects or appearance information of objects. 8 | 9 | You should first try to do some spatial reasoning of the input query and simplify the logic. Some relations between objects can be converted into simpler ones. You should always first try to reduce the complexity of the input query. E.g. an object that has one object on the left and another on the right is between those two objects. So you should use a single relation "between", instead of two relations "left" and "right". There are many similar cases. 10 | 11 | There are two types of variables: normal variables (DEFINE_VARIABLE), which represent a single object, and negative variables (DEFINE_NEGATIVE_VARIABLE), which indicate the abscence of an object satisfying some spatial relations. You should only use negative variables when the query mentions some negative conditions with e.g., 'do not' or 'without'. 12 | 13 | Your should only use the following predefined functions in your program. No other functions are allowed. 14 | 15 | <[REGISTERED_FUNCTIONS_PLACEHOLDER]> 16 | 17 | Some of the predefined functions above have a "score_func" parameter, which specifies how to compare objects. The following score functions are available. 18 | 19 | <[REGISTERED_SCORE_FUNCTIONS_PLACEHOLDER]> 20 | 21 | Your output should be a valid Python program, which means that any additional explanation should be comments (start with #). You only have to output the python code. Please DO NOT follow markdown conventions. You DO NOT have to enclose the code with ```. 22 | 23 | Some extra tips: 24 | - "North", "south", "east" and "west" should always be translated into "near". 25 | - If something is at the side of a room, it should be far from the room center. 26 | - If an object you need is not in the relevant object list, please ignore information about that object, and do not use it. 27 | - You should not ignore relations that can be represented with the given relevant object labels and available functions. 28 | - Relations like "farthest", "northwestern-most" in reference to the room should be translated to being farthest to the room center. 29 | - You should identify only ONE target, to which the given 3D visual grounding query is referring to. You should always set the target at the very end. No function calls are allowed after setting the target. 30 | - Please DO NOT define variables that you will not use later. Every variables defined must be set in relation with at least one other variable. 31 | - Sometimes multiple labels refer to similar objects. For example, different queries may use "desk" or "table" to refer to the same thing. So when you need a "desk", and only "table" are available, you should use table, and vice versa. 32 | - For each object, you either define a normal variable or a negative variable. Please do not define both. 33 | 34 | <[USER]> 35 | 36 | QUERY: 37 | the desk is one of three in the row. there is a desk to the right of it, and a desk to the left of it. 38 | 39 | RELEVANT OBJECT LABELS: 40 | [0] desk 41 | 42 | <[ASSISTANT]> 43 | 44 | # since there are three desks in total, I'll create an object to represent each of them 45 | DESK_MIDDLE = DEFINE_VARIABLE(labels=["desk"]) 46 | DESK_LEFT = DEFINE_VARIABLE(labels=["desk"]) 47 | DESK_RIGHT = DEFINE_VARIABLE(labels=["desk"]) 48 | 49 | # the target desk has one desk to the right, one to the left, therefore by a simple reasoning I know that it is the desk between two other desks 50 | CONSTRAINT_BETWEEN(target=DESK_MIDDLE, anchors={DESK_LEFT, DESK_RIGHT}) 51 | 52 | # set the variable for the target object 53 | SET_TARGET(DESK_MIDDLE) 54 | 55 | <[USER]> 56 | 57 | QUERY: 58 | this is a blue trash can that is beside a orange one. it is not the trash can far from other trash cans. 59 | 60 | RELEVANT OBJECT LABELS: 61 | [0] trash can 62 | 63 | <[ASSISTANT]> 64 | 65 | # define a variable for the blue trash can (target) 66 | TRASH_CAN_BLUE = DEFINE_VARIABLE(labels=["trash can"]) 67 | 68 | # define a variable for the orange trash can 69 | TRASH_CAN_ORANGE = DEFINE_VARIABLE(labels=["trash can"]) 70 | 71 | # define a variable for the far away trash can 72 | TRASH_CAN_FAR = DEFINE_VARIABLE(labels=["trash can"]) 73 | 74 | # the blue trash can should be beside the orange one 75 | CONSTRAINT_NEAR(target=TRASH_CAN_BLUE, anchor=TRASH_CAN_ORANGE) 76 | 77 | # the far away trash can should be far from the other trash cans 78 | CONSTRAINT_FAR(target=TRASH_CAN_FAR, anchor=TRASH_CAN_BLUE) 79 | CONSTRAINT_FAR(target=TRASH_CAN_FAR, anchor=TRASH_CAN_ORANGE) 80 | 81 | # finally set the variable for the target object 82 | SET_TARGET(TRASH_CAN_BLUE) 83 | 84 | <[USER]> 85 | 86 | QUERY: 87 | this is the third chair from the left. 88 | 89 | RELEVANT OBJECT LABELS: 90 | [0] chair 91 | 92 | <[ASSISTANT]> 93 | 94 | # this is an example of how to handle counting-based query 95 | 96 | # since the target is the third chair from the left, there must be at least three chairs in total 97 | # so we define a variable for each of them 98 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 99 | CHAIR_1 = DEFINE_VARIABLE(labels=["chair"]) 100 | CHAIR_2 = DEFINE_VARIABLE(labels=["chair"]) 101 | 102 | # constraint: the target chair has a chair on the left, which also has another chair to its left 103 | CONSTRAINT_LEFT(target=CHAIR_1, anchor=CHAIR_0) 104 | CONSTRAINT_LEFT(target=CHAIR_2, anchor=CHAIR_1) 105 | 106 | # set the variable for the target object 107 | SET_TARGET(CHAIR_0) 108 | 109 | <[USER]> 110 | 111 | QUERY: 112 | this is the chair third closest to the corner of the room. 113 | 114 | RELEVANT OBJECT LABELS: 115 | [0] chair 116 | [1] room corner 117 | 118 | <[ASSISTANT]> 119 | 120 | # this is another example of how to handle counting-based query 121 | # it is also a comparison-based query 122 | 123 | # since the target is the third chair from the room corner, there must be at least three chairs in total 124 | # so we define a variable for each of them 125 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 126 | CHAIR_1 = DEFINE_VARIABLE(labels=["chair"]) 127 | CHAIR_2 = DEFINE_VARIABLE(labels=["chair"]) 128 | 129 | # define a variable for the corner of the room 130 | ROOM_CORNER = DEFINE_VARIABLE(labels=["room corner"]) 131 | 132 | # constraint: the target chair (CHAIR_2) is farther from the room corner than another chair (CHAIR_1), which (CHAIR_1) again is farther from the room corner than yet a third chair (CHAIR_0) 133 | CONSTRAINT_MORE(target=CHAIR_2, reference=CHAIR_1, anchor=ROOM_CORNER, score_func="distance") 134 | CONSTRAINT_MORE(target=CHAIR_1, reference=CHAIR_0, anchor=ROOM_CORNER, score_func="distance") 135 | 136 | # set the variable for the target object 137 | SET_TARGET(CHAIR_2) 138 | 139 | <[USER]> 140 | 141 | QUERY: 142 | this is a chair. it does not have a teddy bear sitting on it. 143 | 144 | RELEVANT OBJECT LABELS: 145 | [0] chair 146 | [1] teddy bear 147 | 148 | <[ASSISTANT]> 149 | 150 | # this is an example of using negative variables 151 | 152 | # define variables for the chair and the teddy bear 153 | CHAIR_0 = DEFINE_VARIABLE(labels=["chair"]) 154 | 155 | # define a negative variable to express that the chair has no teddy bear on it 156 | TEDDY_BEAR_NEG = DEFINE_NEGATIVE_VARIABLE(labels=["teddy bear"]) 157 | 158 | # first create a normal constraint: the teddy bear is on the chair 159 | CONSTRAINT_ON(target=TEDDY_BEAR_NEG, anchor=CHAIR_0) 160 | 161 | # set the variable for the target object 162 | SET_TARGET(CHAIR_0) 163 | 164 | <[USER]> 165 | 166 | QUERY: 167 | it is the cup not on the table. 168 | 169 | RELEVANT OBJECT LABELS: 170 | [0] cup 171 | [1] table 172 | 173 | <[ASSISTANT]> 174 | 175 | # this is another example of using negative variables 176 | 177 | # define the cup variable, which is the target 178 | CUP_0 = DEFINE_VARIABLE(labels=["cup"]) 179 | 180 | # define the table variable as negative, since the cup variable is the target and we want to locate it 181 | TABLE_NEG_0 = DEFINE_NEGATIVE_VARIABLE(labels=["table"]) 182 | 183 | # the cup is not on the table 184 | CONSTRAINT_ON(target=CUP_0, anchor=TABLE_NEG_0) 185 | 186 | # set the variable for the target object 187 | SET_TARGET(CUP_0) 188 | 189 | <[USER]> 190 | 191 | QUERY: 192 | this is a recycling trash can. its blue and white in color and is on the left of a box. it is also below a wooden counter. 193 | 194 | RELEVANT OBJECT LABELS: 195 | [0] trash can 196 | [1] box 197 | [2] counter 198 | [3] plastic trash bin 199 | [4] wastebin 200 | 201 | <[ASSISTANT]> 202 | 203 | # define variables for the trash can. since we have multiple labels denoting trash cans (with synonyms), we include all of them in the label. 204 | TRASH_CAN_0 = DEFINE_VARIABLE(labels=["trash can", "plastic trash bin", "wastebin"]) 205 | 206 | # define variables for the box and the counter 207 | BOX_0 = DEFINE_VARIABLE(labels=["box"]) 208 | COUNTER_0 = DEFINE_VARIABLE(labels=["counter"]) 209 | 210 | # add a constraint: the trash can is on the left of a box 211 | CONSTRAINT_LEFT(target=TRASH_CAN_0, anchor=BOX_0) 212 | 213 | # add a constraint: the trash can is also below a wooden counter 214 | CONSTRAINT_UNDER(target=TRASH_CAN_1, anchor=COUNTER_0) 215 | 216 | # set the CSP variable for the target object 217 | SET_TARGET(TRASH_CAN_0) -------------------------------------------------------------------------------- /prompts/guess_thresholds.txt: -------------------------------------------------------------------------------- 1 | <[SYSTEM]> 2 | 3 | ??? 4 | 5 | <[USER]> 6 | 7 | Here is a 3D Visual Grounding Query: 8 | it is the largest cup on the table close to window. 9 | 10 | To locate the target, you have previously generated a Python program, which identifies these objects: 11 | cup_3, table_4, window_6 12 | 13 | For each of the objects above, we can assign one of the candidates to it: 14 | cup_3: 15 | candidate_1: center=[5.913, -2.261, 1.161]; size=[0.323, 0.323, 0.323] 16 | candidate_2: center=[5.913, -1.783, 1.09]; size=[0.180, 0.180, 0.180] 17 | candidate_3: center=[2.201, -6.001, 1.354]; size=[0.323, 0.323, 0.709] 18 | 19 | table_4: 20 | candidate_1: center=[5.913, -2.261, 0.000]; size=[2.000, 2.000, 2.000] 21 | candidate_2: center=[2.206, -5.929, 0.000]; size=[2.000, 2.000, 2.000] 22 | 23 | window_5: 24 | candidate_1: center=[5.891, -0.758, 2.327]; size=[2.000, 0.300, 2.000] 25 | 26 | In the generated Python program, we have to check the relations between object candidates in order to find the most appropriate one for each object. To check the spatial relations, we define some thresholds: 27 | near_distance: when the distance between two candidates assigned to two objects is below this threshold, they are considered to be "near" each other. 28 | on_distance: when the center of object A of above the center of object B, and the distance between the centers of A and B is below this threshold, object A should be considered as "on" object B. 29 | 30 | It is difficult to determine a good threshold for all queries. So your job is to determine for this specific query which values for the thresholds are the most appropriate ones. Please output the threshold name and its value on each line. 31 | 32 | We assume that the query has only a single target, i.e., only a single candidate can be assigned to the target object. This means that the threshold values you give should exclude all irrelevant objects, except those needed to ground the target. Please choose your threshold values so that this is the case. 33 | 34 | <[ASSISTANT]> 35 | 36 | To determine the appropriate threshold values for the given query, we need to analyze the spatial relationships between the candidates for the objects "cup_3," "table_4," and "window_6." 37 | 38 | Let's break down the query: 39 | - The target is the largest cup. 40 | - The cup is on the table. 41 | - The table is near the window. 42 | 43 | First, we need to identify the largest cup. From the given sizes: 44 | - candidate_1: size=[0.323, 0.323, 0.323] 45 | - candidate_2: size=[0.180, 0.180, 0.180] 46 | - candidate_3: size=[0.323, 0.323, 0.709] 47 | 48 | Candidate_3 is the largest cup. 49 | 50 | Next, we need to determine which table the cup is on. We check the vertical distance (z-coordinate) between the cup and the tables: 51 | - Cup candidate_3: center=[2.201, -6.001, 1.354] 52 | - Table candidate_1: center=[5.913, -2.261, 0.000] 53 | - Table candidate_2: center=[2.206, -5.929, 0.000] 54 | 55 | The vertical distance between cup candidate_3 and table candidate_2 is 1.354, which is reasonable for the cup to be on the table. 56 | 57 | Finally, we need to determine which table is near the window: 58 | - Window candidate_1: center=[5.891, -0.758, 2.327] 59 | 60 | We calculate the horizontal distance (x and y coordinates) between the tables and the window: 61 | - Distance between table candidate_1 and window candidate_1: sqrt((5.913 - 5.891)^2 + (-2.261 + 0.758)^2) ≈ 2.36 62 | - Distance between table candidate_2 and window candidate_1: sqrt((2.206 - 5.891)^2 + (-5.929 + 0.758)^2) ≈ 7.37 63 | 64 | Table candidate_1 is closer to the window. 65 | 66 | Based on this analysis, we need to set the thresholds such that: 67 | - The largest cup (candidate_3) is on table candidate_2. 68 | - Table candidate_2 is not considered near the window to exclude it. 69 | - Table candidate_1 is considered near the window. 70 | 71 | Thus, the appropriate threshold values are: 72 | ``` 73 | on_distance: 2.0 74 | near_distance: 3.0 75 | ``` 76 | 77 | These threshold values ensure that only the relevant objects are considered for grounding the target cup. -------------------------------------------------------------------------------- /prompts/select_solution.txt: -------------------------------------------------------------------------------- 1 | <[SYSTEM]> 2 | 3 | You are an intelligent agent for the 3D Visual Grounding task. You are given a query and many candidate solutions, each of which assigns an center and a size for each objects mentioned in the query (also given to you explicitly). Your task is to choose only one solution that is most appropriate, based on the content of the query. 4 | 5 | Here are some tips: 6 | - When dealing with complex spatial arrangements, it may be helpful to visualize the scenario and positions of objects based on their given centers and sizes. 7 | - Always carefully compute and compare distances between objects, as 'close to' or 'far from' are often relative concepts dependent on the overall layout. 8 | - When checking view-dependent spatial relations like 'front', 'back', 'left', and 'right', remember to convert those into the concept of 'spatially close', or view from the center of all objects to get more accurate judgments. 9 | - While the aim is to find the solution that fits all requirements, keep in mind that in certain situations, a solution might need to be chosen based on the majority of conditions being met, or some conditions might need to be loosened based on commonsense if none of the solutions are perfect. The goal is to choose the solution that most logically aligns with the given query. 10 | - Carefully examining each solution: A quick overview may overlook critical differences between object placements or sizes. Each element in the query and solutions must be reviewed. 11 | - Using mathematical calculations: Use computations to measure distances and sizes when necessary. Precise measurements will greatly aid in tasks with spatial arrangements and opting for the most accurate solution. 12 | - Making judicious assumptions: In certain scenarios, we may need to make assumptions when the query or solutions are ambiguous. These assumptions should be logical and based on the given context, and must always be justified in your reasoning. 13 | - Deploying basic geometry and physics: Simple concepts about distance, area, volume, trajectory, etc can play a pivotal role in generating a robust judgement. 14 | - Considering multiple interpretations: Queries may have multiple potential interpretations. Try to identify alternative understandings, and see which solutions could fit. 15 | - Contemplating real-world feasibility: In addition to spatial calculations and geometry, consider if the solution makes sense in a real-world scenario. 16 | 17 | Example Input: 18 | 19 | @QUERY 20 | it is the largest cup on the table close to window. 21 | 22 | @OBJECTS 23 | cup_3, table_4, window_6 24 | 25 | @SOLUTIONS 26 | solution 0 27 | cup_3: center=[5.913, -2.261, 1.161]; size=[0.323, 0.323, 0.323] 28 | table_4: center=[5.913, -2.261, 0.000]; size=[2.000, 2.000, 2.000] 29 | window_6: center=[5.891, -0.758, 2.327]; size=[2.000, 0.300, 2.000] 30 | 31 | solution 1 32 | cup_3: center=[5.913, -1.783, 1.09]; size=[0.180, 0.180, 0.180] 33 | table_4: center=[5.913, -2.261, 0.000]; size=[2.000, 2.000, 2.000] 34 | window_6: center=[5.891, -0.758, 2.327]; size=[2.000, 0.300, 2.000] 35 | 36 | solution 2 37 | cup_3: center=[2.201, -6.001, 1.354]; size=[0.323, 0.323, 0.709] 38 | table_4: center=[2.206, -5.929, 0.000]; size=[2.000, 2.000, 2.000] 39 | window_6: center=[5.891, -0.758, 2.327]; size=[2.000, 0.300, 2.000] 40 | 41 | Example Output: 42 | 43 | The task involves finding the most appropriate solution that aligns with the query - "it is the largest cup on the table close to window." Thus, we need to find a solution wherein the cup is the largest object among the cups and is close to the window while being on the table. 44 | 45 | Let's evaluate each solution: 46 | 47 | Solution 0: This solution has the cup at the center [5.913, -2.261, 1.161] with size [0.323, 0.323, 0.323]. While it is greater in size than the cup in Solution 1, it is comparatively smaller in height than the one in Solution 2. Moreover, it is relatively closer to the window, although not as close as the one in Solution 1. 48 | 49 | Solution 1: The cup in this solution is smaller in size than the cups in both Solution 0 and 2. Furthermore, while it is closer to the window than the cup in Solution 0, it is still not the largest object among the cups as required by the query. 50 | 51 | Solution 2: This solution presents the cup of the largest size among all the cups, with a size of [0.323, 0.323, 0.709]. However, it is farther from the window than both Solutions 0 and 1, which makes it not an ideal selection as per the query requirement of the cup being close to the window. 52 | 53 | Considering all these details, Solution 0 results as the most appropriate choice as it presents the cup of the largest size close to the window. This confirms with the query as tightly as possible, as the cup in Solution 1 is smaller and the one in Solution 2 is farther away from the window. 54 | 55 | Here is the answer in a format facilitating parsing: 56 | @BEGIN{ANSWER} 57 | correct solution index: [0] 58 | @END{ANSWER} -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Solving Zero-Shot 3D Visual Grounding as Constraint Satisfaction Problems 2 | 3 |
9 | 10 | We propose [Constraint Satisfaction Visual Grounder (CSVG)](https://sunsleaf.github.io/CSVG/) for 3D visual grounding in a point cloud referred to by nautural language descriptions. 11 | 12 | 13 | ## Updates 14 | 15 | - [x] 11/2024, CSVG code released. 16 | - [x] 11/2024, CSVG arxiv paper released. 17 | 18 | ## Setup 19 | 20 | Here are the instructions to run the code. 21 | 22 | #### Install the environment. 23 | 24 | ``` 25 | pip install transformers 26 | pip install pyviz3d 27 | pip install plyfile 28 | ``` 29 | 30 | If you encounter missing package errors, simply installing the package should solve the problem. 31 | 32 | #### Prepare the data. 33 | 34 | Our system can be used with the ground truth segmentation from [ScanNet](https://github.com/ScanNet/ScanNet) or predictions from [Mask3D](https://github.com/JonasSchult/Mask3D). 35 | 36 | For both cases, you need to download the ScanNet dataset to ``data/scans``. You may only download the validation set to save disk space. 37 | 38 | If you want to use instance segmentations from Mask3D, you can run the model by following the instructions from the Mask3D github repo. In the end, the prediction results will be stored in folders named like `instance_evaluation_scannet200_val_query_150_topk_750_dbscan_0.95_0/decoder_-1/`. After this folder is filled with contents, please make a soft link of it to `data/eval_output/mask3d_val`, i.e., the `decoder_-` folder should be linked as `mask3d_val`. 39 | 40 | Download the [ScanRefer](https://github.com/daveredrum/ScanRefer) dataset and put the json files into `data/scanrefer`. 41 | 42 | If you also want to evaluate on the [Nr3D](https://referit3d.github.io/) dataset, download the `nr3d.csv` into the `data` folder directly. 43 | 44 | Now you should have the data required to run the code. 45 | 46 | #### Program generation 47 | 48 | You can use the `run_scripts` bash file to generate Python programs for a dataset: 49 | 50 | `./run_scripts gen --dataset scanrefer --mask3d` 51 | 52 | For a complete list of available options, please run: 53 | 54 | `./run_scripts gen --help` 55 | 56 | The most important option is `--dataset`, which can be `scanrefer`, `nr3d`, or any custom dataset names. For customized dataset names, a file called `{dataset_name}_queries.json` in the `data` folder will be loaded. There are already some examples there, the format of which you can follow. 57 | 58 | If you add the `--mask3d` option, segmentations from Mask3D will be used. If none is added, the ground truth segmentation will be used. 59 | 60 | Another important thing is the LLM server address. You can either deploy a local LLM with OpenAI-compatible APIs, or use OpenAI models directly. The configuration of a local API server is at line 782 in `program_generator.py`; here, the `api_key` argument usually does not matter, but the `base_url` should be modified according to your server configurations. For OpenAI API, you should modify your key at line 799 in `program_generator.py`. 61 | 62 | After running the generation script, an `eval_data_*` file will be generated in the `output` folder. 63 | 64 | #### Evaluation 65 | 66 | You can also use the `run_scripts` file to run the generated program and get the grounding results, e.g.: 67 | 68 | `./run_scripts eval --dataset scanrefer --seg mask3d --num-threads 10` 69 | 70 | The `--dataset` option can be `scanrefer`, `nr3d` or a custom dataset name. The `--seg` option can be `gt` or `mask3d`. 71 | 72 | After running the evaluation script, an `eval_results_*` file will be generated in the `output` folder. 73 | 74 | #### Visualization 75 | 76 | The `visualize_eval_results.py` can be used to visualize the evaluation results. It reads the `eval_results_*` files. We include some examples from our experiments. 77 | 78 | For example, you can run the following command to visualize results on the ScanRefer validation set using Mask3D segmentations: 79 | 80 | `python ./visualize_eval_results.py --file ./output/eval_results_mask3d_scanrefer.json` 81 | 82 | Use the `--help` option to see a full list of available options, e.g., the `--distractor-required` option will plot the distractors (objects with the same label as the target). 83 | 84 | The visualization script will start a [PyViz] server, and you can view the visualization in your browser at `0.0.0.0:8889` (the default port). 85 | 86 | ### Contributor 87 | 88 | Qihao Yuan (Email: qihao.yuan@rug.nl) 89 | 90 | ### License 91 | The source code is released under [GPLv3](https://www.gnu.org/licenses/) license. 92 | 93 | ## Citation 94 | 95 | If you are interested in this work, please cite as below: 96 | 97 | ```text 98 | @article{yuan2024solving, 99 | title={Solving Zero-Shot 3D Visual Grounding as Constraint Satisfaction Problems}, 100 | author={Yuan, Qihao and Zhang, Jiaming and Li, Kailai and Stiefelhagen, Rainer}, 101 | journal={arXiv preprint arXiv:2411.14594}, 102 | year={2024} 103 | } 104 | ``` 105 | -------------------------------------------------------------------------------- /run_scripts: -------------------------------------------------------------------------------- 1 | if [ $1 = "gen" ]; then 2 | # clear 3 | echo "generating programs..." 4 | echo 5 | python -u program_generator.py "${@:2}" 2>&1 | tee out_gen.txt 6 | elif [ $1 = "eval" ]; then 7 | # clear 8 | echo "evaluating programs..." 9 | echo 10 | python -u program_evaluator.py "${@:2}" 2>&1 | tee out_eval.txt 11 | else 12 | echo "invalid input!" 13 | fi -------------------------------------------------------------------------------- /scannet_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import Counter 4 | from dataclasses import dataclass 5 | from random import random, seed 6 | from typing import TypedDict 7 | 8 | import numpy as np 9 | import numpy.typing as npt 10 | import pyviz3d 11 | import pyviz3d.visualizer 12 | from plyfile import PlyData 13 | from sklearn.cluster import DBSCAN 14 | from typing_extensions import Self 15 | 16 | from data.scannet200_constants import CLASS_LABELS_200, VALID_CLASS_IDS_200 17 | 18 | 19 | def idx_2_label_200(idx): 20 | """copied from https://github.com/ripl/Transcrib3D/blob/main/preprocessing/gen_obj_list.py""" 21 | return CLASS_LABELS_200[VALID_CLASS_IDS_200.index(idx)] 22 | 23 | 24 | def read_mesh_vertices(mesh_file: str, load_color: bool) -> npt.NDArray: 25 | """read XYZ (and RGB) for each vertex.""" 26 | assert os.path.isfile(mesh_file), mesh_file 27 | with open(mesh_file, "rb") as f: 28 | plydata = PlyData.read(f) 29 | num_verts = plydata["vertex"].count 30 | num_cols = 6 if load_color else 3 31 | vertices = np.zeros(shape=[num_verts, num_cols], dtype=np.float32) 32 | vertices[:, 0] = plydata["vertex"].data["x"] 33 | vertices[:, 1] = plydata["vertex"].data["y"] 34 | vertices[:, 2] = plydata["vertex"].data["z"] 35 | if load_color: 36 | vertices[:, 3] = plydata["vertex"].data["red"] / 255.0 37 | vertices[:, 4] = plydata["vertex"].data["green"] / 255.0 38 | vertices[:, 5] = plydata["vertex"].data["blue"] / 255.0 39 | return vertices 40 | 41 | 42 | def transform_vertices( 43 | meta_file: str, 44 | mesh_vertices: npt.NDArray, 45 | ) -> tuple[npt.NDArray, npt.NDArray]: 46 | """read alignment matrix and transform mesh vertices""" 47 | assert os.path.isfile(meta_file) 48 | lines = open(meta_file).readlines() 49 | 50 | axis_align_matrix = None 51 | for line in lines: 52 | if "axisAlignment" in line: 53 | axis_align_matrix = [ 54 | float(x) for x in line.rstrip().strip("axisAlignment = ").split(" ") 55 | ] 56 | 57 | if axis_align_matrix is not None: 58 | axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4)) 59 | # print(axis_align_matrix) 60 | 61 | pts = np.ones((mesh_vertices.shape[0], 4)) 62 | pts[:, 0:3] = mesh_vertices[:, 0:3] 63 | pts = np.dot(pts, axis_align_matrix.transpose()) # Nx4 64 | 65 | aligned_vertices = mesh_vertices.copy() 66 | aligned_vertices[:, 0:3] = pts[:, 0:3] 67 | return aligned_vertices, axis_align_matrix 68 | else: 69 | print() 70 | print("no axis alignment matrix!") 71 | print() 72 | return mesh_vertices, np.eye(4) 73 | 74 | 75 | class RawInstance(TypedDict): 76 | inst_id: str 77 | label: str 78 | vertices: npt.NDArray 79 | score: float 80 | 81 | 82 | class BoundingBox3D: 83 | def __init__( 84 | self, 85 | pmin0: npt.NDArray = None, 86 | pmax0: npt.NDArray = None, 87 | pcenter0: npt.NDArray = None, 88 | psize0: npt.NDArray = None, 89 | ): 90 | if pmin0 is not None and pmax0 is not None: 91 | assert pcenter0 is None and psize0 is None 92 | pmin0 = np.array(pmin0) 93 | pmax0 = np.array(pmax0) 94 | elif pcenter0 is not None and psize0 is not None: 95 | assert pmin0 is None and pmax0 is None 96 | pcenter0 = np.array(pcenter0) 97 | psize0 = np.array(psize0) 98 | pmin0 = pcenter0 - psize0 * 0.5 99 | pmax0 = pcenter0 + psize0 * 0.5 100 | 101 | assert pmin0.shape == (3,) 102 | assert pmax0.shape == (3,) 103 | pmin = np.minimum(pmin0, pmax0) 104 | pmax = np.maximum(pmin0, pmax0) 105 | self.pmin = pmin 106 | self.pmax = pmax 107 | self.center = 0.5 * (pmin + pmax) 108 | self.size = pmax - pmin 109 | self.max_extent = np.max(pmax - pmin) 110 | self.extents = { 111 | "x": pmax[0] - pmin[0], 112 | "y": pmax[1] - pmin[1], 113 | "z": pmax[2] - pmin[2], 114 | } 115 | 116 | def contains(self, p: npt.NDArray) -> bool: 117 | return np.all(p > self.pmin) and np.all(p < self.pmax) 118 | 119 | def intersect(self, other: Self) -> Self: 120 | return BoundingBox3D( 121 | pmin0=np.maximum(self.pmin, other.pmin), 122 | pmax0=np.minimum(self.pmax, other.pmax), 123 | ) 124 | 125 | def union(self, other: Self) -> Self: 126 | return BoundingBox3D( 127 | pmin0=np.minimum(self.pmin, other.pmin), 128 | pmax0=np.maximum(self.pmax, other.pmax), 129 | ) 130 | 131 | def volume(self) -> float: 132 | return np.prod(self.pmax - self.pmin) 133 | 134 | def iou(self, other: Self) -> float: 135 | if self.volume() == 0 or other.volume() == 0: 136 | return 0.0 137 | return self.intersect(other).volume() / self.union(other).volume() 138 | 139 | 140 | def filter_raw_instances(inst_map: dict[str, RawInstance]) -> dict[str, RawInstance]: 141 | def filter_pointcould(points): 142 | """copied from the Transcrib3D repo""" 143 | # use dbscan to filter out outlier points 144 | dbscan = DBSCAN(eps=0.1, min_samples=20) 145 | if points.shape[1] == 3: 146 | dbscan.fit(points) 147 | else: 148 | dbscan.fit(points[:, 0:3]) 149 | counter = Counter(dbscan.labels_) 150 | main_idx = counter.most_common(2)[0][0] 151 | if main_idx == -1: 152 | main_idx = counter.most_common(2)[-1][0] 153 | # print("counter:",counter) 154 | # print("main_idx:",main_idx) 155 | points_filtered = points[dbscan.labels_ == main_idx] 156 | return points_filtered 157 | 158 | def calc_iou(inst_a: RawInstance, inst_b: RawInstance) -> float: 159 | bbox_a = BoundingBox3D( 160 | pmin0=inst_a["vertices"][:, :3].min(axis=0), 161 | pmax0=inst_a["vertices"][:, :3].max(axis=0), 162 | ) 163 | bbox_b = BoundingBox3D( 164 | pmin0=inst_b["vertices"][:, :3].min(axis=0), 165 | pmax0=inst_b["vertices"][:, :3].max(axis=0), 166 | ) 167 | return bbox_a.iou(bbox_b) 168 | 169 | # filter the point cloud of each instance 170 | filtered_insts_1: dict[str, RawInstance] = {} 171 | for k, inst_0 in inst_map.items(): 172 | inst_1 = inst_0.copy() 173 | inst_1["vertices"] = filter_pointcould(inst_0["vertices"]) 174 | filtered_insts_1[k] = inst_1 175 | 176 | # filter out overlapped instances with lower scores 177 | filtered_insts_2: dict[str, RawInstance] = {} 178 | for k, inst_1 in filtered_insts_1.items(): 179 | for inst in filtered_insts_1.values(): 180 | iou = calc_iou(inst_1, inst) 181 | if iou >= 0.7 and inst_1["score"] < inst["score"]: 182 | break 183 | else: 184 | filtered_insts_2[k] = inst_1 185 | 186 | return filtered_insts_2 187 | 188 | 189 | def read_instances( 190 | agg_file_path: str, 191 | seg_file_path: str, 192 | vertex_buffer: npt.NDArray, 193 | ) -> dict[str, RawInstance]: 194 | """return a dict: instance id -> instance info""" 195 | 196 | # read segments 197 | seg_to_verts = {} 198 | with open(seg_file_path) as f: 199 | data = json.load(f) 200 | num_verts = len(data["segIndices"]) 201 | for i in range(num_verts): 202 | seg_id = data["segIndices"][i] 203 | if seg_id in seg_to_verts: 204 | seg_to_verts[seg_id].append(i) 205 | else: 206 | seg_to_verts[seg_id] = [i] 207 | 208 | # read instances 209 | inst_id_to_insts: dict[str, RawInstance] = {} 210 | with open(agg_file_path) as f: 211 | data = json.load(f) 212 | num_objects = len(data["segGroups"]) 213 | for i in range(num_objects): 214 | inst = RawInstance() 215 | inst["id"] = str(data["segGroups"][i]["objectId"]) 216 | inst["label"] = str(data["segGroups"][i]["label"]) 217 | 218 | # assign vertices to the instance 219 | for seg in data["segGroups"][i]["segments"]: 220 | verts = seg_to_verts[seg] 221 | if "vertices" not in inst: 222 | inst["vertices"] = vertex_buffer[verts, :] 223 | else: 224 | inst["vertices"] = np.vstack( 225 | [inst["vertices"], vertex_buffer[verts, :]] 226 | ) 227 | 228 | inst_id_to_insts[inst["id"]] = inst 229 | 230 | return inst_id_to_insts 231 | 232 | 233 | def read_instances_mask3d( 234 | scene_id: str, pred_path: str, vertex_buffer: npt.NDArray 235 | ) -> dict[str, RawInstance]: 236 | assert os.path.isdir(pred_path), pred_path 237 | assert vertex_buffer.shape[1] in (3, 6), vertex_buffer.shape 238 | 239 | scene_pred_file = os.path.join(pred_path, scene_id + ".txt") 240 | assert os.path.isfile(scene_pred_file), scene_pred_file 241 | 242 | inst_id_to_insts: dict[str, RawInstance] = {} 243 | with open(scene_pred_file) as f: 244 | for i, line in enumerate(f): 245 | line = line.strip() 246 | if not line: 247 | continue 248 | 249 | mask_rel_path, label_id, score = line.split() 250 | 251 | score = float(score) 252 | if score < 0.5: 253 | continue 254 | 255 | label_str = idx_2_label_200(int(label_id)) 256 | mask_file = os.path.join(pred_path, mask_rel_path) 257 | assert os.path.isfile(mask_file), mask_file 258 | 259 | masks = np.loadtxt(mask_file, dtype=bool) 260 | assert masks.shape[0] == vertex_buffer.shape[0] 261 | inst_id_to_insts[str(i)] = RawInstance( 262 | id=str(i), 263 | label=label_str, 264 | vertices=vertex_buffer[masks, :], 265 | score=score, 266 | ) 267 | 268 | return filter_raw_instances(inst_id_to_insts) 269 | 270 | 271 | def read_instances_maskcluster( 272 | scene_id: str, pred_path: str, vertex_buffer: npt.NDArray 273 | ) -> dict[str, RawInstance]: 274 | assert os.path.isdir(pred_path), pred_path 275 | assert vertex_buffer.shape[1] in (3, 6), vertex_buffer.shape 276 | 277 | scene_pred_file = os.path.join(pred_path, scene_id + ".npz") 278 | pred = np.load(scene_pred_file) 279 | 280 | pred_masks = pred["pred_masks"] 281 | pred_scores = pred["pred_score"] 282 | pred_classes = pred["pred_classes"] 283 | 284 | num_instances = pred["pred_masks"].shape[1] 285 | assert pred_masks.shape[0] == vertex_buffer.shape[0] 286 | assert pred_scores.shape[0] == num_instances 287 | assert pred_classes.shape[0] == num_instances 288 | 289 | return filter_raw_instances( 290 | { 291 | str(i): RawInstance( 292 | id=str(i), 293 | label=idx_2_label_200(pred_classes[i]), 294 | vertices=vertex_buffer[pred_masks[:, i], :], 295 | score=pred_scores[i], 296 | ) 297 | for i in range(num_instances) 298 | } 299 | ) 300 | 301 | 302 | class ObjInstance: 303 | """instance of an object of a certain label/category""" 304 | 305 | def __init__(self, instance_id: str, label: str, vertices: npt.NDArray): 306 | self.inst_id = str(instance_id) 307 | self.label = label 308 | self.vertices = vertices 309 | self.bbox = BoundingBox3D( 310 | pmin0=np.min(vertices[:, :3], axis=0), 311 | pmax0=np.max(vertices[:, :3], axis=0), 312 | ) 313 | 314 | def __hash__(self) -> int: 315 | return hash(self.inst_id) 316 | 317 | def __eq__(self, other: Self) -> bool: 318 | return self.inst_id == other.inst_id 319 | 320 | 321 | class ScanNetScene: 322 | def __init__( 323 | self, 324 | scene_path: str, 325 | mask3d_pred_path: str | None = None, 326 | maskcluster_pred_path: str | None = None, 327 | cache_root: str | None = None, 328 | add_room_center: bool = True, 329 | add_room_corners: bool = True, 330 | ): 331 | scene_path = os.path.normpath(scene_path) 332 | # assert os.path.isdir(scene_path) 333 | 334 | scene_id = os.path.basename(scene_path).strip() 335 | self.scene_id = scene_id 336 | self.viz_suffix = "gt" 337 | # print(f"loading {scene_id}.") 338 | 339 | scene_prefix = f"{scene_path}/{scene_id}" 340 | ply_file = f"{scene_prefix}_vh_clean_2.ply" 341 | agg_file = f"{scene_prefix}.aggregation.json" 342 | seg_file = f"{scene_prefix}_vh_clean_2.0.010000.segs.json" 343 | meta_file = f"{scene_prefix}.txt" 344 | # assert os.path.isfile(ply_file) 345 | # assert os.path.isfile(agg_file) 346 | # assert os.path.isfile(seg_file) 347 | # assert os.path.isfile(meta_file) 348 | 349 | instance_map: dict[str, RawInstance] = {} 350 | if mask3d_pred_path is not None: 351 | # print("loading mask3d.") 352 | assert maskcluster_pred_path is None 353 | self.viz_suffix = "mask3d" 354 | 355 | should_load_data = True 356 | cache_file = None 357 | 358 | if cache_root: 359 | cache_dir = os.path.join(cache_root, "instances_mask3d") 360 | cache_file = os.path.join(cache_dir, f"{scene_id}.npy") 361 | 362 | # create the cache folder if it does not exist 363 | if not os.path.isdir(cache_dir): 364 | os.system(f"mkdir -p {cache_dir}") 365 | 366 | # if the cache file exists, load it directly 367 | elif os.path.isfile(cache_file): 368 | should_load_data = False 369 | instance_map = np.load(cache_file, allow_pickle=True).item() 370 | 371 | if should_load_data: 372 | # load scene data 373 | vert_buf = read_mesh_vertices(mesh_file=ply_file, load_color=True) 374 | vert_buf, _ = transform_vertices( 375 | meta_file=meta_file, mesh_vertices=vert_buf 376 | ) 377 | 378 | instance_map = read_instances_mask3d( 379 | scene_id=scene_id, 380 | pred_path=mask3d_pred_path, 381 | vertex_buffer=vert_buf, 382 | ) 383 | 384 | if cache_file: 385 | np.save(cache_file, instance_map) 386 | 387 | elif maskcluster_pred_path is not None: 388 | # load scene data 389 | vert_buf = read_mesh_vertices(mesh_file=ply_file, load_color=True) 390 | vert_buf, _ = transform_vertices( 391 | meta_file=meta_file, mesh_vertices=vert_buf 392 | ) 393 | 394 | # print("loading maskclustering.") 395 | assert mask3d_pred_path is None 396 | self.viz_suffix = "maskcluster" 397 | instance_map = read_instances_maskcluster( 398 | scene_id=scene_id, 399 | pred_path=maskcluster_pred_path, 400 | vertex_buffer=vert_buf, 401 | ) 402 | 403 | else: 404 | self.viz_suffix = "gt" 405 | 406 | should_load_data = True 407 | cache_file = None 408 | 409 | if cache_root: 410 | cache_dir = os.path.join(cache_root, "instances_gt") 411 | cache_file = os.path.join(cache_dir, f"{scene_id}.npy") 412 | 413 | # create the cache folder if it does not exist 414 | if not os.path.isdir(cache_dir): 415 | os.system(f"mkdir -p {cache_dir}") 416 | 417 | # if the cache file exists, load it directly 418 | elif os.path.isfile(cache_file): 419 | should_load_data = False 420 | instance_map = np.load(cache_file, allow_pickle=True).item() 421 | 422 | if should_load_data: 423 | # load scene data 424 | vert_buf = read_mesh_vertices(mesh_file=ply_file, load_color=True) 425 | vert_buf, _ = transform_vertices( 426 | meta_file=meta_file, mesh_vertices=vert_buf 427 | ) 428 | 429 | # print("loading groundtruth.") 430 | instance_map = read_instances( 431 | agg_file_path=agg_file, 432 | seg_file_path=seg_file, 433 | vertex_buffer=vert_buf, 434 | ) 435 | 436 | if cache_file: 437 | np.save(cache_file, instance_map) 438 | 439 | # remove instances with too few vertices 440 | assert instance_map 441 | instance_map = { 442 | k: v for k, v in instance_map.items() if v["vertices"].shape[0] >= 10 443 | } 444 | # print(f"{len(instance_map)} instances loaded.") 445 | 446 | self.raw_instance_map: dict[str, ObjInstance] = {} 447 | self.instance_map: dict[str, list[ObjInstance]] = {} 448 | bboxes = [] 449 | 450 | # build instance map from raw instance map... 451 | for inst in instance_map.values(): 452 | obj_inst = ObjInstance( 453 | instance_id=inst["id"], 454 | label=inst["label"], 455 | vertices=inst["vertices"], 456 | ) 457 | bboxes.append(obj_inst.bbox) 458 | assert obj_inst.bbox.volume() > 0 459 | 460 | assert obj_inst.inst_id not in self.raw_instance_map 461 | self.raw_instance_map[obj_inst.inst_id] = obj_inst 462 | 463 | if inst["label"] in self.instance_map: 464 | self.instance_map[inst["label"]].append(obj_inst) 465 | else: 466 | self.instance_map[inst["label"]] = [obj_inst] 467 | 468 | self.bbox = BoundingBox3D( 469 | pmin0=np.min([bbox.pmin for bbox in bboxes], axis=0), 470 | pmax0=np.max([bbox.pmax for bbox in bboxes], axis=0), 471 | ) 472 | self.room_center = self.bbox.center 473 | 474 | if add_room_center: 475 | self.instance_map["room center"] = [ 476 | ObjInstance( 477 | instance_id=-1, 478 | label="room center", 479 | vertices=np.array( 480 | [self.room_center - 1e-5, self.room_center + 1e-5] 481 | ), 482 | ) 483 | ] 484 | 485 | if add_room_corners: 486 | self.room_corners = [ 487 | self.bbox.pmin, 488 | np.array([self.bbox.pmax[0], self.bbox.pmin[1], self.bbox.pmin[2]]), 489 | np.array([self.bbox.pmin[0], self.bbox.pmax[1], self.bbox.pmin[2]]), 490 | np.array([self.bbox.pmax[0], self.bbox.pmax[1], self.bbox.pmin[2]]), 491 | ] 492 | self.instance_map["room corner"] = [ 493 | ObjInstance( 494 | instance_id=-2, 495 | label="room corner", 496 | vertices=np.array([corner - 1e-5, corner + 1e-5]), 497 | ) 498 | for corner in self.room_corners 499 | ] 500 | 501 | # TODO: add "room front", "room back", "room left" and "room right" 502 | 503 | def get_instance_map(self) -> dict[str, list[ObjInstance]]: 504 | return self.instance_map 505 | 506 | def get_raw_instance_map(self): 507 | return self.raw_instance_map 508 | 509 | def get_distractors(self, inst_id: str) -> list[ObjInstance]: 510 | inst_0 = self.raw_instance_map[str(inst_id).strip()] 511 | return [inst for inst in self.instance_map[inst_0.label] if inst != inst_0] 512 | 513 | def get_room_center(self) -> npt.NDArray: 514 | return self.room_center 515 | 516 | def get_room_corners(self) -> list[npt.NDArray]: 517 | return self.room_corners 518 | 519 | def get_instance_bbox(self, inst_id: str) -> BoundingBox3D: 520 | return self.raw_instance_map[str(inst_id)].bbox 521 | 522 | def is_unique_label(self, label: str) -> bool: 523 | return len(self.instance_map[label]) == 1 524 | 525 | @dataclass 526 | class BBoxInfo: 527 | pmin: npt.NDArray 528 | pmax: npt.NDArray 529 | color: npt.NDArray 530 | name: str 531 | 532 | def __post_init__(self): 533 | self.pmin = np.array(self.pmin) 534 | self.pmax = np.array(self.pmax) 535 | self.color = np.array(self.color) 536 | 537 | assert self.pmin.shape == (3,) 538 | assert self.pmax.shape == (3,) 539 | assert self.color.shape == (3,) 540 | 541 | assert np.all(self.pmax >= self.pmin) 542 | assert np.all((self.color >= 0) & (self.color <= 1)) 543 | 544 | self.color = (self.color * 255).astype(np.uint8) 545 | 546 | def visualize_pyviz3d( 547 | self, 548 | viz_root_dir: str, 549 | target_id: int | str | None = None, 550 | target_color: tuple[float, float, float] | None = None, 551 | pred_bbox: BoundingBox3D | None = None, 552 | anchor_bboxes: dict[str, BoundingBox3D | list[BoundingBox3D]] | None = None, 553 | segments: bool = False, 554 | seg_colors: dict[str, tuple[float, float, float]] = {}, 555 | bbox_highlights: dict[str, tuple[float, float, float]] = {}, 556 | seg_highlights: dict[str, tuple[float, float, float]] = {}, 557 | extra_bboxes: list[BBoxInfo] = [], 558 | ) -> str: 559 | assert not bbox_highlights or not seg_highlights 560 | viz = pyviz3d.visualizer.Visualizer() 561 | 562 | bbox_line_width = 0.02 563 | id_counter = 0 564 | 565 | if target_id is not None: 566 | target_id = str(target_id).strip() 567 | assert target_id.isdigit() 568 | 569 | if target_color is not None: 570 | target_color = np.array(target_color) 571 | assert target_color.shape == (3,) 572 | assert np.all((target_color >= 0) & (target_color <= 1)) 573 | target_color = (target_color * 255).astype(np.uint8) 574 | 575 | for insts in self.instance_map.values(): 576 | for inst in insts: 577 | # if inst.inst_id not in {"2", "7", "35", "37"}: 578 | # continue 579 | 580 | if inst.vertices.shape[1] != 6: 581 | continue 582 | 583 | point_positions = inst.vertices[:, :3] 584 | 585 | if seg_highlights: 586 | if inst.inst_id in seg_highlights: 587 | point_colors = ( 588 | np.array( 589 | [seg_highlights[inst.inst_id]] * inst.vertices.shape[0] 590 | ) 591 | * 255 592 | ).astype(np.uint8) 593 | else: 594 | if segments: 595 | point_colors = ( 596 | np.array([[225, 225, 225]] * inst.vertices.shape[0]) 597 | ).astype(np.uint8) 598 | else: 599 | point_colors = (inst.vertices[:, 3:] * 0.2 * 255).astype( 600 | np.uint8 601 | ) 602 | 603 | else: 604 | if segments: 605 | if inst.inst_id in seg_colors: 606 | color = seg_colors[inst.inst_id] 607 | color[0] *= 255 608 | color[1] *= 255 609 | color[2] *= 255 610 | else: 611 | # seed(inst.inst_id) 612 | color = [random() * 255, random() * 255, random() * 255] 613 | point_colors = ( 614 | np.array([color] * inst.vertices.shape[0]) 615 | ).astype(np.uint8) 616 | else: 617 | point_colors = (inst.vertices[:, 3:] * 255).astype(np.uint8) 618 | 619 | viz.add_points( 620 | f"inst-{inst.label}-{(id_counter := id_counter + 1)}", 621 | point_positions, 622 | point_colors, 623 | point_size=50, 624 | ) 625 | 626 | if target_id is not None and inst.inst_id == target_id: 627 | viz.add_bounding_box( 628 | f"bbox-target-{inst.label}-{inst.inst_id}-{(id_counter := id_counter + 1)}", 629 | position=inst.bbox.center, 630 | size=inst.bbox.size, 631 | color=np.array([0, 255, 0]) 632 | if target_color is None 633 | else target_color, 634 | edge_width=bbox_line_width, 635 | ) 636 | 637 | if inst.inst_id in bbox_highlights: 638 | viz.add_bounding_box( 639 | f"bbox-highlight-{inst.label}-{inst.inst_id}-{(id_counter := id_counter + 1)}", 640 | position=inst.bbox.center, 641 | size=inst.bbox.size, 642 | color=(np.array(bbox_highlights[inst.inst_id]) * 255).astype( 643 | np.uint8 644 | ), 645 | edge_width=bbox_line_width, 646 | ) 647 | 648 | for bbox_info in extra_bboxes: 649 | bbox = BoundingBox3D(pmin0=bbox_info.pmin, pmax0=bbox_info.pmax) 650 | viz.add_bounding_box( 651 | f"bbox-{bbox_info.name}-{(id_counter := id_counter + 1)}", 652 | position=bbox.center, 653 | size=bbox.size, 654 | color=bbox_info.color, 655 | edge_width=bbox_line_width, 656 | ) 657 | 658 | viz_dir = os.path.join(viz_root_dir, f"{self.scene_id}_{self.viz_suffix}") 659 | # print(f"viz_dir: {viz_dir}") 660 | if os.path.exists(viz_dir): 661 | os.system(f"rm -rf {viz_dir}") 662 | 663 | viz.save(viz_dir, verbose=False) 664 | 665 | return viz_dir 666 | 667 | def visualize_open3d( 668 | self, 669 | target_id: int | str | None = None, 670 | pred_bbox: BoundingBox3D | None = None, 671 | anchor_bboxes: dict[str, BoundingBox3D | list[BoundingBox3D]] | None = None, 672 | segments: bool = False, 673 | ): 674 | import open3d as o3d 675 | 676 | if target_id is not None: 677 | target_id = str(target_id).strip() 678 | assert target_id.isdigit() 679 | 680 | geometries = [] 681 | 682 | for insts in self.instance_map.values(): 683 | for inst in insts: 684 | # if inst.inst_id not in {"2", "7", "35", "37"}: 685 | # continue 686 | 687 | if inst.vertices.shape[1] != 6: 688 | continue 689 | 690 | point_positions = inst.vertices[:, :3] 691 | if segments: 692 | point_colors = ( 693 | np.array( 694 | [[random() * 255, random() * 255, random() * 255]] 695 | * inst.vertices.shape[0] 696 | ) 697 | ).astype(np.uint8) 698 | else: 699 | point_colors = (inst.vertices[:, 3:] * 255).astype(np.uint8) 700 | 701 | point_cloud = o3d.t.geometry.PointCloud(point_positions) 702 | point_cloud.point.colors = point_colors 703 | geometries.append(point_cloud.to_legacy()) 704 | 705 | o3d.visualization.draw_geometries(geometries) 706 | -------------------------------------------------------------------------------- /scope_env.py: -------------------------------------------------------------------------------- 1 | # This file is loaded with two different global scopes. 2 | # The first scope is the one when this file is imported by another file. This will provide a list of available functions to be added to the LLM prompts. 3 | # The second scope is created when this file is loaded and run by the exec() function. This will provide a valid scope to run the program generated by LLM. 4 | # We can't have only one scope, because some predefined functions use global variables that are statically (lexically) binded. 5 | 6 | import abc 7 | import inspect 8 | from typing import Callable 9 | 10 | import numpy.typing as npt 11 | 12 | from misc_utils import is_list_of_type 13 | from scannet_utils import ObjInstance 14 | 15 | # registry for predefined functions 16 | AVAILABLE_HANDLERS: dict[str, object] = {} 17 | 18 | 19 | class TargetInfo: 20 | best_instance: ObjInstance | None = None 21 | candidate_instances: set[ObjInstance] = set() 22 | anchor_instances: dict[str, ObjInstance | list[ObjInstance]] = {} 23 | csp_desc: str | None = None 24 | llm_used: bool = False 25 | 26 | @staticmethod 27 | def reset(): 28 | __class__.best_instance = None 29 | __class__.candidate_instances = set() 30 | __class__.anchor_instances = {} 31 | __class__.csp_desc = None 32 | __class__.llm_used = False 33 | 34 | 35 | class GlobalState: 36 | relevant_obj_map: dict[int, str] = {} 37 | relevant_obj_set: set[str] = set() 38 | relevant_obj_instances: dict[str, list[ObjInstance]] = {} 39 | room_center: npt.NDArray | None = (None,) # array of shape (3,) 40 | room_corners: tuple[ 41 | npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray 42 | ] # four corners of the room (scene) 43 | 44 | @staticmethod 45 | def get_cand_insts(label: str | list[str]) -> list[ObjInstance]: 46 | if isinstance(label, str): 47 | return sorted( 48 | __class__.relevant_obj_instances[label], 49 | key=lambda x: x.inst_id, 50 | ) 51 | elif is_list_of_type(label, str): 52 | cand_insts = [set(__class__.relevant_obj_instances[lbl]) for lbl in label] 53 | return sorted( 54 | list(set.union(*cand_insts)), 55 | key=lambda x: x.inst_id, 56 | ) 57 | 58 | raise SystemError(f"invalid argument: {label}") 59 | 60 | 61 | def register_handler(disabled: bool = False): 62 | def register_handler_empty(handler_class): 63 | return handler_class 64 | 65 | def register_handler_func(handler_class): 66 | assert hasattr(handler_class, "FUNC_NAME") 67 | assert hasattr(handler_class, "SIG_STR") 68 | assert hasattr(handler_class, "call_type_check") 69 | assert hasattr(handler_class, "call") 70 | 71 | func_name = handler_class.FUNC_NAME 72 | 73 | if isinstance(func_name, str): 74 | AVAILABLE_HANDLERS[func_name] = handler_class() 75 | elif isinstance(func_name, list): 76 | assert all([isinstance(x, str) for x in func_name]) 77 | for name in func_name: 78 | AVAILABLE_HANDLERS[name] = handler_class() 79 | else: 80 | raise SystemError(f"invalid handler class: {handler_class}!") 81 | 82 | return handler_class 83 | 84 | if disabled: 85 | return register_handler_empty 86 | return register_handler_func 87 | 88 | 89 | def get_predef_func_sigs(): 90 | funcs = [] 91 | for func_name, handler_class in AVAILABLE_HANDLERS.items(): 92 | sig_str = handler_class.SIG_STR 93 | doc_str = None 94 | if hasattr(handler_class, "DOC_STR"): 95 | doc_str = handler_class.DOC_STR 96 | 97 | funcs.append((func_name, sig_str, doc_str)) 98 | return sorted(funcs, key=lambda x: x[0]) 99 | 100 | 101 | def get_eval_scope(use_type_check_funcs) -> dict[str, Callable]: 102 | """get a dict with predefined functions. this can be used as the scope for exec()""" 103 | 104 | def build_type_check_func(instance, name): 105 | def f(*args, **kwargs): 106 | ret = instance.call_type_check(*args, **kwargs) 107 | if hasattr(ret, "apparent_name"): 108 | ret.set_apparent_name(name) 109 | return ret 110 | 111 | return f 112 | 113 | def build_func(instance, name): 114 | def f(*args, **kwargs): 115 | ret = instance.call(*args, **kwargs) 116 | if hasattr(ret, "apparent_name"): 117 | ret.set_apparent_name(name) 118 | return ret 119 | 120 | return f 121 | 122 | func_dict = {} 123 | for func_name, handler_instance in AVAILABLE_HANDLERS.items(): 124 | if use_type_check_funcs: 125 | func_dict[func_name] = build_type_check_func(handler_instance, func_name) 126 | else: 127 | func_dict[func_name] = build_func(handler_instance, func_name) 128 | return func_dict 129 | 130 | 131 | def set_relevant_obj_map(relevant_obj_map): 132 | """set the object dict used in type checking phase""" 133 | GlobalState.relevant_obj_map = relevant_obj_map 134 | GlobalState.relevant_obj_set = set(relevant_obj_map.values()) 135 | 136 | 137 | def set_instance_map(instance_map): 138 | """set the instance map used in grounding phase""" 139 | GlobalState.relevant_obj_instances = instance_map 140 | 141 | 142 | def set_room_center(room_center): 143 | """set the room center used in grounding phase""" 144 | GlobalState.room_center = room_center 145 | 146 | 147 | def set_target_info( 148 | best_instance: ObjInstance, 149 | candidate_instances: set[ObjInstance], 150 | anchor_instances: dict[str, ObjInstance | list[ObjInstance]], 151 | csp_desc: str, 152 | llm_used: bool, 153 | ): 154 | assert isinstance(best_instance, ObjInstance) 155 | assert isinstance(candidate_instances, (set, list)) 156 | assert isinstance(anchor_instances, dict) 157 | assert isinstance(csp_desc, str) 158 | assert isinstance(llm_used, bool) 159 | 160 | TargetInfo.best_instance = best_instance 161 | TargetInfo.candidate_instances = set(candidate_instances) 162 | TargetInfo.anchor_instances = anchor_instances 163 | TargetInfo.csp_desc = csp_desc 164 | TargetInfo.llm_used = llm_used 165 | -------------------------------------------------------------------------------- /score_funcs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | import numpy as np 6 | 7 | from misc_utils import check, is_list_of_type, lookat_matrix 8 | from scannet_utils import ObjInstance 9 | from scope_env import GlobalState 10 | 11 | SCORE_FUNCTIONS: dict[str, type[ScoreFuncBase]] = {} 12 | 13 | 14 | def register_score_func(score_func_class: type[ScoreFuncBase]): 15 | assert hasattr(score_func_class, "NEED_ANCHOR") 16 | assert isinstance(score_func_class.NEED_ANCHOR, bool) 17 | 18 | assert hasattr(score_func_class, "NAME") 19 | 20 | if isinstance(score_func_class.NAME, str): 21 | assert score_func_class.NAME not in SCORE_FUNCTIONS 22 | SCORE_FUNCTIONS[score_func_class.NAME] = score_func_class 23 | 24 | elif is_list_of_type(score_func_class.NAME, str): 25 | for name in score_func_class.NAME: 26 | assert name not in SCORE_FUNCTIONS 27 | SCORE_FUNCTIONS[name] = score_func_class 28 | 29 | else: 30 | raise SystemError(f"invalid score_func NAME: {score_func_class.NAME}") 31 | 32 | 33 | class ScoreFuncBase(ABC): 34 | NAME: str | None = None 35 | NEED_ANCHOR: bool | None = None 36 | 37 | @staticmethod 38 | @abstractmethod 39 | def get_scores( 40 | candidate_instances: list[ObjInstance], 41 | anchor: ObjInstance | None = None, 42 | ) -> list[float]: 43 | """compute the score for each instance. anchor is optional.""" 44 | 45 | 46 | @register_score_func 47 | class ScoreDistance(ScoreFuncBase): 48 | NAME = "distance" 49 | NEED_ANCHOR = True 50 | 51 | @staticmethod 52 | def get_scores( 53 | candidate_instances: list[ObjInstance], 54 | anchor: ObjInstance | None = None, 55 | ) -> list[float]: 56 | return [ 57 | np.linalg.norm(x.bbox.center - anchor.bbox.center) 58 | for x in candidate_instances 59 | ] 60 | 61 | 62 | @register_score_func 63 | class ScoreSizeX(ScoreFuncBase): 64 | NAME = "size-x" 65 | NEED_ANCHOR = False 66 | 67 | @staticmethod 68 | def get_scores( 69 | candidate_instances: list[ObjInstance], 70 | anchor: ObjInstance | None = None, 71 | ) -> list[float]: 72 | return [x.bbox.size[0] for x in candidate_instances] 73 | 74 | 75 | @register_score_func 76 | class ScoreSizeY(ScoreFuncBase): 77 | NAME = "size-y" 78 | NEED_ANCHOR = False 79 | 80 | @staticmethod 81 | def get_scores( 82 | candidate_instances: list[ObjInstance], 83 | anchor: ObjInstance | None = None, 84 | ) -> list[float]: 85 | return [x.bbox.size[1] for x in candidate_instances] 86 | 87 | 88 | @register_score_func 89 | class ScoreSizeZ(ScoreFuncBase): 90 | NAME = "size-z" 91 | NEED_ANCHOR = False 92 | 93 | @staticmethod 94 | def get_scores( 95 | candidate_instances: list[ObjInstance], 96 | anchor: ObjInstance | None = None, 97 | ) -> list[float]: 98 | return [x.bbox.size[2] for x in candidate_instances] 99 | 100 | 101 | @register_score_func 102 | class ScoreMaxSize(ScoreFuncBase): 103 | NAME = "size" 104 | NEED_ANCHOR = False 105 | 106 | @staticmethod 107 | def get_scores( 108 | candidate_instances: list[ObjInstance], 109 | anchor: ObjInstance | None = None, 110 | ) -> list[float]: 111 | return [x.bbox.max_extent for x in candidate_instances] 112 | 113 | 114 | @register_score_func 115 | class ScorePositionZ(ScoreFuncBase): 116 | NAME = "position-z" 117 | NEED_ANCHOR = False 118 | 119 | @staticmethod 120 | def get_scores( 121 | candidate_instances: list[ObjInstance], 122 | anchor: ObjInstance | None = None, 123 | ) -> list[float]: 124 | return [x.bbox.center[2] for x in candidate_instances] 125 | 126 | 127 | @register_score_func 128 | class ScoreLeft(ScoreFuncBase): 129 | NAME = "left" 130 | NEED_ANCHOR = False 131 | 132 | @staticmethod 133 | def get_scores( 134 | candidate_instances: list[ObjInstance], 135 | anchor: ObjInstance | None = None, 136 | ) -> list[float]: 137 | cand_center = np.mean([x.bbox.center for x in candidate_instances], axis=0) 138 | # look at the center of all candidate instances from the room center 139 | world_to_local = lookat_matrix(eye=cand_center, target=GlobalState.room_center) 140 | 141 | return [ 142 | -(world_to_local @ np.hstack([x.bbox.center, 1]))[0] 143 | for x in candidate_instances 144 | ] 145 | 146 | 147 | @register_score_func 148 | class ScoreRight(ScoreFuncBase): 149 | NAME = "right" 150 | NEED_ANCHOR = False 151 | 152 | @staticmethod 153 | def get_scores( 154 | candidate_instances: list[ObjInstance], 155 | anchor: ObjInstance | None = None, 156 | ) -> list[float]: 157 | cand_center = np.mean([x.bbox.center for x in candidate_instances], axis=0) 158 | # look at the center of all candidate instances from the room center 159 | world_to_local = lookat_matrix(eye=cand_center, target=GlobalState.room_center) 160 | 161 | return [ 162 | (world_to_local @ np.hstack([x.bbox.center, 1]))[0] 163 | for x in candidate_instances 164 | ] 165 | 166 | 167 | @register_score_func 168 | class ScoreFront(ScoreFuncBase): 169 | NAME = "front" 170 | NEED_ANCHOR = False 171 | 172 | @staticmethod 173 | def get_scores( 174 | candidate_instances: list[ObjInstance], 175 | anchor: ObjInstance | None = None, 176 | ) -> list[float]: 177 | cand_center = np.mean([x.bbox.center for x in candidate_instances], axis=0) 178 | # look at the center of all candidate instances from the room center 179 | world_to_local = lookat_matrix(eye=cand_center, target=GlobalState.room_center) 180 | 181 | # the larger the z-coord value, the nearer the instance is to the room center, i.e. "to the front" 182 | return [ 183 | (world_to_local @ np.hstack([x.bbox.center, 1]))[2] 184 | for x in candidate_instances 185 | ] 186 | 187 | 188 | @register_score_func 189 | class ScoreCenter(ScoreFuncBase): 190 | NAME = ["distance-to-center", "distance-to-middle"] 191 | NEED_ANCHOR = False 192 | 193 | @staticmethod 194 | def get_scores( 195 | candidate_instances: list[ObjInstance], 196 | anchor: ObjInstance | None = None, 197 | ) -> list[float]: 198 | center = np.mean([x.bbox.center for x in candidate_instances], axis=0) 199 | return [np.linalg.norm(x.bbox.center - center) for x in candidate_instances] 200 | -------------------------------------------------------------------------------- /visualize_eval_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import subprocess 6 | from contextlib import redirect_stdout 7 | 8 | import numpy as np 9 | 10 | from scannet_utils import BoundingBox3D, ScanNetScene 11 | 12 | 13 | def parse_bbox(data: dict | list) -> BoundingBox3D | list[BoundingBox3D]: 14 | if isinstance(data, list): 15 | result = [] 16 | for item in data: 17 | assert isinstance(item, dict) 18 | assert "pmin" in item 19 | assert "pmax" in item 20 | result.append( 21 | BoundingBox3D( 22 | pmin0=np.array(item["pmin"]), 23 | pmax0=np.array(item["pmax"]), 24 | ) 25 | ) 26 | return result 27 | 28 | else: 29 | assert isinstance(data, dict) 30 | assert "pmin" in data 31 | assert "pmax" in data 32 | return BoundingBox3D(pmin0=np.array(data["pmin"]), pmax0=np.array(data["pmax"])) 33 | 34 | 35 | def format_query(query: str) -> str: 36 | query = re.sub(r"[,.:;?!]", "", query) 37 | query = re.sub(r" +", " ", query) 38 | query = query.lower().strip() 39 | return query 40 | 41 | 42 | def get_gt_query(formatted_query: str, gt_eval_results: dict) -> str | None: 43 | for eval_res in gt_eval_results: 44 | gt_query = format_query(eval_res["text"]) 45 | if gt_query == formatted_query: 46 | return eval_res 47 | 48 | 49 | def get_zsvg_query(formatted_query: str, zsvg_eval_results: dict) -> str | None: 50 | for eval_res in zsvg_eval_results: 51 | zsvg_query = format_query(eval_res["caption"]) 52 | if zsvg_query == formatted_query: 53 | return eval_res 54 | 55 | 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--file", type=str, required=True) 58 | parser.add_argument("--gt-file", type=str) 59 | parser.add_argument("--zsvg3d-file", type=str) 60 | parser.add_argument("--no-server", action="store_true") 61 | parser.add_argument("--viz-port", type=int, default=8889) 62 | parser.add_argument("--zsvg-fail", action="store_true") 63 | parser.add_argument("--csvg-fail", action="store_true") 64 | parser.add_argument("--anchor-not-required", action="store_true") 65 | parser.add_argument("--experiment-name", type=str) 66 | parser.add_argument("--distractor-required", action="store_true") 67 | parser.add_argument("--func", type=str, action="append") 68 | parser.add_argument("--no-func", type=str, action="append") 69 | args = parser.parse_args() 70 | 71 | eval_results_file = args.file 72 | assert os.path.isfile(eval_results_file) 73 | with open(eval_results_file) as f: 74 | eval_results = json.load(f) 75 | 76 | if args.gt_file: 77 | gt_eval_results_file = args.gt_file 78 | assert os.path.isfile(gt_eval_results_file) 79 | with open(gt_eval_results_file) as f: 80 | gt_eval_results = json.load(f) 81 | 82 | if args.zsvg3d_file: 83 | zsvg_eval_results_file = args.zsvg3d_file 84 | assert os.path.isfile(zsvg_eval_results_file) 85 | with open(zsvg_eval_results_file) as f: 86 | zsvg_eval_results = json.load(f) 87 | 88 | scannet_root = "./data" 89 | viz_dir_root = "./data/visualize_eval" 90 | assert os.path.isdir(scannet_root) 91 | if not os.path.isdir(viz_dir_root): 92 | os.system(f"mkdir -p {viz_dir_root}") 93 | 94 | zsvg_should_work = not args.zsvg_fail 95 | csvg_mask3d_should_work = not args.csvg_fail 96 | csvg_gt_should_work = True 97 | 98 | conf_str = f"csvg_{0 if args.csvg_fail else 1}" 99 | if args.zsvg3d_file: 100 | conf_str += f"_zsvg_{0 if args.zsvg_fail else 1}" 101 | if args.distractor_required: 102 | conf_str += "_with_distractor" 103 | if args.experiment_name: 104 | conf_str += f"_{args.experiment_name}" 105 | 106 | print() 107 | print(f"config str: {conf_str}") 108 | input("press enter to continue...") 109 | print() 110 | 111 | for i, eval_res in enumerate(eval_results): 112 | print(f"progress: {i} / {len(eval_results)}.") 113 | # if ( 114 | # eval_res["acc05"] is False 115 | # and eval_res["acc025"] is False 116 | # and "predicted_bbox" in eval_res 117 | # ): 118 | 119 | if "anchor_bboxes" not in eval_res or not eval_res["anchor_bboxes"]: 120 | if not args.anchor_not_required: 121 | continue 122 | 123 | if eval_res["acc05"] == csvg_mask3d_should_work: 124 | # assert eval_res["acc025"] is True 125 | 126 | query = format_query(eval_res["text"]) 127 | 128 | if args.gt_file: 129 | gt_eval_res = get_gt_query(query, gt_eval_results) 130 | assert gt_eval_res is not None 131 | 132 | if gt_eval_res["acc05"] != csvg_gt_should_work: 133 | continue 134 | 135 | if args.zsvg3d_file: 136 | zsvg_eval_res = get_zsvg_query(query, zsvg_eval_results) 137 | if zsvg_eval_res is None: 138 | continue 139 | 140 | if ( 141 | "pred_box" not in zsvg_eval_res 142 | or zsvg_eval_res["acc05"] != zsvg_should_work 143 | ): 144 | continue 145 | 146 | scene_id = eval_res["scene_id"] 147 | scene_path = os.path.join(scannet_root, "scans", scene_id) 148 | 149 | # print(json.dumps(gt_eval_res, sort_keys=True, indent=4)) 150 | # print(json.dumps(zsvg_eval_res, sort_keys=True, indent=4)) 151 | 152 | scene = ScanNetScene( 153 | scene_path, 154 | add_room_center=False, 155 | add_room_corners=False, 156 | cache_root="./data/instance_cache", 157 | ) 158 | 159 | if args.distractor_required: 160 | target_id = eval_res["target_id"] 161 | if not scene.get_distractors(target_id): 162 | continue 163 | 164 | if args.func: 165 | if not any(func in eval_res["program"] for func in args.func): 166 | continue 167 | 168 | if args.no_func: 169 | if any(func in eval_res["program"] for func in args.no_func): 170 | continue 171 | 172 | print() 173 | print("=" * 30) 174 | print(f"scene id : {scene_id}") 175 | print(f"target id : {eval_res['target_id']}") 176 | print(f"query : {eval_res['text']}") 177 | print(f"target label: {eval_res['target_label']}") 178 | print() 179 | print("program :") 180 | print(">" * 10) 181 | print(eval_res["program"]) 182 | print("<" * 10) 183 | print() 184 | print("csp :") 185 | print(">" * 10) 186 | print(eval_res["csp_desc"]) 187 | print("<" * 10) 188 | if args.gt_file: 189 | print() 190 | print("csvg gt program") 191 | print(">" * 10) 192 | print(gt_eval_res["program"]) 193 | print("<" * 10) 194 | if args.zsvg3d_file: 195 | print() 196 | print("zsvg3d program") 197 | print(">" * 10) 198 | print(zsvg_eval_res["program"]) 199 | print("<" * 10) 200 | print() 201 | print(f"query : {eval_res['text']}") 202 | print() 203 | print("=" * 30) 204 | print() 205 | 206 | if args.zsvg3d_file: 207 | zsvg_bbox = BoundingBox3D( 208 | pcenter0=zsvg_eval_res["pred_box"][:3], 209 | psize0=zsvg_eval_res["pred_box"][3:], 210 | ) 211 | 212 | gt_color = [0.4, 0.7, 1.0] 213 | pred_color_correct = [0.3, 1.0, 0.3] 214 | pred_color_wrong = [1.0, 0.2, 0.2] 215 | anchor_color = [1.0, 0.7, 0.0] 216 | distractor_color = [1.0, 0.3, 0.7] 217 | 218 | extra_bboxes = [] 219 | 220 | if args.gt_file: 221 | csvg_gt_bbox = parse_bbox(gt_eval_res["predicted_bbox"]) 222 | extra_bboxes.append( 223 | ScanNetScene.BBoxInfo( 224 | pmin=csvg_gt_bbox.pmin, 225 | pmax=csvg_gt_bbox.pmax, 226 | color=pred_color_correct 227 | if csvg_gt_should_work 228 | else pred_color_wrong, 229 | name="csvg-gt-pred", 230 | ), 231 | ) 232 | 233 | if args.zsvg3d_file: 234 | extra_bboxes.append( 235 | ScanNetScene.BBoxInfo( 236 | pmin=zsvg_bbox.pmin, 237 | pmax=zsvg_bbox.pmax, 238 | color=pred_color_correct if zsvg_should_work else pred_color_wrong, 239 | name="zsvg-pred", 240 | ), 241 | ) 242 | 243 | if args.gt_file: 244 | for name, bbox in gt_eval_res["anchor_bboxes"].items(): 245 | bbox = parse_bbox(bbox) 246 | extra_bboxes.append( 247 | ScanNetScene.BBoxInfo( 248 | pmin=bbox.pmin, 249 | pmax=bbox.pmax, 250 | color=anchor_color, 251 | name=f"csvg-gt-anchor-{name}", 252 | ) 253 | ) 254 | 255 | pred_bbox = parse_bbox(eval_res["predicted_bbox"]) 256 | extra_bboxes.append( 257 | ScanNetScene.BBoxInfo( 258 | pmin=pred_bbox.pmin, 259 | pmax=pred_bbox.pmax, 260 | color=pred_color_correct 261 | if csvg_mask3d_should_work 262 | else pred_color_wrong, 263 | name="csvg-mask3d-pred", 264 | ) 265 | ) 266 | 267 | anchor_bboxes = { 268 | name: parse_bbox(bbox) for name, bbox in eval_res["anchor_bboxes"].items() 269 | } 270 | for name, bbox in anchor_bboxes.items(): 271 | assert isinstance(bbox, BoundingBox3D) 272 | extra_bboxes.append( 273 | ScanNetScene.BBoxInfo( 274 | pmin=bbox.pmin, 275 | pmax=bbox.pmax, 276 | color=anchor_color, 277 | name="csvg-mask3d-anchor", 278 | ) 279 | ) 280 | 281 | if args.distractor_required: 282 | target_id = eval_res["target_id"] 283 | distractors = scene.get_distractors(target_id) 284 | for inst in distractors: 285 | extra_bboxes.append( 286 | ScanNetScene.BBoxInfo( 287 | pmin=inst.bbox.pmin, 288 | pmax=inst.bbox.pmax, 289 | color=distractor_color, 290 | name="distractor", 291 | ) 292 | ) 293 | 294 | viz_dir = scene.visualize_pyviz3d( 295 | viz_dir_root, 296 | target_id=eval_res["target_id"], 297 | target_color=gt_color, 298 | extra_bboxes=extra_bboxes, 299 | ) 300 | 301 | if not args.no_server: 302 | try: 303 | proc = subprocess.Popen( 304 | ["python", "-m", "http.server", str(args.viz_port), "-d", viz_dir], 305 | stdout=subprocess.DEVNULL, 306 | stderr=subprocess.DEVNULL, 307 | ) 308 | print(f"serving visualization at http://0.0.0.0:{args.viz_port}/") 309 | print() 310 | cmd = input("press enter to continue...\n") 311 | print() 312 | 313 | if cmd.lower().strip() == "store": 314 | file_path = f"./figure_data_{conf_str}.txt" 315 | with open(file_path, "a") as f: 316 | with redirect_stdout(f): 317 | print() 318 | print("=" * 30) 319 | print(f"scene id : {scene_id}") 320 | print(f"target id : {eval_res['target_id']}") 321 | print(f"query : {eval_res['text']}") 322 | print(f"target label: {eval_res['target_label']}") 323 | print() 324 | print("program :") 325 | print(">" * 10) 326 | print(eval_res["program"]) 327 | print("<" * 10) 328 | print() 329 | print("csp :") 330 | print(">" * 10) 331 | print(eval_res["csp_desc"]) 332 | print("<" * 10) 333 | if args.gt_file: 334 | print() 335 | print("csvg gt program") 336 | print(">" * 10) 337 | print(gt_eval_res["program"]) 338 | print("<" * 10) 339 | if args.zsvg3d_file: 340 | print() 341 | print("zsvg3d program") 342 | print(">" * 10) 343 | print(zsvg_eval_res["program"]) 344 | print("<" * 10) 345 | print() 346 | print("=" * 30) 347 | print() 348 | print() 349 | print(f"query info stored to {file_path}") 350 | input("press enter to continue...") 351 | print() 352 | 353 | except Exception: 354 | pass 355 | finally: 356 | proc.kill() 357 | -------------------------------------------------------------------------------- /visualize_segs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from scannet_utils import ScanNetScene 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--scene", type=str, required=True) 8 | parser.add_argument("--no-server", action="store_true") 9 | parser.add_argument("--viz-port", type=int, default=8889) 10 | parser.add_argument("--segs", action="store_true") 11 | 12 | 13 | def parse_color(arg): 14 | if arg.count(":") != 3: 15 | raise argparse.ArgumentTypeError( 16 | "invalid bbox color mapping: missing component." 17 | ) 18 | 19 | inst_id, r, g, b = arg.split(":") 20 | inst_id, r, g, b = inst_id.strip(), r.strip(), g.strip(), b.strip() 21 | 22 | if not inst_id or not r or not g or not b: 23 | raise argparse.ArgumentTypeError("invalid bbox color mapping: empty component.") 24 | 25 | if ( 26 | (not r.replace(".", "", 1).isdigit()) 27 | or (not g.replace(".", "", 1).isdigit()) 28 | or (not b.replace(".", "", 1).isdigit()) 29 | ): 30 | raise argparse.ArgumentTypeError("invalid bbox color mapping: rgb not float.") 31 | 32 | try: 33 | result = [inst_id, [float(r), float(g), float(b)]] 34 | except Exception: 35 | raise argparse.ArgumentTypeError("invalid bbox color mapping: i don't know.") 36 | 37 | if ( 38 | result[1][0] < 0 39 | or result[1][0] > 1 40 | or result[1][1] < 0 41 | or result[1][1] > 1 42 | or result[1][2] < 0 43 | or result[1][2] > 1 44 | ): 45 | raise argparse.ArgumentTypeError( 46 | "invalid bbox color mapping: invalid rgb value." 47 | ) 48 | 49 | return result 50 | 51 | 52 | parser.add_argument( 53 | "-bh", 54 | "--bbox-highlight", 55 | metavar="inst_id:r,g,b", 56 | type=parse_color, 57 | action="append", 58 | ) 59 | parser.add_argument( 60 | "-sh", 61 | "--seg-highlight", 62 | metavar="inst_id:r,g,b", 63 | type=parse_color, 64 | action="append", 65 | ) 66 | parser.add_argument( 67 | "-c", 68 | "--color", 69 | metavar="inst_id:r,g,b", 70 | type=parse_color, 71 | action="append", 72 | ) 73 | 74 | group = parser.add_mutually_exclusive_group() 75 | group.add_argument("--mask3d", action="store_true") 76 | group.add_argument("--maskcluster", action="store_true") 77 | 78 | args = parser.parse_args() 79 | 80 | scene_id = args.scene if "scene" in args.scene else "scene" + args.scene 81 | scannet_root = "./data" 82 | viz_dir_root = "./visualize" 83 | mask3d_pred_root = "./data/eval_output/mask3d_val" if args.mask3d else None 84 | maskcluster_pred_root = "./data/eval_output/maskcluster" if args.maskcluster else None 85 | 86 | seg_colors = {} if args.color is None else {c[0]: c[1] for c in args.color} 87 | bbox_highlights = ( 88 | {} if args.bbox_highlight is None else {bh[0]: bh[1] for bh in args.bbox_highlight} 89 | ) 90 | seg_highlights = ( 91 | {} if args.seg_highlight is None else {sh[0]: sh[1] for sh in args.seg_highlight} 92 | ) 93 | 94 | scene_path = os.path.join(scannet_root, "scans", scene_id) 95 | scene = ScanNetScene( 96 | scene_path, 97 | mask3d_pred_path=mask3d_pred_root, 98 | maskcluster_pred_path=maskcluster_pred_root, 99 | add_room_center=False, 100 | add_room_corners=False, 101 | ) 102 | viz_dir = scene.visualize_pyviz3d( 103 | viz_dir_root, 104 | segments=args.segs, 105 | seg_colors=seg_colors, 106 | bbox_highlights=bbox_highlights, 107 | seg_highlights=seg_highlights, 108 | ) 109 | 110 | if not args.no_server: 111 | os.system(f"python -m http.server {args.viz_port} -d {viz_dir}") 112 | --------------------------------------------------------------------------------