├── arc_agi ├── __init__.py ├── utils.py ├── solve.py ├── scoring.py ├── config.py ├── types.py ├── io.py ├── sandbox.py ├── solve_parallel_coding.py ├── llm.py ├── solve_coding.py └── prompts.py ├── .gitignore ├── arcagi1.png ├── arcagi2.png ├── arc2captured.png ├── officialtable_boxed.png ├── requirements.txt ├── LICENSE.txt ├── README.md ├── main.py └── data ├── arc-prize-2024 └── sample_submission.json └── arc-prize-2025 └── sample_submission.json /arc_agi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .env 3 | .DS_Store 4 | output/ -------------------------------------------------------------------------------- /arcagi1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poetiq-ai/poetiq-arc-agi-solver/HEAD/arcagi1.png -------------------------------------------------------------------------------- /arcagi2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poetiq-ai/poetiq-arc-agi-solver/HEAD/arcagi2.png -------------------------------------------------------------------------------- /arc2captured.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poetiq-ai/poetiq-arc-agi-solver/HEAD/arc2captured.png -------------------------------------------------------------------------------- /officialtable_boxed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/poetiq-ai/poetiq-arc-agi-solver/HEAD/officialtable_boxed.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asynciolimiter==1.2.0 2 | litellm==1.78.2 3 | numpy==2.3.4 4 | python-dotenv==1.1.1 5 | scipy==1.16.3 6 | -------------------------------------------------------------------------------- /arc_agi/utils.py: -------------------------------------------------------------------------------- 1 | from arc_agi.types import RunResult 2 | 3 | 4 | def canonical_test_key(results: list[RunResult]) -> str: 5 | return str([r["output"] for r in results]) 6 | -------------------------------------------------------------------------------- /arc_agi/solve.py: -------------------------------------------------------------------------------- 1 | from arc_agi.config import CONFIG_LIST 2 | from arc_agi.solve_parallel_coding import solve_parallel_coding 3 | from arc_agi.types import ARCAGIResult 4 | 5 | 6 | async def solve( 7 | train_in: list[list[list[int]]], 8 | train_out: list[list[list[int]]], 9 | test_in: list[list[list[int]]], 10 | problem_id: str | None = None, 11 | ) -> list[ARCAGIResult]: 12 | result = await solve_parallel_coding( 13 | train_in=train_in, 14 | train_out=train_out, 15 | test_in=test_in, 16 | expert_configs=[cfg.copy() for cfg in CONFIG_LIST], 17 | problem_id=problem_id, 18 | ) 19 | 20 | return result 21 | -------------------------------------------------------------------------------- /arc_agi/scoring.py: -------------------------------------------------------------------------------- 1 | def grids_equal(a, b) -> bool: 2 | """Strict structural equality for ARC grids (list[list[int]]).""" 3 | return a == b 4 | 5 | 6 | def score_task(kaggle_preds: list[dict], gt_outputs: list) -> float: 7 | """ 8 | Fraction of test inputs correct for a task. 9 | Correct if attempt_1 == GT or attempt_2 == GT for each test input. 10 | """ 11 | if not gt_outputs: 12 | return 0.0 13 | correct = 0 14 | for i, gt in enumerate(gt_outputs): 15 | if i >= len(kaggle_preds): 16 | continue 17 | pack = kaggle_preds[i] or {} 18 | a1 = pack.get("attempt_1") 19 | a2 = pack.get("attempt_2") 20 | if (a1 is not None and grids_equal(a1, gt)) or ( 21 | a2 is not None and grids_equal(a2, gt) 22 | ): 23 | correct += 1 24 | return correct / max(len(gt_outputs), 1) 25 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2025 Poetiq, Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /arc_agi/config.py: -------------------------------------------------------------------------------- 1 | from arc_agi.prompts import FEEDBACK_PROMPT, SOLVER_PROMPT_1, SOLVER_PROMPT_2, SOLVER_PROMPT_3 2 | from arc_agi.types import ExpertConfig 3 | 4 | # To run Poetiq(Gemini-3-a): 5 | NUM_EXPERTS = 1 6 | # To run Poetiq(Gemini-3-b): 7 | # NUM_EXPERTS = 2 8 | # To run Poetiq(Gemini-3-c): 9 | # NUM_EXPERTS = 8 10 | 11 | CONFIG_LIST: list[ExpertConfig] = [ 12 | { 13 | # Prompts 14 | 'solver_prompt': SOLVER_PROMPT_1, 15 | 'feedback_prompt': FEEDBACK_PROMPT, 16 | # LLM parameters 17 | 'llm_id': 'gemini/gemini-3-pro-preview', 18 | 'solver_temperature': 1.0, 19 | 'request_timeout': 60 * 60, # in seconds 20 | 'max_total_timeouts': 15, # per problem per solver 21 | 'max_total_time': None, # per problem per solver 22 | 'per_iteration_retries': 2, 23 | # Solver parameters 24 | 'num_experts': 1, 25 | 'max_iterations': 10, 26 | 'max_solutions': 5, 27 | 'selection_probability': 1.0, 28 | 'seed': 0, 29 | 'shuffle_examples': True, 30 | 'improving_order': True, 31 | 'return_best_result': True, 32 | # Voting parameters 33 | 'use_new_voting': True, 34 | 'count_failed_matches': True, 35 | 'iters_tiebreak': False, 36 | 'low_to_high_iters': False, 37 | }, 38 | ] * NUM_EXPERTS 39 | -------------------------------------------------------------------------------- /arc_agi/types.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional, TypedDict 2 | 3 | Models = Literal[ 4 | "groq/openai/gpt-oss-120b", 5 | "openai/gpt-5", 6 | "openai/gpt-5.1", 7 | "xai/grok-4-fast", 8 | "xai/grok-4", 9 | "anthropic/claude-sonnet-4-5", 10 | "anthropic/claude-haiku-4-5", 11 | "gemini/gemini-2.5-pro", 12 | "gemini/gemini-3-pro-preview", 13 | ] 14 | 15 | 16 | class ExpertConfig(TypedDict): 17 | use_new_voting: bool 18 | count_failed_matches: bool 19 | iters_tiebreak: bool 20 | low_to_high_iters: bool 21 | solver_prompt: str 22 | feedback_prompt: str 23 | llm_id: Models 24 | max_iterations: int 25 | solver_temperature: float 26 | max_solutions: int 27 | selection_probability: float 28 | seed: int 29 | shuffle_examples: bool 30 | improving_order: bool 31 | return_best_result: bool 32 | request_timeout: Optional[int] 33 | max_total_timeouts: Optional[int] 34 | max_total_time: Optional[int] 35 | num_experts: int 36 | per_iteration_retries: int 37 | 38 | 39 | MessageRole = Literal["user", "assistant", "system"] 40 | 41 | 42 | class Message(TypedDict): 43 | role: MessageRole 44 | content: str 45 | 46 | 47 | class RunResult(TypedDict): 48 | success: bool 49 | output: str 50 | soft_score: float 51 | error: Optional[str] 52 | code: str 53 | 54 | 55 | class ARCAGIResult(TypedDict): 56 | train_results: list[RunResult] 57 | results: list[RunResult] # test results 58 | iteration: int 59 | prompt_tokens: Optional[int] 60 | completion_tokens: Optional[int] 61 | 62 | 63 | class ARCAGISolution(TypedDict): 64 | code: str 65 | feedback: str 66 | score: float 67 | -------------------------------------------------------------------------------- /arc_agi/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, List 3 | 4 | from arc_agi.types import ARCAGIResult 5 | 6 | 7 | def _coerce_grid(x: Any) -> list: 8 | # numpy -> list 9 | try: 10 | import numpy as _np 11 | 12 | if isinstance(x, _np.ndarray): 13 | return x.tolist() 14 | except Exception: 15 | pass 16 | # stringified JSON -> list 17 | if isinstance(x, str): 18 | s = x.strip() 19 | if s and (s[0] == "[" or s[0] == "{"): 20 | try: 21 | parsed = json.loads(s) 22 | return parsed 23 | except Exception: 24 | # not JSON; fall through 25 | return [] 26 | else: 27 | return [] 28 | # already list-like? 29 | if isinstance(x, list): 30 | return x 31 | return [] 32 | 33 | 34 | def build_kaggle_two_attempts(results: list[ARCAGIResult], test_in: List[List[List[int]]]): 35 | """ 36 | Returns: List[{"attempt_1": grid, "attempt_2": grid}] with len == len(test_in). 37 | """ 38 | num_tests = len(test_in) 39 | out = [] 40 | 41 | for j in range(num_tests): 42 | attempts: List[list] = [] 43 | 44 | # Sweep iterations in order; collect up to 2 successful outputs for test j 45 | for ar in results: 46 | tr = ar.get("results", []) 47 | if j < len(tr): 48 | rr = tr[j] 49 | grid = _coerce_grid(rr.get("output", [])) 50 | if grid != []: 51 | attempts.append(grid) 52 | if len(attempts) == 2: 53 | break 54 | 55 | # Pad with empty arrays if fewer than two attempts available 56 | while len(attempts) < 2: 57 | attempts.append([]) 58 | 59 | out.append({"attempt_1": attempts[0], "attempt_2": attempts[1]}) 60 | 61 | return out 62 | -------------------------------------------------------------------------------- /arc_agi/sandbox.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import os 4 | import sys 5 | import tempfile 6 | import textwrap 7 | 8 | async def run( 9 | code: str, input_grid: list[list[int]], timeout_s: float = 1.5 10 | ) -> tuple[bool, str]: 11 | """Run user code in a subprocess asynchronously, returning (ok, result or error).""" 12 | script = _build_script(code) 13 | 14 | with tempfile.TemporaryDirectory() as td: 15 | path = os.path.join(td, "u.py") 16 | with open(path, "w", encoding="utf-8") as f: 17 | f.write(textwrap.dedent(script)) 18 | 19 | proc = await asyncio.create_subprocess_exec( 20 | sys.executable, 21 | path, 22 | stdin=asyncio.subprocess.PIPE, 23 | stdout=asyncio.subprocess.PIPE, 24 | stderr=asyncio.subprocess.PIPE, 25 | cwd=td, 26 | env={"PYTHONHASHSEED": "0"}, 27 | ) 28 | 29 | try: 30 | stdout, stderr = await asyncio.wait_for( 31 | proc.communicate(input=json.dumps({"input": input_grid}).encode()), 32 | timeout=timeout_s, 33 | ) 34 | except asyncio.TimeoutError: 35 | try: 36 | proc.kill() 37 | except ProcessLookupError: 38 | pass 39 | return False, "timeout" 40 | 41 | if proc.returncode != 0: 42 | return False, (stderr.decode() or stdout.decode()).strip() 43 | 44 | try: 45 | payload = json.loads(stdout.decode()) 46 | return bool(payload.get("ok")), json.dumps(payload.get("result")) 47 | except Exception as e: 48 | return False, f"bad-json: {e}" 49 | 50 | 51 | def _build_script(code: str) -> str: 52 | return f""" 53 | # generated file 54 | {code} 55 | if __name__ == '__main__': 56 | import json 57 | import numpy as np 58 | import scipy 59 | from sys import stdin 60 | data = json.load(stdin) 61 | res = transform(np.array(data['input'])) 62 | print(json.dumps({{"ok": True, 'result': res.tolist()}})) 63 | """ 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Poetiq](https://poetiq.ai): SOTA Reasoning on ARC-AGI 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/) 5 | [![ARC-AGI](https://img.shields.io/badge/Task-ARC--AGI-red)](https://arcprize.org/) 6 | 7 | This repository allows reproduction of **Poetiq's** record-breaking submission to the ARC-AGI-1 and ARC-AGI-2 benchmarks. 8 | 9 | Full analysis is available in our launch post, **[Traversing the Frontier of Superintelligence](https://poetiq.ai/posts/arcagi_announcement/)**. 10 | 11 | Our method is now on top of the official leaderboard. More information is in our follow-up post, **[Poetiq Shatters ARC-AGI-2 State of the Art at Half the Cost](https://poetiq.ai/posts/arcagi_verified/)**. 12 | 13 | --- 14 | 15 | ## 📊 Public Eval Results 16 | You can recreate the Gemini 3 points from these charts using this repo. 17 | 18 |

19 | 20 | 21 |

22 | 23 | ## 📊 Official Private Eval Results 24 | These are our results on the official leaderboard from ARC Prize, but those problems are kept private. 25 | 26 |

27 | 28 |

29 |

30 | 31 |

32 | 33 | ## 🛠️ Usage 34 | 35 | ### Prerequisites 36 | - Python 3.11+ 37 | - API Keys for the models you wish to test (Gemini, OpenAI, etc.) 38 | 39 | ### Quick Start 40 | 41 | 1. Setup the environment: 42 | ```bash 43 | python -m venv .venv 44 | source .venv/bin/activate 45 | pip install -r requirements.txt 46 | ``` 47 | 48 | 2. Create a .env file in the root directory. You must include keys for the models you intend to run. 49 | 50 | ```bash 51 | GEMINI_API_KEY=... 52 | OPENAI_API_KEY=... 53 | ``` 54 | 55 | 3. Modify the constants in main.py to set the problem set, number of problems, etc. Then run the script: 56 | 57 | ```bash 58 | python main.py 59 | ``` 60 | 61 | 4. By default, the code runs the Poetiq 3 config described in the blog post. You can uncomment other ones or modify the config in config.py 62 | 63 | ## 📄 Contact 64 | If you use this code or these results in your research, please cite our blog post: 65 | 66 | Poetiq Team. (2025). *Traversing the Frontier of Superintelligence*. Poetiq AI. [https://poetiq.ai/posts/arcagi_announcement/](https://poetiq.ai/posts/arcagi_announcement/) 67 | 68 | For questions or to discuss the future of reasoning, reach out to us at poetiq@poetiq.ai. 69 | 70 | [![X (formerly Twitter)](https://img.shields.io/badge/X-000000?style=for-the-badge&logo=x&logoColor=white)](https://x.com/poetiq_ai) 71 | [![LinkedIn](https://img.shields.io/badge/LinkedIn-0077B5?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/company/poetiq/) 72 | [![Bluesky](https://img.shields.io/badge/Bluesky-0285FF?style=for-the-badge&logo=Bluesky&logoColor=white)](https://bsky.app/profile/poetiq-ai.bsky.social) 73 | -------------------------------------------------------------------------------- /arc_agi/solve_parallel_coding.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import numpy as np 4 | 5 | from arc_agi.solve_coding import solve_coding 6 | from arc_agi.types import ARCAGIResult, ExpertConfig 7 | from arc_agi.utils import canonical_test_key 8 | 9 | 10 | async def solve_parallel_coding( 11 | *, 12 | train_in: list[list[list[int]]], 13 | train_out: list[list[list[int]]], 14 | test_in: list[list[list[int]]], 15 | expert_configs: list[ExpertConfig], 16 | problem_id: str | None = None, 17 | ) -> list[ARCAGIResult]: 18 | """ 19 | Run multiple coding experts in parallel, group by identical test outputs, then rank. 20 | """ 21 | #assert len(expert_configs) > 1, "Need at least two expert configs." 22 | 23 | use_new_voting = expert_configs[0]["use_new_voting"] 24 | count_failed_matches = expert_configs[0]["count_failed_matches"] 25 | iters_tiebreak = expert_configs[0]["iters_tiebreak"] 26 | low_to_high_iters = expert_configs[0]["low_to_high_iters"] 27 | 28 | for it, cfg in enumerate(expert_configs): 29 | # Ensure each config gets a separate sequence of seeds. The code_solver 30 | # adds the current iteration to the seed at each iteration, so this 31 | # guarantees that each iteration of each code_solver gets a different 32 | # seed, assuming the configs all start with an identical seed. 33 | cfg["seed"] += it * cfg["max_iterations"] 34 | 35 | # Solve concurrently 36 | tasks = [ 37 | asyncio.create_task( 38 | solve_coding( 39 | train_in=train_in, 40 | train_out=train_out, 41 | test_in=test_in, 42 | config=cfg, 43 | problem_id=problem_id, 44 | ) 45 | ) 46 | for cfg in expert_configs 47 | ] 48 | results: list[ARCAGIResult] = await asyncio.gather(*tasks) 49 | 50 | # Buckets 51 | candidate_buckets: dict[str, list[ARCAGIResult]] = {} 52 | failure_buckets: dict[str, list[ARCAGIResult]] = {} 53 | 54 | for res in results: 55 | is_passer = all(rr.get("success", False) for rr in res.get("train_results", [])) 56 | key = canonical_test_key(res.get("results", [])) 57 | if is_passer: 58 | candidate_buckets.setdefault(key, []).append(res) 59 | else: 60 | failure_buckets.setdefault(key, []).append(res) 61 | 62 | if use_new_voting: 63 | # Optionally merge failures into passers if outputs match 64 | if count_failed_matches: 65 | for k in list(failure_buckets.keys()): 66 | if k in candidate_buckets: 67 | candidate_buckets[k].extend(failure_buckets[k]) 68 | del failure_buckets[k] 69 | 70 | # ---- Passers: sort by vote count desc; diversity-first ---- 71 | passer_groups: list[list[ARCAGIResult]] = list(candidate_buckets.values()) 72 | 73 | if iters_tiebreak: 74 | # Put the lowest (if low_to_high_iters) iterations in position 0 of each sublist. 75 | passer_groups = [ 76 | sorted(ps, key=lambda x: x['iteration'], reverse=not low_to_high_iters) for ps in passer_groups 77 | ] 78 | # Sort the list by min iterations, highest to lowest, so after the last sort below it is lowest to highest. 79 | passer_groups = sorted(passer_groups, key=lambda x: x[0]['iteration'], reverse=low_to_high_iters) 80 | 81 | # Sort passers by how many votes they have. 82 | passer_groups = sorted(passer_groups, key=len, reverse=True) 83 | 84 | ordered: list[ARCAGIResult] = [] 85 | # one per group for diversity 86 | ordered.extend([grp[0] for grp in passer_groups if grp]) 87 | 88 | # ---- Failures: grouped + ranked ---- 89 | # within each failure group, best first by mean soft_score desc 90 | for fs in failure_buckets.values(): 91 | fs.sort(key=_mean_soft, reverse=True) 92 | 93 | failure_groups: list[list[ARCAGIResult]] = list(failure_buckets.values()) 94 | # Sort groups: votes (desc), tie-break by best member's mean soft_score (desc) 95 | failure_groups.sort( 96 | key=lambda fs: (len(fs), _mean_soft(fs[0]) if fs else 0.0), 97 | reverse=True, 98 | ) 99 | 100 | # diversity-first over failure groups 101 | ordered.extend([fs[0] for fs in failure_groups if fs]) 102 | # remaining passer members 103 | ordered.extend([m for grp in passer_groups for m in grp[1:]]) 104 | # remaining failure members 105 | ordered.extend([m for fs in failure_groups for m in fs[1:]]) 106 | 107 | return ordered 108 | 109 | else: 110 | # ---- Old mode ---- 111 | # Passers by vote desc 112 | passer_groups: list[list[ARCAGIResult]] = sorted( 113 | candidate_buckets.values(), key=len, reverse=True 114 | ) 115 | 116 | firsts = [grp[0] for grp in passer_groups if grp] 117 | 118 | # Failures are flat, sorted by mean soft_score desc 119 | failed_flat: list[ARCAGIResult] = [ 120 | r for fs in failure_buckets.values() for r in fs 121 | ] 122 | failed_sorted = sorted(failed_flat, key=_mean_soft, reverse=True) 123 | 124 | rest = [m for grp in passer_groups for m in grp[1:]] 125 | 126 | return firsts + failed_sorted + rest 127 | 128 | 129 | def _mean_soft(res: ARCAGIResult) -> float: 130 | trs = res.get("train_results", []) 131 | if not trs: 132 | return 0.0 133 | return float(np.mean([rr.get("soft_score", 0.0) for rr in trs])) 134 | -------------------------------------------------------------------------------- /arc_agi/llm.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any 3 | 4 | import litellm 5 | from asynciolimiter import Limiter 6 | from litellm import acompletion 7 | from litellm import exceptions as litellm_exceptions 8 | 9 | from arc_agi.types import Models 10 | 11 | # Silence unnecessary litellm logs. 12 | litellm.suppress_debug_info = True 13 | 14 | RETRIES = 3 15 | RETRY_DELAY_SEC = 5 16 | 17 | limiters: dict[Models, Limiter] = { 18 | "groq/openai/gpt-oss-120b": Limiter(1.0), 19 | "openai/gpt-5": Limiter(1.0), 20 | "openai/gpt-5.1": Limiter(1.0), 21 | "xai/grok-4-fast": Limiter(1.0), 22 | "xai/grok-4": Limiter(1.0), 23 | "anthropic/claude-sonnet-4-5": Limiter(1.0), 24 | "anthropic/claude-haiku-4-5": Limiter(1.0), 25 | "gemini/gemini-2.5-pro": Limiter(2.0), 26 | "gemini/gemini-3-pro-preview": Limiter(1.0), 27 | } 28 | 29 | props: dict[Models, dict] = { 30 | "groq/openai/gpt-oss-120b": {}, 31 | "openai/gpt-5": {"reasoning_effort": "high"}, 32 | "openai/gpt-5.1": {"reasoning_effort": "high"}, 33 | "xai/grok-4-fast": {}, 34 | "xai/grok-4": {}, 35 | "anthropic/claude-sonnet-4-5": {"thinking": {"type": "enabled", "budget_tokens": 32_000}}, 36 | "anthropic/claude-haiku-4-5": {"thinking": {"type": "enabled", "budget_tokens": 32_000}}, 37 | "gemini/gemini-2.5-pro": {"thinking": {"type": "enabled", "budget_tokens": 16_000}}, 38 | "gemini/gemini-3-pro-preview": {}, 39 | } 40 | 41 | 42 | async def llm( 43 | model: Models, 44 | message: str, 45 | temperature, 46 | request_timeout: int | None, 47 | max_remaining_time: float | None, 48 | max_remaining_timeouts: int | None, 49 | problem_id: str | None = None, 50 | retries: int = RETRIES, 51 | ) -> tuple[str, float, float | None, int | None, int, int]: 52 | attempt = 1 53 | while attempt <= retries: 54 | await limiters[model].wait() 55 | 56 | current_request_timeout = request_timeout or 15 * 60 57 | if max_remaining_time is not None: 58 | current_request_timeout = min(current_request_timeout, max_remaining_time) 59 | 60 | start_time = asyncio.get_event_loop().time() 61 | try: 62 | resp: Any = await acompletion( 63 | model=model, 64 | messages=[{"role": "user", "content": message}], 65 | temperature=temperature, 66 | timeout=current_request_timeout, 67 | num_retries=0, 68 | **props[model], 69 | ) 70 | end_time = asyncio.get_event_loop().time() 71 | duration = end_time - start_time 72 | if max_remaining_time is not None: 73 | max_remaining_time -= duration 74 | 75 | prompt_tokens = resp.model_extra.get("usage").prompt_tokens 76 | completion_tokens = resp.model_extra.get("usage").completion_tokens 77 | 78 | return ( 79 | resp["choices"][0]["message"]["content"].strip(), 80 | duration, 81 | max_remaining_time, 82 | max_remaining_timeouts, 83 | prompt_tokens, 84 | completion_tokens 85 | ) 86 | 87 | except ( 88 | litellm_exceptions.RateLimitError, 89 | litellm_exceptions.InternalServerError, 90 | litellm_exceptions.ServiceUnavailableError, 91 | litellm_exceptions.APIConnectionError, 92 | litellm_exceptions.APIError, 93 | litellm.RouterRateLimitError, 94 | litellm.RouterRateLimitErrorBasic, 95 | ) as e: 96 | # None of these exceptions should prevent the problem from being solved, so don't let them count against the allotted retries. 97 | print(f"{problem_id or ''} Ignoring {type(e).__name__} and retrying attempt {attempt}: {e}") 98 | await asyncio.sleep(RETRY_DELAY_SEC) 99 | continue 100 | 101 | except Exception as e: 102 | end_time = asyncio.get_event_loop().time() 103 | duration = end_time - start_time 104 | if max_remaining_time is not None: 105 | max_remaining_time -= duration 106 | 107 | if "Timeout" in str(e): 108 | if max_remaining_timeouts is not None: 109 | max_remaining_timeouts -= 1 110 | print( 111 | f"{problem_id or ''} Timed out. Remaining timeouts: {max_remaining_timeouts}" 112 | ) 113 | if max_remaining_timeouts is not None and max_remaining_timeouts <= 0: 114 | raise RuntimeError("Exceeded timeouts allotted to the request") 115 | 116 | if attempt == retries: 117 | return ( 118 | "Timeout", 119 | duration, 120 | max_remaining_time, 121 | max_remaining_timeouts, 122 | 0, 123 | 0 124 | ) 125 | if max_remaining_time is not None and max_remaining_time <= 0: 126 | raise RuntimeError("Exceeded time allotted to the request") 127 | 128 | if attempt == retries: 129 | print(f"{problem_id or ''} Max retry limit reached. Last exception during call:") 130 | print(str(e)) 131 | raise e 132 | 133 | print(str(e)) 134 | print(f"Exception during request for problem: {problem_id or ''}. Retry number {attempt}.") 135 | await asyncio.sleep(RETRY_DELAY_SEC) 136 | 137 | # Increment attempt at the end of the loop. 138 | attempt += 1 139 | 140 | raise RuntimeError("Retries exceeded") 141 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import os 4 | import resource 5 | import time 6 | import traceback 7 | from datetime import datetime 8 | from typing import Optional 9 | 10 | from dotenv import load_dotenv 11 | 12 | from arc_agi.config import CONFIG_LIST 13 | from arc_agi.io import build_kaggle_two_attempts 14 | from arc_agi.scoring import score_task 15 | from arc_agi.solve import solve 16 | 17 | load_dotenv() 18 | 19 | 20 | # time the run started, so multiple runs don't collide 21 | TIMESTAMP = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 22 | 23 | # challenge input file 24 | DATA_CHALLENGES = os.path.join(os.path.dirname(__file__), "data", "arc-prize-2024", "arc-agi_evaluation_challenges.json") 25 | # optional challenge solution file 26 | DATA_SOLUTIONS = os.path.join(os.path.dirname(__file__), "data", "arc-prize-2024", "arc-agi_evaluation_solutions.json") 27 | # where to write outputs 28 | OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "output") 29 | OUTPUT = os.path.join(OUTPUT_DIR, f"submission_{TIMESTAMP}.json") 30 | 31 | # number of problems (None = all) 32 | NUM_PROBLEMS = None 33 | # select particular problems 34 | SELECTED_PROBLEMS = [] # e.g. ['b7999b51'] 35 | 36 | 37 | async def _eval_task_data(task_id: str, task: dict) -> tuple[str, Optional[list[dict]], Optional[dict], Optional[str], float]: 38 | """ 39 | Returns: (task_id, kaggle_preds | None on error, tokens | None on error, error, elapsed_seconds) 40 | """ 41 | start = time.time() 42 | try: 43 | train = task.get("train", []) 44 | test = task.get("test", []) 45 | train_in = [ex["input"] for ex in train] 46 | train_out = [ex["output"] for ex in train] 47 | test_in = [ex["input"] for ex in test] 48 | 49 | results = await solve(train_in, train_out, test_in, problem_id=task_id) 50 | kaggle_preds = build_kaggle_two_attempts(results, test_in) 51 | 52 | prompt_tokens = sum(r['prompt_tokens'] or 0 for r in results if r) 53 | completion_tokens = sum(r['completion_tokens'] or 0 for r in results if r) 54 | tokens = { 55 | "prompt": prompt_tokens, 56 | "completion": completion_tokens, 57 | "total": prompt_tokens + completion_tokens 58 | } 59 | 60 | return task_id, kaggle_preds, tokens, None, time.time() - start 61 | except Exception: 62 | return task_id, None, None, traceback.format_exc(), time.time() - start 63 | 64 | 65 | async def main(): 66 | # Ensure we don't run out of file handles 67 | # Get current soft and hard limits 68 | soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) 69 | # Set a new soft limit (cannot exceed hard limit) 70 | new_soft = 65536 71 | resource.setrlimit(resource.RLIMIT_NOFILE, (new_soft, hard)) 72 | 73 | os.makedirs(os.path.dirname(OUTPUT), exist_ok=True) 74 | 75 | print(f"Writing config_{TIMESTAMP}.json to output directory...") 76 | with open(os.path.join(OUTPUT_DIR, f"config_{TIMESTAMP}.json"), "w", encoding="utf-8") as f: 77 | json.dump(CONFIG_LIST, f, indent=4) 78 | 79 | # Load challenges 80 | with open(DATA_CHALLENGES, "r", encoding="utf-8") as f: 81 | challenges_blob: dict[str, dict] = json.load(f) 82 | 83 | # Load solutions if present; disable scoring if missing/unreadable 84 | solutions_blob: Optional[dict[str, list]] = None 85 | if DATA_SOLUTIONS and os.path.exists(DATA_SOLUTIONS): 86 | try: 87 | with open(DATA_SOLUTIONS, "r", encoding="utf-8") as f: 88 | solutions_blob = json.load(f) 89 | except Exception as e: 90 | print(f"WARNING: Could not load solutions file '{DATA_SOLUTIONS}': {e}\nScoring will be disabled.") 91 | 92 | items = list(challenges_blob.items()) 93 | if SELECTED_PROBLEMS: 94 | sel = set(SELECTED_PROBLEMS) 95 | items = [it for it in items if it[0] in sel] 96 | if NUM_PROBLEMS is not None: 97 | items = items[:NUM_PROBLEMS] 98 | 99 | 100 | print(f"Running {len(items)} problems from {DATA_CHALLENGES}...") 101 | print("Scoring:", "enabled" if solutions_blob is not None else "disabled (no solutions)") 102 | 103 | start = time.time() 104 | 105 | submission: dict[str, list[dict]] = {} 106 | tokens_data: dict[str, dict] = {} 107 | 108 | # running scores only if solutions available 109 | per_task_scores: dict[str, float] = {} 110 | total = 0 111 | correct = 0.0 112 | incorrect = 0.0 113 | 114 | tasks = [asyncio.create_task(_eval_task_data(task_id, task)) for task_id, task in items] 115 | 116 | for coro in asyncio.as_completed(tasks): 117 | task_id, preds, tokens, err, elapsed = await coro 118 | 119 | if err is not None or preds is None: 120 | print(f"! {task_id} (error in {round(elapsed)}s)\n{err}") 121 | submission[task_id] = [] 122 | else: 123 | submission[task_id] = preds 124 | if tokens: 125 | tokens_data[task_id] = tokens 126 | 127 | # running scores if solutions available 128 | if solutions_blob is not None and task_id in solutions_blob: 129 | gt_outputs = solutions_blob[task_id] 130 | task_score = score_task(preds, gt_outputs) 131 | per_task_scores[task_id] = task_score 132 | total += 1 133 | correct += task_score 134 | incorrect += 1 - task_score 135 | mark = "✓" if task_score == 1.0 else "✗" 136 | print(f"{mark} {task_id} ({round(elapsed)}s) [{correct}/{total}]") 137 | else: 138 | print(f"· {task_id} ({round(elapsed)}s)") 139 | 140 | # write cumulative Kaggle output after each task 141 | try: 142 | with open(OUTPUT, "w", encoding="utf-8") as f: 143 | json.dump(submission, f) 144 | with open(os.path.join(OUTPUT_DIR, f"tokens_{TIMESTAMP}.json"), "w", encoding="utf-8") as f: 145 | json.dump(tokens_data, f) 146 | except Exception as e: 147 | print(f"WARNING: Failed to write partial output to {OUTPUT}: {e}") 148 | 149 | total_time = time.time() - start 150 | 151 | print("\n=== Summary ===") 152 | print(f"Data file: {DATA_CHALLENGES}") 153 | print(f"Problems: {len(items)}") 154 | if solutions_blob is not None and per_task_scores: 155 | acc = correct / total 156 | print(f"Correct: {correct}") 157 | print(f"Incorrect: {incorrect}") 158 | print(f"Accuracy: {acc * 100:.3f}") 159 | else: 160 | print("Scoring: disabled or no tasks matched in solutions.") 161 | print(f"Total time: {round(total_time)}s") 162 | 163 | # final write just in case 164 | try: 165 | with open(OUTPUT, "w", encoding="utf-8") as f: 166 | json.dump(submission, f) 167 | print(f"\nWrote Kaggle submission to: {OUTPUT}") 168 | with open(os.path.join(OUTPUT_DIR, f"tokens_{TIMESTAMP}.json"), "w", encoding="utf-8") as f: 169 | json.dump(tokens_data, f) 170 | print(f"Wrote token usage to: {os.path.join(OUTPUT_DIR, f'tokens_{TIMESTAMP}.json')}") 171 | except Exception as e: 172 | print(f"ERROR: Final write to {OUTPUT} failed: {e}") 173 | 174 | if __name__ == "__main__": 175 | asyncio.run(main()) 176 | -------------------------------------------------------------------------------- /data/arc-prize-2024/sample_submission.json: -------------------------------------------------------------------------------- 1 | {"007bbfb7": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "00d62c1b": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "017c7c7b": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "025d127b": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "045e512c": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0520fde7": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "05269061": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "05f2a901": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "06df4c85": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "08ed6ac7": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "09629e4f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0962bcdd": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0a938d79": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0b148d64": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0ca9ddb6": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0d3d703e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0dfd9992": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0e206a2e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "10fcaaa3": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "11852cab": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1190e5a7": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "137eaa0f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "150deff5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "178fcbfb": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1a07d186": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1b2d62fb": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1b60fb0c": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1bfc4729": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1c786137": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1caeab9d": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1cf80156": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1e0a9b12": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1e32b0e9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1f0c79e5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1f642eb9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1f85a75f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1f876c06": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1fad071e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2013d3e2": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2204b7a8": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "22168020": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "22233c11": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2281f1f4": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "228f6490": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "22eb0ac0": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "234bbc79": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "23581191": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "239be575": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "23b5c85d": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "253bf280": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "25d487eb": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "25d8a9c8": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "25ff71a9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "264363fd": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "272f95fa": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "27a28665": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "28bf18c6": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "28e73c20": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "29623171": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "29c11459": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "29ec7d0e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2bcee788": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2bee17df": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2c608aff": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2dc579da": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2dd70a9a": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2dee498d": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "31aa019c": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "321b1fc6": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "32597951": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3345333e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3428a4f5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3618c87e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3631a71a": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "363442ee": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "36d67576": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "36fdfd69": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3906de3d": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "39a8645d": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "39e1d7f9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3aa6fb7a": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3ac3eb23": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3af2c5a8": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3bd67248": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3bdb4ada": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3befdf3e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3c9b0459": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3de23699": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3e980e27": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3eda0437": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3f7978a0": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "40853293": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "4093f84a": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "41e4d17e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "4258a5f9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "4290ef0e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "42a50994": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "4347f46a": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "444801d8": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "445eab21": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}]} -------------------------------------------------------------------------------- /arc_agi/solve_coding.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import string 4 | from typing import Any, Optional 5 | 6 | import numpy as np 7 | 8 | from arc_agi.llm import llm 9 | from arc_agi.sandbox import run 10 | from arc_agi.types import ARCAGIResult, ARCAGISolution, ExpertConfig, RunResult 11 | 12 | 13 | async def solve_coding( 14 | *, 15 | train_in: list[list[list[int]]], 16 | train_out: list[list[list[int]]], 17 | test_in: list[list[list[int]]], 18 | config: ExpertConfig, 19 | problem_id: str | None = None, 20 | ) -> ARCAGIResult: 21 | solver_prompt = config["solver_prompt"] 22 | feedback_prompt = config["feedback_prompt"] 23 | llm_model = config["llm_id"] 24 | max_iterations = int(config["max_iterations"]) 25 | solver_temperature = float(config["solver_temperature"]) 26 | max_solutions = int(config.get("max_solutions")) 27 | selection_probability = float(config.get("selection_probability")) 28 | seed = int(config.get("seed")) 29 | timeout_sandbox = float(config.get("timeout_s", 5)) 30 | shuffle_examples = bool(config.get("shuffle_examples")) 31 | improving_order = bool(config.get("improving_order")) 32 | return_best = bool(config.get("return_best_result")) 33 | request_timeout = config.get("request_timeout") 34 | max_total_timeouts = config.get("max_total_timeouts") 35 | max_total_time = config.get("max_total_time") 36 | per_iteration_retries = config.get("per_iteration_retries") 37 | 38 | total_prompt_tokens = 0 39 | total_completion_tokens = 0 40 | 41 | best_train_score = -1.0 42 | best_result: Optional[ARCAGIResult] = None 43 | last_train: list[RunResult] = [ 44 | RunResult( 45 | success=False, 46 | output="", 47 | soft_score=0.0, 48 | error="Unexpected use of initial empty train result", 49 | code="", 50 | ) 51 | ] 52 | last_test: Optional[list[RunResult]] = None 53 | 54 | rng = np.random.default_rng(seed) 55 | solutions: list[ARCAGISolution] = [] 56 | 57 | for it in range(max_iterations): 58 | example = _make_example(train_in, train_out, test_in) 59 | problem_str = format_problem(example, shuffle_examples, seed + it) 60 | message = _build_prompt(solver_prompt, problem=problem_str) 61 | 62 | selected = [] 63 | if solutions: 64 | mask = rng.uniform(size=len(solutions)) < selection_probability 65 | selected = [s for s, keep in zip(solutions, mask, strict=False) if keep] 66 | 67 | if selected: 68 | examples_block = create_examples( 69 | selected, max_examples=max_solutions, improving_order=improving_order 70 | ) 71 | message += "\n\n" + _build_prompt(feedback_prompt, feedback=examples_block) 72 | 73 | try: 74 | response, duration, max_total_time, max_total_timeouts, prompt_tokens, completion_tokens = await llm( 75 | llm_model, 76 | message=message, 77 | temperature=solver_temperature, 78 | request_timeout=request_timeout, 79 | max_remaining_time=max_total_time, 80 | max_remaining_timeouts=max_total_timeouts, 81 | problem_id=problem_id, 82 | retries=per_iteration_retries, 83 | ) 84 | total_prompt_tokens += prompt_tokens 85 | total_completion_tokens += completion_tokens 86 | except Exception as e: 87 | if "Exceeded timeouts allotted to the request" in str(e) or "Exceeded time allotted to the request" in str(e): 88 | # Exceeded max_remaining_timeouts or max_remaining_time 89 | print("Exiting early due to exceeding allotted time or timeouts on problem", problem_id) 90 | break 91 | # Just exceeded per_iteration_retries, so try the next iteration 92 | continue 93 | 94 | code = _parse_code_from_llm(response) 95 | if not code: 96 | continue 97 | 98 | train_res, test_res = await _eval_on_train_and_test( 99 | code, train_in, train_out, test_in, timeout_s=timeout_sandbox 100 | ) 101 | 102 | last_train, last_test = train_res, test_res 103 | 104 | if all(r["success"] for r in train_res): 105 | return ARCAGIResult( 106 | train_results=train_res, 107 | results=test_res, 108 | iteration=it + 1, 109 | prompt_tokens=total_prompt_tokens, 110 | completion_tokens=total_completion_tokens, 111 | ) 112 | 113 | feedback, score = _build_feedback(train_res, train_in, train_out) 114 | solutions.append(ARCAGISolution(code=code, feedback=feedback, score=score)) 115 | 116 | if score >= best_train_score: 117 | best_train_score = score 118 | best_result = ARCAGIResult( 119 | train_results=train_res, 120 | results=test_res, 121 | iteration=it + 1, 122 | prompt_tokens=None, 123 | completion_tokens=None, 124 | ) 125 | 126 | if return_best and best_result is not None: 127 | best_result['prompt_tokens'] = total_prompt_tokens 128 | best_result['completion_tokens'] = total_completion_tokens 129 | return best_result 130 | if last_test is None: 131 | last_test = [ 132 | RunResult( 133 | success=False, 134 | output="", 135 | soft_score=0.0, 136 | error="Failed to generate any valid solutions.", 137 | code="", 138 | ) 139 | ] 140 | return ARCAGIResult( 141 | train_results=last_train, 142 | results=last_test, 143 | iteration=max_iterations, 144 | prompt_tokens=total_prompt_tokens, 145 | completion_tokens=total_completion_tokens, 146 | ) 147 | 148 | 149 | def create_examples(solutions, max_examples=3, improving_order: bool = False): 150 | template = string.Template(""" 151 | 152 | 153 | ```python 154 | $code 155 | ``` 156 | 157 | 158 | $feedback 159 | 160 | 161 | $score 162 | 163 | 164 | """) 165 | if not solutions: 166 | return "" 167 | scores = [x["score"] for x in solutions] 168 | inds = np.argsort(scores)[::-1] 169 | inds = inds[: min(max_examples, len(inds))] 170 | if improving_order: 171 | inds = inds[::-1] 172 | 173 | blocks: list[str] = [] 174 | for k, idx in enumerate(inds, start=1): 175 | e = solutions[idx] 176 | blocks.append( 177 | template.substitute( 178 | index=k, 179 | code=e["code"], 180 | feedback=e["feedback"], 181 | score=f"{e['score']:.2f}", 182 | ) 183 | ) 184 | return "\n".join(blocks) 185 | 186 | 187 | def _build_prompt(base_prompt: str, **fields: str) -> str: 188 | s = base_prompt 189 | for k, v in fields.items(): 190 | s = s.replace(f"$${k}$$", v) 191 | return s 192 | 193 | 194 | def _array_diff(arr1: np.ndarray, arr2: np.ndarray) -> str: 195 | rows, cols = arr1.shape 196 | out = [] 197 | for i in range(rows): 198 | row = [] 199 | for j in range(cols): 200 | if arr1[i, j] == arr2[i, j]: 201 | row.append(str(int(arr1[i, j]))) 202 | else: 203 | row.append(f"{int(arr1[i, j])}/{int(arr2[i, j])}") 204 | out.append(" ".join(row)) 205 | return "\n".join(out) 206 | 207 | 208 | def _parse_code_from_llm(response: str) -> Optional[str]: 209 | m = re.search(r"```python\s*(.*?)```", response, re.DOTALL | re.IGNORECASE) 210 | return m.group(1) if m else None 211 | 212 | 213 | def _soft_score(pred: np.ndarray, truth: np.ndarray) -> float: 214 | if pred.shape != truth.shape: 215 | return 0.0 216 | if truth.size == 0: 217 | return 1.0 218 | raw = np.mean(pred == truth) 219 | return float(np.nan_to_num(raw, posinf=0.0, neginf=0.0)) 220 | 221 | 222 | def _json_to_ndarray(s: str) -> Optional[np.ndarray]: 223 | try: 224 | obj = json.loads(s) 225 | arr = np.array(obj) 226 | if arr.ndim < 2: 227 | arr = np.expand_dims(arr, axis=list(range(2 - arr.ndim))) 228 | return arr.astype(int, copy=False) 229 | except Exception: 230 | return None 231 | 232 | 233 | def _make_example(train_in, train_out, test_in) -> dict[str, Any]: 234 | train = [ 235 | {"input": iin, "output": oout} 236 | for iin, oout in zip(train_in, train_out, strict=True) 237 | ] 238 | test = [{"input": iin} for iin in test_in] 239 | return {"train": train, "test": test} 240 | 241 | 242 | def format_problem( 243 | problem: dict[str, Any], 244 | shuffle: bool = False, 245 | seed: Optional[int] = None, 246 | ) -> str: 247 | train = list(problem["train"]) 248 | test = list(problem["test"]) 249 | 250 | if shuffle and len(train) > 1: 251 | rng = np.random.default_rng(seed if seed is not None else 0) 252 | perm = rng.permutation(len(train)) 253 | train = [train[i] for i in perm] 254 | 255 | example_str = "" 256 | challenge_str = "" 257 | 258 | for example_num, example in enumerate(train, start=1): 259 | example_str += f""" 260 | Example #{example_num} 261 | Input: 262 | 263 | {_example_to_diagram(example["input"])} 264 | 265 | 266 | Output: 267 | 268 | {_example_to_diagram(example["output"])} 269 | 270 | """ 271 | 272 | for challenge_num, challenge in enumerate(test, start=1): 273 | challenge_str += f""" 274 | Challenge #{challenge_num} 275 | Input: 276 | 277 | {_example_to_diagram(challenge["input"])} 278 | 279 | """ 280 | 281 | return example_str + challenge_str 282 | 283 | 284 | def _example_to_diagram(example: list[list[int]] | np.ndarray) -> str: 285 | """Converts an ARC-AGI example (list of lists) to a diagram (ascii grid).""" 286 | diagram = "" 287 | for row in example: 288 | row_str = " ".join([str(col) for col in row]) + "\n" 289 | diagram += row_str 290 | return diagram[:-1] # Strip final \n 291 | 292 | 293 | async def _eval_on_train_and_test( 294 | code: str, 295 | train_in: list[list[list[int]]], 296 | train_out: list[list[list[int]]], 297 | test_in: list[list[list[int]]], 298 | *, 299 | timeout_s: float = 1.5, 300 | ) -> tuple[list[RunResult], list[RunResult]]: 301 | # Train 302 | train_results: list[RunResult] = [] 303 | for i, (iin, oout) in enumerate(zip(train_in, train_out, strict=True)): 304 | ok, out_str = await run(code, iin, timeout_s=timeout_s) 305 | success = False 306 | soft = 0.0 307 | err: Optional[str] = None 308 | if not ok: 309 | err = out_str or "Execution failed." 310 | else: 311 | arr = _json_to_ndarray(out_str) 312 | if arr is None: 313 | err = ( 314 | f"Failed to parse output as JSON 2D array.\nOutput was:\n{out_str}" 315 | ) 316 | else: 317 | truth = np.array(oout) 318 | success = bool(arr.shape == truth.shape and np.array_equal(arr, truth)) 319 | soft = _soft_score(arr, truth) 320 | train_results.append( 321 | RunResult(success=success, output=out_str, soft_score=soft, error=err, code=code) 322 | ) 323 | 324 | # Test 325 | test_results: list[RunResult] = [] 326 | for i, iin in enumerate(test_in): 327 | ok, out_str = await run(code, iin, timeout_s=timeout_s) 328 | err = None if ok else (out_str or "Execution failed.") 329 | test_results.append( 330 | RunResult(success=False, output=out_str, soft_score=0.0, error=err, code=code) 331 | ) 332 | return train_results, test_results 333 | 334 | 335 | def _parse_json_array_no_expand(s: str) -> Optional[np.ndarray]: 336 | """Parse JSON into a NumPy array without changing rank or dtype.""" 337 | try: 338 | return np.array(json.loads(s)) 339 | except Exception: 340 | return None 341 | 342 | 343 | def _build_feedback( 344 | train_results: list[RunResult], train_in, train_out 345 | ) -> tuple[str, float]: 346 | feedback_parts: list[str] = [] 347 | per_example_scores: list[float] = [] 348 | 349 | for i, rr in enumerate(train_results): 350 | if rr["success"]: 351 | feedback_parts.append(f"Solves Example #{i + 1} correctly. ") 352 | per_example_scores.append(1.0) 353 | continue 354 | 355 | msg_lines: list[str] = [f"Solves Example #{i + 1} incorrectly. "] 356 | 357 | pred_raw = _parse_json_array_no_expand(rr["output"]) if rr["output"] else None 358 | truth = np.array(train_out[i]) 359 | 360 | if pred_raw is None: 361 | per_example_scores.append(0.0) 362 | msg_lines.append("\nThe output has to be a rectangular grid of numbers.\n") 363 | else: 364 | pred_for_display = pred_raw 365 | if pred_for_display.ndim < 2: 366 | pred_for_display = np.expand_dims( 367 | pred_for_display, axis=list(range(2 - pred_for_display.ndim)) 368 | ) 369 | 370 | if pred_raw.shape != truth.shape: 371 | per_example_scores.append(0.0) 372 | msg_lines.append( 373 | f"\n\nShape mismatch: your prediction's shape was {pred_raw.shape}, " 374 | f"while the correct shape was {truth.shape}." 375 | ) 376 | else: 377 | # Same shape: show diff grid and compute soft score. 378 | msg_lines.append( 379 | "\nYour code's output does not match the expected output." 380 | "\n\nBelow is a visualization of the 2D array your code produced as well as the expected output.\n" 381 | "Correctly predicted values are shown as-is while the incorrectly predicted values are shown " 382 | "in the format 'prediction/correct':\n" 383 | ) 384 | diff = _array_diff(pred_for_display, truth) 385 | msg_lines.append(f"\n```\n{diff}\n```\n") 386 | 387 | example_score = float(np.mean(pred_raw == truth)) 388 | example_score = float( 389 | np.nan_to_num(example_score, posinf=0.0, neginf=0.0) 390 | ) 391 | per_example_scores.append(example_score) 392 | msg_lines.append( 393 | f"Output accuracy: {example_score:.2f} (0 is worst, 1 is best).\n" 394 | ) 395 | 396 | if rr["error"]: 397 | msg_lines.append( 398 | f"\n\nYour code produced the following error:\n{rr['error']}\n" 399 | ) 400 | 401 | feedback_parts.append("".join(msg_lines)) 402 | 403 | full_feedback = "\n\n".join(feedback_parts) 404 | mean_score = ( 405 | float(np.mean(np.nan_to_num(per_example_scores, posinf=0.0, neginf=0.0))) 406 | if per_example_scores 407 | else 0.0 408 | ) 409 | return full_feedback, mean_score 410 | -------------------------------------------------------------------------------- /arc_agi/prompts.py: -------------------------------------------------------------------------------- 1 | SOLVER_PROMPT_1 = ''' 2 | You are an expert in solving Abstract Reasoning Corpus (ARC) tasks by writing Python code. Your goal is to analyze input-output examples and create a 'transform' function that correctly transforms any given input grid into the corresponding output grid. 3 | 4 | Here's how to approach the problem: 5 | 6 | **1. Analyze the Examples:** 7 | * Identify the key objects in the input and output grids (e.g., shapes, lines, regions). 8 | * Determine the relationships between these objects (e.g., spatial arrangement, color, size). 9 | * Identify the operations that transform the input objects and relationships into the output objects and relationships (e.g., rotation, reflection, color change, object addition/removal). 10 | * Consider the grid dimensions, symmetries, and other visual features. 11 | 12 | **2. Formulate a Hypothesis:** 13 | * Based on your analysis, formulate a transformation rule that works consistently across all examples. 14 | * Express the rule as a sequence of image manipulation operations. 15 | * Prioritize simpler rules first. 16 | * Consider these types of transformations: 17 | * **Object Manipulation:** Moving, rotating, reflecting, or resizing objects. 18 | * **Color Changes:** Changing the color of specific objects or regions. 19 | * **Spatial Arrangements:** Rearranging the objects in a specific pattern. 20 | * **Object Addition/Removal:** Adding or removing objects based on certain criteria. 21 | 22 | **3. Implement the Code:** 23 | * Write a Python function called `transform(grid: np.ndarray) -> np.ndarray` that implements your transformation rule. 24 | * Use NumPy for array manipulations. Other standard libraries are also available. 25 | * Write modular code with clear variable names and comments to explain the logic behind each step. 26 | * Document your code clearly, explaining the transformation rule in the docstring. 27 | * Handle edge cases and invalid inputs gracefully. 28 | 29 | **4. Test and Refine:** 30 | * Test your code on all examples. If it fails for any example, refine your hypothesis and code. 31 | * Use debugging techniques to identify and fix errors. 32 | * Ensure your code handles edge cases and invalid inputs gracefully. 33 | 34 | **5. Output:** 35 | * Provide a brief explanation of your solution. 36 | * Include the complete Python code for the `transform` function within a single markdown code block. 37 | * Do not include any `__name__ == "__main__"` block or any code outside the function definition. 38 | 39 | **Examples:** 40 | 41 | **Example 1:** 42 | 43 | **Input:** 44 | ``` 45 | [[1, 1, 1], 46 | [1, 0, 1], 47 | [1, 1, 1]] 48 | ``` 49 | 50 | **Output:** 51 | ``` 52 | [[0, 0, 0], 53 | [0, 1, 0], 54 | [0, 0, 0]] 55 | ``` 56 | 57 | **Explanation:** 58 | Replace the border with 0s. 59 | 60 | **Code:** 61 | ```python 62 | import numpy as np 63 | 64 | def transform(grid: np.ndarray) -> np.ndarray: 65 | """Replace the border with 0s.""" 66 | grid[0, :] = 0 67 | grid[-1, :] = 0 68 | grid[:, 0] = 0 69 | grid[:, -1] = 0 70 | return grid 71 | ``` 72 | 73 | **Example 2:** 74 | 75 | **Input:** 76 | ``` 77 | [[1, 2, 3], 78 | [4, 5, 6], 79 | [7, 8, 9]] 80 | ``` 81 | 82 | **Output:** 83 | ``` 84 | [[9, 8, 7], 85 | [6, 5, 4], 86 | [3, 2, 1]] 87 | ``` 88 | 89 | **Explanation:** 90 | Reverse the order of elements in each row and then reverse the order of the rows themselves. 91 | 92 | **Code:** 93 | ```python 94 | import numpy as np 95 | 96 | def transform(grid: np.ndarray) -> np.ndarray: 97 | """Reverses the order of elements in each row and then reverses the order of the rows.""" 98 | new_grid = grid[:, ::-1][::-1] 99 | return new_grid 100 | ``` 101 | 102 | **Example 3:** 103 | 104 | **Input:** 105 | ``` 106 | [[0, 0, 0, 0, 0], 107 | [0, 1, 1, 1, 0], 108 | [0, 1, 0, 1, 0], 109 | [0, 1, 1, 1, 0], 110 | [0, 0, 0, 0, 0]] 111 | ``` 112 | 113 | **Output:** 114 | ``` 115 | [[0, 0, 0, 0, 0], 116 | [0, 0, 0, 0, 0], 117 | [0, 0, 1, 0, 0], 118 | [0, 0, 0, 0, 0], 119 | [0, 0, 0, 0, 0]] 120 | ``` 121 | 122 | **Explanation:** 123 | Keep only the center pixel if it is 1, otherwise make the grid all zeros. 124 | 125 | **Code:** 126 | ```python 127 | import numpy as np 128 | 129 | def transform(grid: np.ndarray) -> np.ndarray: 130 | """Keep only the center pixel if it is 1, otherwise make the grid all zeros.""" 131 | center_row, center_col = grid.shape[0] // 2, grid.shape[1] // 2 132 | if grid[center_row, center_col] == 1: 133 | new_grid = np.zeros_like(grid) 134 | new_grid[center_row, center_col] = 1 135 | return new_grid 136 | else: 137 | return np.zeros_like(grid) 138 | ``` 139 | 140 | **PROBLEM:** 141 | 142 | Below is a textual representation of the input-output examples and the challenge to be solved. 143 | 144 | $$problem$$ 145 | ''' 146 | 147 | SOLVER_PROMPT_2 = ''' 148 | You are a world-class expert in solving Abstract Reasoning Corpus (ARC) tasks. Your approach is methodical, creative, and highly effective. You are also a master Python coder, producing elegant, efficient, and well-documented solutions. 149 | 150 | Your goal is to analyze a set of input-output examples and devise a Python function that accurately transforms any input grid into its corresponding output grid. The key is to identify a *single, consistent transformation rule* that generalizes across *all* examples. Do not give up until you find a correct solution. 151 | 152 | Follow this iterative process: 153 | 154 | **Part 1: Initial Analysis and Hypothesis Generation** 155 | 156 | 1. **Example Inspection:** Carefully examine the input and output grids for each example. Note their dimensions, color palettes, and any prominent visual features (shapes, symmetries, patterns). Use visualization techniques to aid your analysis. 157 | 2. **Transformation Hypotheses:** Formulate several candidate transformation rules. Start with simpler rules and gradually increase complexity. Consider these categories: 158 | * **Color Transformations:** Replacing colors based on specific criteria (e.g., adjacency, frequency). For example, replace all 0s with 1s, or replace the most frequent color with the least frequent color. 159 | * **Object Isolation:** Identifying and isolating objects based on color, shape, or position. For example, extract the largest connected component of a certain color, or isolate objects based on their spatial relationships. 160 | * **Spatial Operations:** Rotating, reflecting, resizing, or moving objects. For example, rotate the grid by 90 degrees, reflect the grid horizontally or vertically, or resize the grid by a certain factor. 161 | * **Pattern Generation:** Replicating or extending existing patterns. For example, repeat a certain pattern across the grid, or generate a new pattern based on the existing patterns. 162 | 3. **Symmetry Analysis:** Identify any symmetries (rotational, reflectional) in the input and output grids. Determine if the transformation preserves or alters these symmetries. 163 | 164 | **Part 2: Iterative Testing and Refinement** 165 | 166 | 1. **Code Implementation:** Implement your strongest candidate rule as a Python function. The function *must* accept a 2D numpy array as input and return a 2D numpy array as output. 167 | 2. **Rigorous Testing:** Test your code against *all* training examples. A single failure indicates an incorrect rule. 168 | 3. **Feedback Analysis:** If your code fails, carefully analyze the feedback. Identify the specific examples that failed and the nature of the errors. Use print statements to debug intermediate values and verify your assumptions. 169 | 4. **Hypothesis Refinement:** Based on the feedback, refine your transformation rule. This may involve adjusting parameters, adding new conditions, or discarding the rule altogether and starting with a new hypothesis. 170 | 5. **Repeat:** Continue this iterative process of coding, testing, and refining until you find a rule that works for all training examples. Do not give up until you find a correct solution. 171 | 172 | **Part 3: Coding Guidelines** 173 | 174 | 1. **Available Libraries:** You can use `numpy`, `cv2` (OpenCV), and any library from the standard Python library. 175 | 2. **Computer Vision Techniques:** Consider using `cv2` for tasks involving object detection, edge detection, or image filtering. 176 | 3. **Utility Functions:** Write reusable utility functions to improve code modularity and readability. 177 | 4. **Error Handling:** Implement robust error handling to gracefully manage edge cases and invalid inputs. 178 | 5. **Code Clarity:** Write clean, well-documented code with meaningful variable names and comments. 179 | 180 | **Part 4: Output Requirements** 181 | 182 | 1. **Output Format:** 183 | * Begin with a concise paragraph explaining the proposed solution, followed by a Python code section. 184 | * You *must* provide a code output representing your best attempt. Do not give up or refuse to produce code. 185 | * **The code section must be a single, valid Python code block in markdown fenced code block format and nothing else.** 186 | * The main transform function must have the signature `def transform(grid: np.ndarray) -> np.ndarray`. 187 | * Document the transformation rule implemented in the docstring of the transform function. 188 | * Do not include any `__name__ == "__main__"` block. This will be added later by the user. You are writing a library function. 189 | 190 | **Example:** 191 | 192 | **Problem:** 193 | Input: 194 | 195 | 0 0 1 196 | 0 1 0 197 | 1 0 0 198 | 199 | 200 | Output: 201 | 202 | 1 1 1 203 | 1 1 1 204 | 1 1 1 205 | 206 | 207 | **Explanation:** 208 | Replace all 0s with 1s. 209 | 210 | ```python 211 | import numpy as np 212 | 213 | def transform(grid: np.ndarray) -> np.ndarray: 214 | """Replace all 0s with 1s.""" 215 | return np.where(grid == 0, 1, grid) 216 | ``` 217 | 218 | **PROBLEM:** 219 | 220 | Below is a textual representation of the input-output examples and the challenge to be solved. 221 | 222 | $$problem$$ 223 | ''' 224 | 225 | SOLVER_PROMPT_3 = ''' 226 | You are a world-class expert in solving Abstract Reasoning Corpus (ARC) tasks. Your approach is methodical, creative, and highly effective. You are also a master Python coder, producing elegant, efficient, and well-documented solutions. 227 | 228 | Your goal is to analyze a set of input-output examples and devise a Python function that accurately transforms any input grid into its corresponding output grid. The key is to identify a *single, consistent transformation rule* that generalizes across *all* examples. Do not give up until you find a correct solution. 229 | 230 | Follow this iterative process: 231 | 232 | **Part 1: Initial Analysis and Hypothesis Generation** 233 | 234 | 1. **Example Inspection:** Carefully examine the input and output grids for each example. Note their dimensions, color palettes, and any prominent visual features (shapes, symmetries, patterns). Use visualization techniques to aid your analysis. 235 | 2. **Transformation Hypotheses:** Formulate several candidate transformation rules. Start with simpler rules and gradually increase complexity. Consider these categories: 236 | * **Color Transformations:** Replacing colors based on specific criteria (e.g., adjacency, frequency). For example, replace all 0s with 1s, or replace the most frequent color with the least frequent color. 237 | * **Object Isolation:** Identifying and isolating objects based on color, shape, or position. For example, extract the largest connected component of a certain color, or isolate objects based on their spatial relationships. 238 | * **Spatial Operations:** Rotating, reflecting, resizing, or moving objects. For example, rotate the grid by 90 degrees, reflect the grid horizontally or vertically, or resize the grid by a certain factor. 239 | * **Pattern Generation:** Replicating or extending existing patterns. For example, repeat a certain pattern across the grid, or generate a new pattern based on the existing patterns. 240 | 3. **Symmetry Analysis:** Identify any symmetries (rotational, reflectional) in the input and output grids. Determine if the transformation preserves or alters these symmetries. 241 | 242 | **Part 2: Iterative Testing and Refinement** 243 | 244 | 1. **Code Implementation:** Implement your strongest candidate rule as a Python function. The function *must* accept a 2D numpy array as input and return a 2D numpy array as output. 245 | 2. **Rigorous Testing:** Test your code against *all* training examples. A single failure indicates an incorrect rule. 246 | 3. **Feedback Analysis:** If your code fails, carefully analyze the feedback. Identify the specific examples that failed and the nature of the errors. Use print statements to debug intermediate values and verify your assumptions. 247 | 4. **Hypothesis Refinement:** Based on the feedback, refine your transformation rule. This may involve adjusting parameters, adding new conditions, or discarding the rule altogether and starting with a new hypothesis. 248 | 5. **Repeat:** Continue this iterative process of coding, testing, and refining until you find a rule that works for all training examples. Do not give up until you find a correct solution. 249 | 250 | **Part 3: Coding Guidelines** 251 | 252 | 1. **Available Libraries:** You can use `numpy`, `cv2` (OpenCV), and any library from the standard Python library. 253 | 2. **Computer Vision Techniques:** Consider using `cv2` for tasks involving object detection, edge detection, or image filtering. 254 | 3. **Utility Functions:** Write reusable utility functions to improve code modularity and readability. 255 | 4. **Error Handling:** Implement robust error handling to gracefully manage edge cases and invalid inputs. 256 | 5. **Code Clarity:** Write clean, well-documented code with meaningful variable names and comments. The code should be as concise as possible. 257 | 258 | **Part 4: Output Requirements** 259 | 260 | 1. **Output Format:** 261 | * Begin with a concise paragraph explaining the proposed solution, followed by a Python code section. 262 | * You *must* provide a code output representing your best attempt. Do not give up or refuse to produce code. 263 | * **The code section must be a single, valid Python code block in markdown fenced code block format and nothing else.** 264 | * The main transform function must have the signature `def transform(grid: np.ndarray) -> np.ndarray`. 265 | * Document the transformation rule implemented in the docstring of the transform function. 266 | * Do not include any `__name__ == "__main__"` block. This will be added later by the user. You are writing a library function. 267 | 268 | **Example 1:** 269 | 270 | **Problem:** 271 | Input: 272 | 273 | 0 0 1 274 | 0 1 0 275 | 1 0 0 276 | 277 | 278 | Output: 279 | 280 | 1 1 1 281 | 1 1 1 282 | 1 1 1 283 | 284 | 285 | **Explanation:** 286 | Replace all 0s with 1s. 287 | 288 | ```python 289 | import numpy as np 290 | 291 | def transform(grid: np.ndarray) -> np.ndarray: 292 | """Replace all 0s with 1s.""" 293 | return np.where(grid == 0, 1, grid) 294 | ``` 295 | 296 | **Example 2:** 297 | 298 | **Problem:** 299 | Input: 300 | 301 | 0 0 0 302 | 0 1 0 303 | 0 0 0 304 | 305 | 306 | Output: 307 | 308 | 0 1 0 309 | 1 1 1 310 | 0 1 0 311 | 312 | 313 | **Explanation:** 314 | Replace all neighbors of 1 with 1. 315 | 316 | ```python 317 | import numpy as np 318 | 319 | def transform(grid: np.ndarray) -> np.ndarray: 320 | """Replace all neighbors of 1 with 1.""" 321 | new_grid = grid.copy() 322 | for i in range(1, grid.shape[0] - 1): 323 | for j in range(1, grid.shape[1] - 1): 324 | if grid[i][j] == 1: 325 | new_grid[i-1][j] = 1 326 | new_grid[i+1][j] = 1 327 | new_grid[i][j-1] = 1 328 | new_grid[i][j+1] = 1 329 | return new_grid 330 | ``` 331 | 332 | **Example 3:** 333 | 334 | **Problem:** 335 | Input: 336 | 337 | 1 2 3 338 | 4 5 6 339 | 7 8 9 340 | 341 | 342 | Output: 343 | 344 | 9 8 7 345 | 6 5 4 346 | 3 2 1 347 | 348 | 349 | **Explanation:** 350 | Reverse the grid. 351 | 352 | ```python 353 | import numpy as np 354 | 355 | def transform(grid: np.ndarray) -> np.ndarray: 356 | """Reverse the grid.""" 357 | return np.flip(grid) 358 | ``` 359 | 360 | **PROBLEM:** 361 | 362 | Below is a textual representation of the input-output examples and the challenge to be solved. 363 | 364 | $$problem$$ 365 | ''' 366 | 367 | FEEDBACK_PROMPT = ''' 368 | **EXISTING PARTIAL/INCORRECT SOLUTIONS:** 369 | 370 | Following are some of the best, though not completely correct, solutions so far. For each solution, its code, corresponding feedback regarding its output on the example problems, and a numeric score between 0. (worst) and 1. (best) indicating the quality of outputs is also provided. Study these solutions and corresponding feedback and produce a new solution fixing all the issues. Make sure to follow the output format specified earlier. 371 | 372 | $$feedback$$ 373 | ''' 374 | -------------------------------------------------------------------------------- /data/arc-prize-2025/sample_submission.json: -------------------------------------------------------------------------------- 1 | {"00576224": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "007bbfb7": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "009d5c81": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "00d62c1b": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "00dbd492": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "017c7c7b": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "025d127b": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "03560426": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "045e512c": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0520fde7": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "05269061": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "05a7bcf2": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "05f2a901": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0607ce86": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0692e18c": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "06df4c85": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "070dd51e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "08ed6ac7": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "09629e4f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0962bcdd": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "09c534e7": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0a1d4ef5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0a2355a6": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0a938d79": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0b148d64": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0b17323b": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0bb8deee": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0becf7df": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0c786b71": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0c9aba6e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0ca9ddb6": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0d3d703e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0d87d2a6": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0e206a2e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0e671a1a": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "0f63c0b9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "103eff5b": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "10fcaaa3": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "11852cab": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1190bc91": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1190e5a7": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "11dc524f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "11e1fe23": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "12422b43": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "12997ef3": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "12eac192": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "13713586": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "137eaa0f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "137f0df0": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "13f06aa5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "140c817e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "14754a24": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1478ab18": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "14b8e18c": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "150deff5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "15113be4": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "15660dd6": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "15663ba9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "15696249": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "17829a00": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "178fcbfb": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "17b80ad2": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "17b866bd": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "17cae0c1": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "18286ef8": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "182e5d0f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "18419cfa": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "18447a8d": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "184a9768": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "195ba7dc": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1990f7a8": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "19bb5feb": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1a07d186": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1a244afd": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1a2e2828": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1a6449f1": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1acc24af": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1b2d62fb": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1b59e163": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1b60fb0c": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1b8318e3": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1be83260": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1bfc4729": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1c02dbbe": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1c0d0a4b": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1c56ad9f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1c786137": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1caeab9d": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1cf80156": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1d0a4b61": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1d398264": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1d61978c": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1da012fc": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1e0a9b12": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1e32b0e9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1e5d6875": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1e81d6f9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1efba499": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1f0c79e5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1f642eb9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1f85a75f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1f876c06": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "1fad071e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2013d3e2": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2037f2c7": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2072aba6": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "20818e16": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "20981f0e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "20fb2937": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "212895b5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "21f83797": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2204b7a8": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "22168020": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "22208ba4": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "22233c11": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "22425bda": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "22806e14": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2281f1f4": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "228f6490": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "22a4bbc2": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "22eb0ac0": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "230f2e48": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "234bbc79": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "23581191": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "239be575": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "23b5c85d": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "25094a63": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "252143c9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "253bf280": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2546ccf6": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "256b0a75": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "25c199f5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "25d487eb": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "25d8a9c8": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "25e02866": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "25ff71a9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2601afb7": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "264363fd": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2685904e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2697da3f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "272f95fa": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2753e76c": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "278e5215": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "27a28665": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "27a77e38": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "27f8ce4f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "281123b4": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "28bf18c6": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "28e73c20": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "292dd178": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "29623171": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "29700607": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "29c11459": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2a28add5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2a5f8217": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2b01abd0": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2b9ef948": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2bcee788": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2bee17df": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2c0b0aff": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2c608aff": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2c737e39": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2ccd9fef": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2dc579da": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2dd70a9a": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2de01db2": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2dee498d": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2e65ae53": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2f0c5170": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2f767503": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "2faf500b": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "305b1341": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "30f42897": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "310f3251": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3194b014": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "319f2597": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "31aa019c": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "31adaf00": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "31d5ba1a": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "320afe60": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "321b1fc6": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "32597951": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "32e9702f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "33067df9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "332202d5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "332efdb3": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3345333e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "337b420f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3391f8c0": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "33b52de3": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3428a4f5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "342ae2ed": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "342dd610": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3490cc26": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "34b99a2b": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "34cfa167": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "351d6448": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "358ba94e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3618c87e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "363442ee": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "36d67576": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "36fdfd69": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "37ce87bb": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "37d3e8b2": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3906de3d": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "396d80d7": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3979b1a8": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "39a8645d": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "39e1d7f9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3a301edc": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3aa6fb7a": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3ac3eb23": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3ad05f52": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3af2c5a8": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3b4c2228": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3bd292e8": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3bd67248": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3bdb4ada": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3befdf3e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3c9b0459": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3cd86f4f": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}, {"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3d31c5b3": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3d588dc9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3d6c6e23": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3de23699": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3e980e27": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3eda0437": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3ee1011a": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3f23242b": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "3f7978a0": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "4093f84a": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "40f6cd08": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "412b6263": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "414297c0": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "41ace6b5": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "41e4d17e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "423a55dc": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "4258a5f9": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "4290ef0e": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}], "42918530": [{"attempt_1": [[0, 0], [0, 0]], "attempt_2": [[0, 0], [0, 0]]}]} --------------------------------------------------------------------------------