├── requirements.txt ├── .gitignore ├── data └── download.sh ├── CONTRIBUTING.md ├── src ├── agents │ ├── random.py │ ├── crossword.py │ ├── grid_world.py │ ├── chess.py │ └── tic_tac_toe.py ├── config.py ├── interfaces.py ├── prompts.py ├── constants.py ├── environments │ ├── chess.py │ ├── dm_control.py │ ├── grid_world.py │ ├── crossword.py │ └── tic_tac_toe.py ├── main.py ├── bagz.py └── evaluate.py ├── README.md └── LICENSE /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | chess 3 | dm-control 4 | dm-env 5 | etils 6 | imageio 7 | immutabledict 8 | importlib_resources 9 | jax 10 | matplotlib 11 | numpy 12 | tqdm 13 | typing-extensions 14 | zstandard 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /data/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | 18 | set -ex 19 | 20 | wget https://storage.googleapis.com/lm_act/LICENSE 21 | wget https://storage.googleapis.com/lm_act/lm_act.zip 22 | unzip lm_act.zip 23 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /src/agents/random.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Agent that random chooses a legal action.""" 17 | 18 | import dataclasses 19 | from typing import Any 20 | 21 | import numpy as np 22 | 23 | from lm_act.src import config as config_lib 24 | from lm_act.src import interfaces 25 | 26 | 27 | @dataclasses.dataclass(frozen=True, kw_only=True) 28 | class RandomAgentConfig(config_lib.Agent): 29 | """Configuration for the random agent.""" 30 | 31 | name: str = 'random' 32 | 33 | 34 | class RandomAgent(interfaces.Agent): 35 | """Random agent.""" 36 | 37 | def __init__(self, config: RandomAgentConfig) -> None: 38 | pass 39 | 40 | def step( 41 | self, 42 | observation: Any, 43 | environment: interfaces.Environment, 44 | rng: np.random.Generator, 45 | ) -> str: 46 | return environment.sample_legal_action(rng) 47 | -------------------------------------------------------------------------------- /src/agents/crossword.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Oracle agent for the crossword environment.""" 17 | 18 | from collections.abc import Mapping 19 | import dataclasses 20 | from typing import Any 21 | 22 | import numpy as np 23 | 24 | from lm_act.src import config as config_lib 25 | from lm_act.src import interfaces 26 | 27 | 28 | @dataclasses.dataclass(frozen=True, kw_only=True) 29 | class OracleAgentConfig(config_lib.Agent): 30 | """Configuration for the oracle agent.""" 31 | 32 | name: str = 'crossword_oracle' 33 | 34 | 35 | class OracleAgent(interfaces.Agent): 36 | """Interface for agents.""" 37 | 38 | def __init__(self, config: OracleAgentConfig) -> None: 39 | pass 40 | 41 | def step( 42 | self, 43 | observation: Mapping[str, Any], 44 | environment: interfaces.Environment, 45 | rng: np.random.Generator, 46 | ) -> str: 47 | """Returns one of the missing words.""" 48 | solutions = observation['puzzle'].split('Solution:\n')[1].split('\n') 49 | solutions = [sol.split(' (')[0] for sol in solutions] 50 | return rng.choice(solutions) 51 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Configuration for the experiment.""" 17 | 18 | import dataclasses 19 | 20 | 21 | @dataclasses.dataclass(frozen=True, kw_only=True) 22 | class Agent: 23 | """Configuration for the agent.""" 24 | 25 | name: str 26 | action_type: str = 'txt' 27 | 28 | 29 | @dataclasses.dataclass(frozen=True, kw_only=True) 30 | class Environment: 31 | """Configuration for the environment.""" 32 | 33 | name: str 34 | observation_type: str 35 | action_type: str = 'txt' 36 | 37 | 38 | @dataclasses.dataclass(frozen=True, kw_only=True) 39 | class Prompt: 40 | """Configuration for the prompt.""" 41 | 42 | show_legal_actions: bool 43 | use_chain_of_thought: bool 44 | include_past_actions: bool = True 45 | 46 | 47 | @dataclasses.dataclass(frozen=True, kw_only=True) 48 | class Experiment: 49 | """Configuration for the experiment.""" 50 | 51 | num_demonstrations: int = 0 52 | num_evaluation_steps: int = 100 53 | num_evaluation_episodes: int = 100 54 | replay_episode: bool = False 55 | 56 | agent: Agent 57 | environment: Environment 58 | prompt: Prompt 59 | 60 | def __post_init__(self): 61 | if self.agent.action_type != self.environment.action_type: 62 | raise ValueError('The agent and environment action types must match.') 63 | 64 | if self.replay_episode and self.num_demonstrations != 1: 65 | raise ValueError('Replaying an episode requires exactly 1 demonstration.') 66 | -------------------------------------------------------------------------------- /src/agents/grid_world.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Shortest path agent for grid world.""" 17 | 18 | from collections.abc import Mapping 19 | import dataclasses 20 | from typing import Any 21 | 22 | import numpy as np 23 | 24 | from lm_act.src import config as config_lib 25 | from lm_act.src import interfaces 26 | 27 | 28 | @dataclasses.dataclass(frozen=True, kw_only=True) 29 | class ShortestPathAgentConfig(config_lib.Agent): 30 | """Configuration for the minimax agent.""" 31 | 32 | name: str = 'grid_world_shortest_path' 33 | 34 | 35 | class ShortestPathAgent(interfaces.Agent): 36 | """Shortest path agent for grid world.""" 37 | 38 | def __init__(self, config: ShortestPathAgentConfig) -> None: 39 | pass 40 | 41 | def step( 42 | self, 43 | observation: Mapping[str, Any], 44 | environment: interfaces.Environment, 45 | rng: np.random.Generator, 46 | ) -> str: 47 | """Returns an optimal action for the observation and legal actions.""" 48 | player_y, player_x = observation['player'] 49 | target_y, target_x = observation['target'] 50 | 51 | optimal_actions = list() 52 | 53 | if target_x < player_x: 54 | optimal_actions.append('left') 55 | elif player_x < target_x: 56 | optimal_actions.append('right') 57 | if target_y < player_y: 58 | optimal_actions.append('up') 59 | elif player_y < target_y: 60 | optimal_actions.append('down') 61 | 62 | return rng.choice(optimal_actions) 63 | -------------------------------------------------------------------------------- /src/agents/chess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Random and Stockfish agents for chess.""" 17 | 18 | import dataclasses 19 | import os 20 | from typing import Any 21 | 22 | import chess 23 | import chess.engine 24 | import chess.pgn 25 | import chess.svg 26 | import numpy as np 27 | 28 | from lm_act.src import config as config_lib 29 | from lm_act.src import interfaces 30 | from lm_act.src.environments import chess as chess_env 31 | 32 | 33 | @dataclasses.dataclass(frozen=True, kw_only=True) 34 | class StockfishAgentConfig(config_lib.Agent): 35 | """Configuration for the Stockfish agent.""" 36 | 37 | name: str = 'chess_stockfish' 38 | action_type: chess_env.ActionNotation = 'san' 39 | skill_level: int = 20 40 | time_limit: float = 0.05 41 | node_limit: int | None = None 42 | 43 | 44 | class StockfishAgent(interfaces.Agent): 45 | """Stockfish agent.""" 46 | 47 | def __init__( 48 | self, 49 | config: StockfishAgentConfig, 50 | ) -> None: 51 | self._skill_level = config.skill_level 52 | self._action_notation = config.action_type 53 | self._limit = chess.engine.Limit( 54 | nodes=config.node_limit, 55 | time=config.time_limit, 56 | ) 57 | 58 | def step( 59 | self, 60 | observation: Any, 61 | environment: interfaces.Environment, 62 | rng: np.random.Generator, 63 | ) -> str: 64 | """Returns Stockfish's action for the board in the observation.""" 65 | bin_path = os.path.join( 66 | os.getcwd(), 67 | '../Stockfish/src/stockfish', 68 | ) 69 | with chess.engine.SimpleEngine.popen_uci(bin_path) as engine: 70 | engine.configure({'Skill Level': self._skill_level}) 71 | return chess_env.action_to_string( 72 | action_notation=self._action_notation, 73 | action=engine.play(observation['board'], limit=self._limit).move, 74 | board=observation['board'], 75 | ) 76 | -------------------------------------------------------------------------------- /src/interfaces.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Defines the interfaces for agents and environments.""" 17 | 18 | import abc 19 | from collections.abc import Mapping, Sequence 20 | import pathlib 21 | from typing import Any 22 | 23 | import dm_env 24 | import numpy as np 25 | 26 | from lm_act.src import config as config_lib 27 | 28 | 29 | class Environment(dm_env.Environment): 30 | """Interface for environments.""" 31 | 32 | @abc.abstractmethod 33 | def __init__( 34 | self, 35 | config: config_lib.Environment, 36 | opening_paths: Sequence[pathlib.Path] | None = None, 37 | ) -> None: 38 | """Initializes the environment.""" 39 | 40 | @property 41 | @abc.abstractmethod 42 | def legal_actions(self) -> list[str]: 43 | """Returns the legal actions.""" 44 | 45 | def sample_legal_action(self, rng: np.random.Generator) -> str: 46 | """Returns a random legal action.""" 47 | # By default, we just return one of the legal actions. 48 | return rng.choice(self.legal_actions) 49 | 50 | def action_is_illegal(self, action: str) -> bool: 51 | """Returns whether the action is illegal.""" 52 | # By default, we just check if the action is in the legal actions. 53 | return action not in self.legal_actions 54 | 55 | def action_is_invalid(self, action: str) -> bool: 56 | """Returns whether the action is valid.""" 57 | # By default, we just check if the action is a single word. 58 | return len(action.split()) != 1 59 | 60 | def observation_spec(self) -> Any: 61 | return NotImplementedError('Unnecessary for this interface.') 62 | 63 | def action_spec(self) -> Any: 64 | return NotImplementedError('Unnecessary for this interface.') 65 | 66 | 67 | class Agent(abc.ABC): 68 | """Interface for agents.""" 69 | 70 | @abc.abstractmethod 71 | def __init__(self, config: config_lib.Agent) -> None: 72 | """Initializes the agent.""" 73 | 74 | @abc.abstractmethod 75 | def step( 76 | self, 77 | observation: Mapping[str, Any], 78 | env: Environment, 79 | rng: np.random.Generator, 80 | ) -> str: 81 | """Returns the agent's action for the observation, environment, and rng.""" 82 | -------------------------------------------------------------------------------- /src/prompts.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Builds the prompts for the experiment.""" 17 | 18 | from lm_act.src import config as config_lib 19 | 20 | 21 | def build_demonstration_prompt( 22 | demonstrations: str, 23 | ) -> str: 24 | """Returns the prompt for the demonstrations.""" 25 | if not demonstrations: 26 | return ( 27 | 'You are an intelligent agent operating in a dynamic environment. Based' 28 | ' on the series of observations provided, you need to determine the' 29 | ' optimal action that maximizes the expected reward or achieves the' 30 | ' desired goal. Carefully consider all the given observations, infer' 31 | ' the current state of the environment, and select the most appropriate' 32 | ' action.\n\n' 33 | ) 34 | 35 | return ( 36 | 'You are a powerful reinforcement learning agent. You can effectively' 37 | ' identify a policy exposed by demonstrations and reproduce it in a new' 38 | ' situation.\n\nHere are a number of' 39 | f' demonstrations:\n\n{demonstrations}\n' 40 | ) 41 | 42 | 43 | def build_trajectory_prompt( 44 | trajectory: str, 45 | legal_actions: list[str], 46 | config: config_lib.Experiment, 47 | ) -> str: 48 | """Returns the prompt for the current trajectory.""" 49 | prompt = f'\nThis is the current trajectory:\n\n{trajectory}\n' 50 | 51 | if config.prompt.show_legal_actions: 52 | prompt += ( 53 | '\nIn this situation, this is the list of all the actions that are' 54 | f' legal:\n\n{", ".join(legal_actions)}\n' 55 | ) 56 | 57 | prompt += '\nGiven the ' 58 | if 0 < config.num_demonstrations: 59 | prompt += 'demonstrations and the ' 60 | prompt += 'current trajectory, you should infer the next logical action.' 61 | 62 | if config.prompt.show_legal_actions: 63 | prompt += '\nCheck that the chosen action is in the set of legal actions.' 64 | 65 | if config.prompt.use_chain_of_thought: 66 | prompt += ( 67 | '\nThink step by step and very briefly explain your reasoning for' 68 | ' choosing this action.\nYou must answer with the reasoning followed by' 69 | ' the action in the following format:\nReasoning: ...\nAction: ...' 70 | ) 71 | else: 72 | if config.prompt.show_legal_actions: 73 | prompt += ( 74 | '\nYou must answer with one of the legal actions only, without any' 75 | ' other text.' 76 | ) 77 | else: 78 | prompt += ( 79 | '\nYou must answer with the action only, without any other text,' 80 | ' following exactly the same format as the previous actions.' 81 | ) 82 | 83 | return prompt.strip() 84 | -------------------------------------------------------------------------------- /src/agents/tic_tac_toe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Minimax agent for tic tac toe.""" 17 | 18 | import copy 19 | import dataclasses 20 | from typing import Any 21 | 22 | import numpy as np 23 | 24 | from lm_act.src import config as config_lib 25 | from lm_act.src import interfaces 26 | from lm_act.src.environments import tic_tac_toe as tic_tac_toe_env 27 | 28 | 29 | @dataclasses.dataclass(frozen=True, kw_only=True) 30 | class MinimaxAgentConfig(config_lib.Agent): 31 | """Configuration for the minimax agent.""" 32 | 33 | name: str = 'tic_tac_toe_minimax' 34 | 35 | 36 | def _minimax( 37 | player_symbol: str, # Constant across the recursion. 38 | board: np.ndarray, 39 | current_symbol: str, 40 | depth: int, 41 | is_max: bool, 42 | ) -> tuple[str | None, int]: 43 | """Returns the best move and its value for the given board and symbol.""" 44 | # Check whether there is a winner. 45 | for _, axis in tic_tac_toe_env.AXES: 46 | line = np.take_along_axis(board.flatten(), axis, axis=None) 47 | if (line == player_symbol).all(): 48 | return None, 10 - depth 49 | elif (line == ('o' if player_symbol == 'x' else 'x')).all(): 50 | return None, -10 + depth 51 | 52 | # Check whether there is a draw. 53 | if ' ' not in board: 54 | return None, 0 55 | 56 | # Search all the legal moves. 57 | best_value = -100 if is_max else 100 58 | best_move = None 59 | 60 | for move in tic_tac_toe_env.legal_actions(board): 61 | row = tic_tac_toe_env.ROWS.index(move[0]) 62 | col = tic_tac_toe_env.COLS.index(move[1]) 63 | new_board = copy.deepcopy(board) 64 | new_board[row, col] = current_symbol 65 | _, value = _minimax( 66 | player_symbol=player_symbol, 67 | board=new_board, 68 | current_symbol='o' if current_symbol == 'x' else 'x', 69 | is_max=not is_max, 70 | depth=depth + 1, 71 | ) 72 | 73 | if (is_max and best_value < value) or (not is_max and value < best_value): 74 | best_value = value 75 | best_move = move 76 | 77 | return best_move, best_value 78 | 79 | 80 | class MinimaxAgent(interfaces.Agent): 81 | """Minimax agent for tic tac toe.""" 82 | 83 | def __init__(self, config: MinimaxAgentConfig): 84 | pass 85 | 86 | def step( 87 | self, 88 | observation: Any, 89 | environment: interfaces.Environment, 90 | rng: np.random.Generator, 91 | ) -> str: 92 | """Returns the best move for the given board.""" 93 | action, _ = _minimax( 94 | player_symbol=observation['symbol'], 95 | board=observation['board'], 96 | current_symbol=observation['symbol'], 97 | depth=0, 98 | is_max=True, 99 | ) 100 | if action is None: 101 | raise ValueError('No optimal action found.') 102 | return action 103 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Constants used in the LMAct project.""" 17 | 18 | from typing import TypeVar 19 | 20 | from lm_act.src import interfaces 21 | from lm_act.src.agents import chess as chess_agents 22 | from lm_act.src.agents import crossword as crossword_agents 23 | from lm_act.src.agents import grid_world as grid_world_agents 24 | from lm_act.src.agents import random as random_agents 25 | from lm_act.src.agents import tic_tac_toe as tic_tac_toe_agents 26 | from lm_act.src.environments import chess as chess_env 27 | from lm_act.src.environments import crossword as crossword_env 28 | from lm_act.src.environments import dm_control as dm_control_env 29 | from lm_act.src.environments import grid_world as grid_world_env 30 | from lm_act.src.environments import tic_tac_toe as tic_tac_toe_env 31 | 32 | 33 | def get_rgb_shape(environment_name: str) -> tuple[int, int, int]: 34 | if environment_name.startswith('atari'): 35 | return (210, 160, 3) 36 | elif environment_name == 'chess': 37 | return (256, 256, 3) 38 | elif environment_name.startswith('tic_tac_toe'): 39 | return (256, 256, 3) 40 | elif environment_name.startswith('grid_world'): 41 | return (192, 192, 3) 42 | else: 43 | raise ValueError(f'Unknown environment name: {environment_name}.') 44 | 45 | 46 | # The interfaces are abstract, so we need to use TypeVar to make them generic. 47 | Agent = TypeVar('Agent', bound=interfaces.Agent) 48 | Environment = TypeVar('Environment', bound=interfaces.Environment) 49 | 50 | 51 | def get_agent_builder(agent_name: str) -> type[Agent]: 52 | match agent_name: 53 | case 'random': 54 | return random_agents.RandomAgent 55 | case 'chess_stockfish': 56 | return chess_agents.StockfishAgent 57 | case 'crossword_oracle': 58 | return crossword_agents.OracleAgent 59 | case 'grid_world_shortest_path': 60 | return grid_world_agents.ShortestPathAgent 61 | case 'tic_tac_toe_minimax': 62 | return tic_tac_toe_agents.MinimaxAgent 63 | case _: 64 | raise ValueError(f'Unknown agent name: {agent_name}.') 65 | 66 | 67 | def get_environment_builder(environment_name: str) -> type[Environment]: 68 | """Returns the environment builder for the given environment name.""" 69 | if environment_name.startswith('atari'): 70 | raise NotImplementedError('atari environments are not yet supported.') 71 | elif environment_name == 'chess': 72 | return chess_env.Chess 73 | elif environment_name == 'crossword': 74 | return crossword_env.Crossword 75 | elif environment_name.startswith('dm_control'): 76 | return dm_control_env.DMControl 77 | elif environment_name.startswith('grid_world'): 78 | return grid_world_env.GridWorld 79 | elif environment_name.startswith('tic_tac_toe'): 80 | return tic_tac_toe_env.TicTacToe 81 | else: 82 | raise ValueError(f'Unknown environment name: {environment_name}.') 83 | -------------------------------------------------------------------------------- /src/environments/chess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Chess environment.""" 17 | 18 | import copy 19 | import dataclasses 20 | import io 21 | import os 22 | import pathlib 23 | from typing import Literal 24 | 25 | import chess 26 | import chess.pgn 27 | import dm_env 28 | from numpy import random 29 | 30 | from lm_act.src import bagz 31 | from lm_act.src import config as config_lib 32 | from lm_act.src import interfaces 33 | 34 | 35 | ActionNotation = Literal['uci', 'san'] 36 | 37 | 38 | def action_to_string( 39 | action: chess.Move, 40 | board: chess.Board, 41 | action_notation: ActionNotation, 42 | ) -> str: 43 | """Returns the string representation of a chess action.""" 44 | match action_notation: 45 | case 'uci': 46 | return str(action) 47 | case 'san': 48 | return board.san(action) 49 | 50 | 51 | @dataclasses.dataclass(frozen=True, kw_only=True) 52 | class EnvironmentConfig(config_lib.Environment): 53 | """Configuration for the environment.""" 54 | 55 | name: str = 'chess' 56 | observation_type: Literal['fen', 'pgn', 'txt', 'png'] = 'fen' 57 | action_type: ActionNotation = 'san' 58 | 59 | stockfish_node_limit: int = 1 60 | stockfish_time_limit: float = 0.001 61 | stockfish_skill_level: int = 0 62 | 63 | 64 | class Chess(interfaces.Environment): 65 | """A simple chess environment to play against stockfish.""" 66 | 67 | def __init__( 68 | self, 69 | config: EnvironmentConfig, 70 | opening_paths: list[pathlib.Path] | None = None, 71 | openings: list[chess.pgn.Game] | None = None, 72 | ) -> None: 73 | self._stockfish_limit = chess.engine.Limit( 74 | nodes=config.stockfish_node_limit, 75 | time=config.stockfish_time_limit, 76 | ) 77 | self._stockfish_skill_level = config.stockfish_skill_level 78 | self._action_notation = config.action_type 79 | 80 | if openings is not None: 81 | self._openings = openings 82 | elif opening_paths is not None: 83 | self._openings = list() 84 | for opening_path in opening_paths: 85 | pgns = bagz.BagReader( 86 | (opening_path / 'observations_pgn.bag').as_posix() 87 | ) 88 | pgn = io.StringIO(pgns[0].decode('utf-8')) 89 | if (game := chess.pgn.read_game(pgn)) is not None: 90 | self._openings.append(game) 91 | else: 92 | raise ValueError('Either `openings` or `opening_paths` must be provided.') 93 | 94 | self._game: chess.pgn.Game = None 95 | self._node: chess.pgn.GameNode = None 96 | self._player_is_white: bool = None 97 | 98 | def reset(self) -> dm_env.TimeStep: 99 | self._game = self._openings.pop(0) 100 | self._node = self._game.end() 101 | self._player_is_white = self._board.turn == chess.WHITE 102 | return dm_env.restart(observation=self._observation) 103 | 104 | @property 105 | def _board(self): 106 | return self._node.board() 107 | 108 | @property 109 | def _observation(self): 110 | return { 111 | 'board': copy.deepcopy(self._board), 112 | 'fen': self._board.fen(en_passant='fen'), 113 | 'pgn': str(self._game), 114 | 'txt': str(self._board), 115 | } 116 | 117 | def _turn( 118 | self, 119 | action: str, 120 | notation: ActionNotation, 121 | ) -> None | dm_env.TimeStep: 122 | match notation: 123 | case 'uci': 124 | action = chess.Move.from_uci(action) 125 | case 'san': 126 | action = self._board.push_san(action) 127 | self._node = self._node.add_main_variation(action) 128 | 129 | if ( 130 | claim_draw := self._board.can_claim_draw() 131 | ) or self._board.is_game_over(): 132 | match (outcome := self._board.outcome(claim_draw=claim_draw)).winner: # pytype: disable=attribute-error 133 | case chess.WHITE: 134 | reward = 1 if self._player_is_white else -1 135 | case chess.BLACK: 136 | reward = -1 if self._player_is_white else 1 137 | case None: 138 | reward = 0 139 | case _: 140 | raise ValueError(f'Unknown outcome: {outcome}') 141 | return dm_env.termination(observation=self._observation, reward=reward) 142 | 143 | return None 144 | 145 | def step(self, action: str) -> dm_env.TimeStep: 146 | outcome = self._turn(action=action, notation=self._action_notation) 147 | if outcome is not None: 148 | return outcome 149 | 150 | bin_path = os.path.join( 151 | os.getcwd(), 152 | '../Stockfish/src/stockfish', 153 | ) 154 | with chess.engine.SimpleEngine.popen_uci(bin_path) as engine: 155 | engine.configure({'Skill Level': self._stockfish_skill_level}) 156 | action = str(engine.play(self._board, limit=self._stockfish_limit).move) 157 | 158 | if (outcome := self._turn(action=action, notation='uci')) is not None: 159 | return outcome 160 | 161 | return dm_env.transition(observation=self._observation, reward=0) 162 | 163 | @property 164 | def legal_actions(self) -> list[str]: 165 | return sorted( 166 | action_to_string( 167 | action=action, 168 | board=self._board, 169 | action_notation=self._action_notation, 170 | ) 171 | for action in self._board.legal_moves 172 | ) 173 | 174 | def sample_legal_action(self, rng: random.Generator) -> str: 175 | return action_to_string( 176 | action=rng.choice(list(self._board.legal_moves)), 177 | board=self._board, 178 | action_notation=self._action_notation, 179 | ) 180 | -------------------------------------------------------------------------------- /src/environments/dm_control.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """DM Control environment.""" 17 | 18 | import ast 19 | from collections.abc import Mapping 20 | import dataclasses 21 | import pathlib 22 | from typing import Any, Literal 23 | 24 | from dm_control import suite 25 | import dm_env 26 | import numpy as np 27 | from numpy import random 28 | 29 | from lm_act.src import bagz 30 | from lm_act.src import config as config_lib 31 | from lm_act.src import interfaces 32 | 33 | 34 | @dataclasses.dataclass(frozen=True, kw_only=True) 35 | class EnvironmentConfig(config_lib.Environment): 36 | """Configuration for the environment.""" 37 | 38 | name: str = 'dm_control' 39 | domain: str = 'cheetah' 40 | task: str = 'run' 41 | observation_type: Literal['dict'] = 'dict' 42 | seed: int = 0 43 | 44 | def __post_init__(self): 45 | object.__setattr__(self, 'name', f'{self.name}_{self.domain}_{self.task}') 46 | 47 | 48 | class DMControl(interfaces.Environment): 49 | """DM Control environment.""" 50 | 51 | def __init__( 52 | self, 53 | config: EnvironmentConfig, 54 | opening_paths: list[pathlib.Path] | None = None, 55 | openings: list[Mapping[str, Any]] | None = None, 56 | ) -> None: 57 | if openings is not None: 58 | self._openings = openings 59 | elif opening_paths is not None: 60 | self._openings = list() 61 | for opening_path in opening_paths: 62 | observations = bagz.BagReader( 63 | (opening_path / 'observations_dict.bag').as_posix() 64 | ) 65 | self._openings.append(ast.literal_eval(observations[0].decode('utf-8'))) 66 | else: 67 | raise ValueError('Either `openings` or `opening_paths` must be provided.') 68 | 69 | self._config = config 70 | self._env = suite.load( 71 | domain_name=config.domain, 72 | task_name=config.task, 73 | task_kwargs=dict(random=config.seed), 74 | ) 75 | self._action_spec = self._env.action_spec() 76 | 77 | @property 78 | def domain(self) -> str: 79 | return self._config.domain 80 | 81 | @property 82 | def task(self) -> str: 83 | return self._config.task 84 | 85 | def _prepare_observation(self, time_step: dm_env.TimeStep) -> dm_env.TimeStep: 86 | time_step.observation['dict'] = str( 87 | {k: v.tolist() for k, v in time_step.observation.items()} 88 | ) 89 | return time_step 90 | 91 | def _set_init_state(self, observation: Mapping[str, Any]) -> dm_env.TimeStep: 92 | """Returns the time step after setting the initial environment state.""" 93 | time_step = self._env.reset() 94 | 95 | with self._env.physics.reset_context(): 96 | match self._config.domain: 97 | case 'cheetah' | 'hopper': 98 | self._env.physics.data.qpos[0] = 0 99 | self._env.physics.data.qpos[1:] = observation['position'] 100 | self._env.physics.data.qvel[:] = observation['velocity'] 101 | case 'point_mass': 102 | self._env.physics.data.qpos[:] = observation['position'] 103 | self._env.physics.data.qvel[:] = observation['velocity'] 104 | case _: 105 | raise ValueError(f'Unknown domain: {self._config.domain}') 106 | 107 | self._env.physics.after_reset() 108 | env_observation = self._env.task.get_observation(self._env.physics) 109 | 110 | # Check that the observation from the environment matches `observation`. 111 | for key, value in observation.items(): 112 | np.testing.assert_equal(env_observation[key], value) 113 | 114 | time_step = dm_env.TimeStep( 115 | step_type=time_step.step_type, 116 | reward=time_step.reward, 117 | discount=time_step.discount, 118 | observation=env_observation, 119 | ) 120 | return self._prepare_observation(time_step=time_step) 121 | 122 | def reset(self) -> dm_env.TimeStep: 123 | """Resets the environment.""" 124 | observation = self._openings.pop(0) 125 | return self._set_init_state(observation) 126 | 127 | def step(self, action: str) -> dm_env.TimeStep: 128 | """Steps the environment.""" 129 | actions = self._extract_actions(action) 130 | return self._prepare_observation(self._env.step(actions)) 131 | 132 | @property 133 | def legal_actions(self) -> list[str]: 134 | """Returns the legal actions.""" 135 | return [ 136 | 'A comma-separated list (enclosed by square brackets) of' 137 | f' {self._action_spec.shape[0]} values between' 138 | f' {self._action_spec.minimum.tolist()} and' 139 | f' {self._action_spec.maximum.tolist()}.' 140 | ] 141 | 142 | def sample_legal_action(self, rng: random.Generator) -> str: 143 | min_values = self._action_spec.minimum 144 | max_values = self._action_spec.maximum 145 | return str( 146 | ( 147 | (max_values - min_values) * rng.random(len(min_values)) + min_values 148 | ).tolist() 149 | ) 150 | 151 | def action_is_illegal(self, action: str) -> bool: 152 | """Returns whether the action is legal.""" 153 | # For DM Control, we treat invalid and illegal actions as the same. 154 | return self.action_is_invalid(action) 155 | 156 | def action_is_invalid(self, action: str) -> bool: 157 | """Returns whether the action is valid.""" 158 | try: 159 | self._extract_actions(action) 160 | return False 161 | except ValueError: 162 | return True 163 | 164 | def _extract_actions(self, action: str) -> np.ndarray: 165 | if not action.startswith('[') or not action.endswith(']'): 166 | raise ValueError(f'Invalid action: {action}') 167 | action = action[1:-1] 168 | values = np.fromiter(map(float, action.split(',')), dtype=np.float64) 169 | self._action_spec.validate(values) 170 | return values 171 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Evaluates an agent on the LMAct benchmark.""" 17 | 18 | from collections.abc import Sequence 19 | import logging 20 | 21 | from absl import app 22 | from absl import flags 23 | import immutabledict 24 | import numpy as np 25 | import tqdm 26 | 27 | from lm_act.src import config as config_lib 28 | from lm_act.src import evaluate 29 | from lm_act.src.agents import chess as chess_agent 30 | from lm_act.src.agents import crossword as crossword_agent 31 | from lm_act.src.agents import grid_world as grid_world_agent 32 | from lm_act.src.agents import random as random_agent 33 | from lm_act.src.agents import tic_tac_toe as tic_tac_toe_agent 34 | from lm_act.src.environments import chess 35 | from lm_act.src.environments import crossword 36 | from lm_act.src.environments import dm_control 37 | from lm_act.src.environments import grid_world 38 | from lm_act.src.environments import tic_tac_toe 39 | 40 | 41 | _ENVIRONMENT = flags.DEFINE_enum( 42 | name='environment', 43 | default='tic_tac_toe', 44 | enum_values=[ 45 | 'chess', 46 | 'crossword', 47 | 'dm_control', 48 | 'grid_world', 49 | 'tic_tac_toe', 50 | ], 51 | help='The environment to evaluate.', 52 | ) 53 | _OBSERVATION_TYPE = flags.DEFINE_enum( 54 | name='observation_type', 55 | default='txt', 56 | enum_values=['coords', 'dict', 'fen', 'pgn', 'png', 'rgb', 'txt'], 57 | help='The observation representation to evaluate.', 58 | ) 59 | _ACTION_TYPE = flags.DEFINE_enum( 60 | name='action_type', 61 | default='txt', 62 | enum_values=['txt', 'san'], 63 | help='The action representation to evaluate.', 64 | ) 65 | _AGENT = flags.DEFINE_enum( 66 | name='agent', 67 | default='random', 68 | enum_values=[ 69 | 'random', 70 | 'chess_stockfish', 71 | 'crossword_oracle', 72 | 'grid_world_shortest_path', 73 | 'tic_tac_toe_minimax', 74 | ], 75 | help='The agent to evaluate.', 76 | ) 77 | _NUM_DEMONSTRATIONS = flags.DEFINE_integer( 78 | name='num_demonstrations', 79 | default=0, 80 | help='The number of demonstrations to use.', 81 | ) 82 | _NUM_EVALUTION_EPISODES = flags.DEFINE_integer( 83 | name='num_evaluation_episodes', 84 | default=100, 85 | help='The number of episodes to evaluate.', 86 | ) 87 | _NUM_EVALUATION_STEPS = flags.DEFINE_integer( 88 | name='num_evaluation_steps', 89 | default=100, 90 | help='The number of steps to evaluate.', 91 | ) 92 | 93 | _CONFIG_BY_ENVIRONMENT = immutabledict.immutabledict({ 94 | 'chess': chess.EnvironmentConfig, 95 | 'crossword': crossword.EnvironmentConfig, 96 | 'dm_control': dm_control.EnvironmentConfig, 97 | 'grid_world': grid_world.EnvironmentConfig, 98 | 'tic_tac_toe': tic_tac_toe.EnvironmentConfig, 99 | }) 100 | _CONFIG_BY_AGENT = immutabledict.immutabledict({ 101 | 'random': random_agent.RandomAgentConfig, 102 | 'chess_stockfish': chess_agent.StockfishAgentConfig, 103 | 'crossword_oracle': crossword_agent.OracleAgentConfig, 104 | 'grid_world_shortest_path': grid_world_agent.ShortestPathAgentConfig, 105 | 'tic_tac_toe_minimax': tic_tac_toe_agent.MinimaxAgentConfig, 106 | }) 107 | 108 | 109 | def main(argv: Sequence[str]) -> None: 110 | if len(argv) > 1: 111 | raise app.UsageError('Too many command-line arguments.') 112 | 113 | logging.getLogger().setLevel(logging.WARNING) 114 | 115 | print(f'Environment: {_ENVIRONMENT.value}') 116 | print(f'Observation type: {_OBSERVATION_TYPE.value}') 117 | print(f'Agent: {_AGENT.value}') 118 | print(f'Num evaluation episodes: {_NUM_EVALUTION_EPISODES.value}') 119 | 120 | scores = list() 121 | num_steps = list() 122 | num_invalid_actions = list() 123 | num_illegal_actions = list() 124 | num_empty_actions = list() 125 | 126 | for episode in tqdm.trange(_NUM_EVALUTION_EPISODES.value): 127 | ( 128 | episode_score, 129 | episode_num_steps, 130 | episode_num_invalid_actions, 131 | episode_num_illegal_actions, 132 | episode_num_empty_actions, 133 | ) = evaluate.evaluate_episode( 134 | episode_idx=episode, 135 | config=config_lib.Experiment( 136 | num_demonstrations=_NUM_DEMONSTRATIONS.value, 137 | num_evaluation_steps=_NUM_EVALUATION_STEPS.value, 138 | agent=_CONFIG_BY_AGENT[_AGENT.value]( 139 | action_type=_ACTION_TYPE.value, 140 | ), 141 | environment=_CONFIG_BY_ENVIRONMENT[_ENVIRONMENT.value]( 142 | observation_type=_OBSERVATION_TYPE.value, 143 | action_type=_ACTION_TYPE.value, 144 | ), 145 | prompt=config_lib.Prompt( 146 | show_legal_actions=None, 147 | use_chain_of_thought=None, 148 | ), 149 | ), 150 | ) 151 | 152 | scores.append(episode_score) 153 | num_steps.append(episode_num_steps) 154 | num_invalid_actions.append(episode_num_invalid_actions) 155 | num_illegal_actions.append(episode_num_illegal_actions) 156 | num_empty_actions.append(episode_num_empty_actions) 157 | 158 | logging.info({ 159 | 'episode': episode, 160 | 'score': episode_score, 161 | 'num_steps': episode_num_steps, 162 | 'num_invalid_actions': episode_num_invalid_actions, 163 | 'num_illegal_actions': episode_num_illegal_actions, 164 | 'num_empty_actions': episode_num_empty_actions, 165 | }) 166 | 167 | print(f'Average score: {np.mean(scores):.2f}') 168 | print(f'Average num steps: {np.mean(num_steps):.2f}') 169 | print(f'Average num invalid actions: {np.mean(num_invalid_actions):.2f}') 170 | print(f'Average num illegal actions: {np.mean(num_illegal_actions):.2f}') 171 | print(f'Average num empty actions: {np.mean(num_empty_actions):.2f}') 172 | 173 | 174 | if __name__ == '__main__': 175 | app.run(main) 176 | -------------------------------------------------------------------------------- /src/environments/grid_world.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Grid World environment.""" 17 | 18 | from collections.abc import Mapping 19 | import copy 20 | import dataclasses 21 | import os 22 | import pathlib 23 | from typing import Literal 24 | 25 | import dm_env 26 | import imageio 27 | import jax 28 | import numpy as np 29 | 30 | from lm_act.src import bagz 31 | from lm_act.src import config as config_lib 32 | from lm_act.src import interfaces 33 | 34 | 35 | _ASSETS_PATH = pathlib.Path( 36 | os.path.join(os.getcwd(), '../crafter/crafter/assets') 37 | ) 38 | 39 | 40 | @dataclasses.dataclass(frozen=True, kw_only=True) 41 | class EnvironmentConfig(config_lib.Environment): 42 | """Configuration for the environment.""" 43 | 44 | name: str = 'grid_world' 45 | observation_type: Literal['rgb', 'txt', 'coords'] = 'rgb' 46 | height: int = 12 47 | width: int = 12 48 | 49 | def __post_init__(self): 50 | object.__setattr__(self, 'name', f'grid_world_{self.height}x{self.width}') 51 | 52 | 53 | class GridWorld(interfaces.Environment): 54 | """2D grid world environment with a player and a target.""" 55 | 56 | def __init__( 57 | self, 58 | config: EnvironmentConfig, 59 | opening_paths: list[pathlib.Path] | None = None, 60 | openings: list[tuple[np.ndarray, np.ndarray]] | None = None, 61 | ) -> None: 62 | if openings is not None: 63 | self._openings = openings 64 | elif opening_paths is not None: 65 | self._openings = list() 66 | for opening_path in opening_paths: 67 | player_coordinates = bagz.BagReader( 68 | (opening_path / 'observations_player.bag').as_posix() 69 | ) 70 | target_coordinates = bagz.BagReader( 71 | (opening_path / 'observations_target.bag').as_posix() 72 | ) 73 | player = np.frombuffer(player_coordinates[0], dtype=np.int64) 74 | target = np.frombuffer(target_coordinates[0], dtype=np.int64) 75 | self._openings.append((player, target)) 76 | else: 77 | raise ValueError('Either `openings` or `opening_paths` must be provided.') 78 | 79 | self._width = config.width 80 | self._height = config.height 81 | 82 | self._player: np.ndarray = None 83 | self._target: np.ndarray = None 84 | 85 | self._walls = np.full((self._height, self._width), False, dtype=bool) 86 | self._walls[0, :] = True 87 | self._walls[-1, :] = True 88 | self._walls[:, 0] = True 89 | self._walls[:, -1] = True 90 | 91 | with open(_ASSETS_PATH / 'food.png', 'r') as f: 92 | target_sprite = imageio.imread(f)[:, :, :-1] 93 | with open(_ASSETS_PATH / 'player.png', 'r') as f: 94 | player_sprite = imageio.imread(f)[:, :, :-1] 95 | with open(_ASSETS_PATH / 'stone.png', 'r') as f: 96 | wall_sprite = imageio.imread(f) 97 | 98 | sprite_matrix = np.array([wall_sprite, target_sprite, player_sprite]) 99 | self._sprite_matrix = np.transpose(sprite_matrix[:, ::-1], [1, 2, 0, 3]) 100 | 101 | self._rgb = jax.jit(self._rgb) 102 | 103 | def reset(self) -> dm_env.TimeStep: 104 | self._player, self._target = self._openings.pop(0) 105 | return dm_env.restart(observation=self._observation) 106 | 107 | def _rgb(self, state: np.ndarray) -> jax.Array: 108 | return jax.lax.conv_transpose( 109 | state[None], # NHWC 110 | self._sprite_matrix, # HWIO 111 | (16, 16), 112 | 'SAME', 113 | )[0] 114 | 115 | @property 116 | def _text(self) -> str: 117 | scene = list() 118 | 119 | for row in range(self._height): 120 | row_str = '|' 121 | for col in range(self._width): 122 | if self._walls[row, col]: 123 | tile = 'wall' 124 | elif self._player[0] == row and self._player[1] == col: 125 | tile = 'player' 126 | elif self._target[0] == row and self._target[1] == col: 127 | tile = 'target' 128 | else: 129 | tile = 'tile' 130 | row_str += tile.center(8) + '|' 131 | scene.append(row_str) 132 | scene.append('-' * len(row_str)) 133 | 134 | scene = ['-' * len(scene[0])] + scene 135 | 136 | return '\n'.join(scene) 137 | 138 | @property 139 | def _observation(self) -> Mapping[str, np.ndarray | str]: 140 | player_state = np.zeros((self._height, self._width), dtype=np.bool_) 141 | player_state[self._player[0], self._player[1]] = True 142 | target_state = np.zeros((self._height, self._width), dtype=np.bool_) 143 | target_state[self._target[0], self._target[1]] = True 144 | state = np.stack( 145 | [self._walls, target_state, player_state], 146 | axis=-1, 147 | dtype=np.uint8, 148 | ) 149 | return { 150 | 'player': copy.deepcopy(self._player), 151 | 'target': copy.deepcopy(self._target), 152 | 'rgb': copy.deepcopy(np.array(self._rgb(state), dtype=np.uint8)), 153 | 'txt': self._text, 154 | 'coords': str( 155 | dict(player=self._player.tolist(), target=self._target.tolist()) 156 | ), 157 | } 158 | 159 | def step(self, action: str) -> dm_env.TimeStep: 160 | next_y, next_x = self._player 161 | 162 | match action: 163 | case 'left': 164 | next_x -= 1 165 | case 'right': 166 | next_x += 1 167 | case 'up': 168 | next_y -= 1 169 | case 'down': 170 | next_y += 1 171 | case _: 172 | raise ValueError(f'Unsupported action: {action}') 173 | 174 | if not self._walls[next_y, next_x]: 175 | self._player = np.array([next_y, next_x]) 176 | if np.all(self._player == self._target): 177 | return dm_env.termination(observation=self._observation, reward=1) 178 | 179 | return dm_env.transition(observation=self._observation, reward=0) 180 | 181 | @property 182 | def legal_actions(self) -> list[str]: 183 | """Returns the legal actions.""" 184 | return ['left', 'right', 'up', 'down'] 185 | -------------------------------------------------------------------------------- /src/environments/crossword.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Crossword environment.""" 17 | 18 | import ast 19 | import collections 20 | from collections.abc import Mapping 21 | import dataclasses 22 | import logging 23 | import os 24 | import pathlib 25 | import re 26 | from typing import Literal 27 | 28 | import dm_env 29 | 30 | from lm_act.src import bagz 31 | from lm_act.src import config as config_lib 32 | from lm_act.src import interfaces 33 | 34 | 35 | _BASE_DIR_PATH = pathlib.Path( 36 | os.path.join( 37 | os.getcwd(), 38 | 'data/lm_act/', 39 | ) 40 | ) 41 | 42 | 43 | @dataclasses.dataclass(frozen=True, kw_only=True) 44 | class EnvironmentConfig(config_lib.Environment): 45 | """Configuration for the environment.""" 46 | 47 | name: str = 'crossword' 48 | observation_type: Literal['txt'] = 'txt' 49 | 50 | 51 | class Crossword(interfaces.Environment): 52 | """A simple crossword environment.""" 53 | 54 | def __init__( 55 | self, 56 | config: EnvironmentConfig, 57 | opening_paths: list[pathlib.Path] | None = None, 58 | openings: list[str] | None = None, 59 | ) -> None: 60 | if openings is not None: 61 | self._openings = openings 62 | elif opening_paths is not None: 63 | self._openings = list() 64 | for opening_path in opening_paths: 65 | puzzles = bagz.BagReader( 66 | (opening_path / 'observations_puzzle.bag').as_posix() 67 | ) 68 | self._openings.append(puzzles[0].decode('utf-8')) 69 | else: 70 | raise ValueError('Either `openings` or `opening_paths` must be provided.') 71 | 72 | self._puzzle: str = None 73 | self._grid: str = None 74 | self._clues: str = None 75 | self._coords_by_solution: Mapping[str, tuple[int, int]] = None 76 | self._unfound_solutions: list[str] = None 77 | 78 | self._words_by_length = collections.defaultdict(set) 79 | 80 | with open(_BASE_DIR_PATH / 'crossword' / 'words.txt', 'r') as f: 81 | for word in f: 82 | word = word.strip() 83 | self._words_by_length[len(word)].add(word) 84 | 85 | def reset(self) -> dm_env.TimeStep: 86 | self._puzzle = self._openings.pop(0) 87 | self._grid, self._clues, solutions, *_ = self._puzzle.split('\n\n') 88 | 89 | if (grid_len := len(self._grid)) != 554: 90 | raise ValueError(f'Invalid grid size, should be 554 but is {grid_len}.') 91 | 92 | self._coords_by_solution = dict() 93 | for solution in solutions.split('Solution:\n')[1].split('\n'): 94 | # The solution has the format "A1: word (row, col)", so we need some regex 95 | # magic to separate the word and the coordinates. 96 | if (match := re.match(r'(.*?)\s*(\(\d+,\s*\d+\))', solution)) is not None: 97 | sol, coords = match.groups() 98 | self._coords_by_solution[sol] = ast.literal_eval(coords) 99 | 100 | self._unfound_solutions = list(self._coords_by_solution.keys()) 101 | 102 | return dm_env.restart(observation=self._observation) 103 | 104 | @property 105 | def _observation(self): 106 | return { 107 | 'txt': self._grid + '\n\n' + self._clues, 108 | 'puzzle': ( 109 | self._grid 110 | + '\n\n' 111 | + self._clues 112 | + '\n\nSolution:\n' 113 | + '\n'.join( 114 | f'{sol} {self._coords_by_solution[sol]}' 115 | for sol in self._unfound_solutions 116 | ) 117 | ), 118 | } 119 | 120 | def _update_grid(self, solution: str) -> None: 121 | is_vertical = solution[0] == 'D' 122 | row, col = self._coords_by_solution[solution] 123 | 124 | # Account for the offsets. 125 | row = row * 2 + 1 126 | col = col * 5 + 3 127 | 128 | for character in solution.split(': ')[1]: 129 | idx = row * 37 + col 130 | self._grid = ( 131 | self._grid[: idx - 1] + ' ' + character + self._grid[idx + 1 :] 132 | ) 133 | row += 2 * is_vertical 134 | col += 5 * (1 - is_vertical) 135 | 136 | def step(self, action: str) -> dm_env.TimeStep: 137 | # The environment should not be case-sensitive. 138 | action = action.upper().strip() 139 | 140 | if action in self._unfound_solutions: 141 | self._unfound_solutions.remove(action) 142 | self._update_grid(action) 143 | 144 | if self._unfound_solutions: 145 | return dm_env.transition(reward=1, observation=self._observation) 146 | return dm_env.termination(reward=1, observation=self._observation) 147 | 148 | if action in self._coords_by_solution.keys(): 149 | return dm_env.transition(reward=0, observation=self._observation) 150 | 151 | # We continue if the proposed word is incorrect but has the correct length. 152 | try: 153 | idx, word = action.split(': ') 154 | solution = [ 155 | sol for sol in self._coords_by_solution if sol.startswith(idx) 156 | ][0] 157 | if len(solution.split(': ')[1]) == len(word): 158 | return dm_env.transition(reward=0, observation=self._observation) 159 | except (IndexError, ValueError): 160 | logging.info('Invalid action: %s', action) 161 | return dm_env.termination(reward=-1, observation=self._observation) 162 | 163 | # Otherwise, we terminate the game. 164 | return dm_env.termination(reward=-1, observation=self._observation) 165 | 166 | def action_is_invalid(self, action: str) -> bool: 167 | """Returns whether the action in the format `A0: word` or `D1: word`.""" 168 | return not re.match(r'^[AD]\d+:\s\w+$', action) 169 | 170 | @property 171 | def legal_actions(self) -> list[str]: 172 | """Returns the legal actions (all possible words of the correct length).""" 173 | legal_actions = list() 174 | for solution in self._unfound_solutions: 175 | idx, fill = solution.split(': ') 176 | for word in self._words_by_length[len(fill)]: 177 | legal_actions.append(f'{idx}: {word}') 178 | return legal_actions 179 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LMAct: A Benchmark for In-Context Imitation Learning with Long Multimodal Demonstrations 2 | 3 |

4 | Overview figure 5 |

6 | 7 | This repository provides an implementation of our ICML 2025 paper [LMAct: A Benchmark for In-Context Imitation Learning with Long Multimodal Demonstrations](https://arxiv.org/abs/2412.01441). 8 | 9 | > In this paper, we present a benchmark to pressure-test today’s frontier 10 | models’ multimodal decision-making capabilities in the very long-context regime 11 | (up to one million tokens) and investigate whether these models can learn from 12 | large numbers of expert demonstrations in their context. 13 | We evaluate the performance of Claude 3.5 Sonnet, Gemini 1.5 Flash, Gemini 1.5 14 | Pro, Gemini 2.0 Flash Experimental, GPT-4o, o1-mini, o1-preview, and o1 as 15 | policies across a battery of simple interactive decision-making tasks: playing 16 | tic-tac-toe, chess, and Atari, navigating grid worlds, solving crosswords, and 17 | controlling a simulated cheetah. 18 | We study increasing amounts of expert demonstrations in the context — from no 19 | demonstrations to 512 full episodes. 20 | Across our tasks, models rarely manage to fully reach expert performance, and 21 | often, presenting more demonstrations has little effect. 22 | Some models steadily improve with more demonstrations on a few tasks. 23 | We investigate the effect of encoding observations as text or images and the 24 | impact of chain-of-thought prompting. 25 | To help quantify the impact of other approaches and future innovations, we open 26 | source our benchmark that covers the zero-, few-, and many-shot regimes in a 27 | unified evaluation. 28 | 29 | ## Contents 30 | 31 | ``` 32 | . 33 | | 34 | ├── crafter - Crafter (needs to be downloaded) 35 | | 36 | ├── data - Expert demonstrations (need to be downloaded) 37 | | 38 | ├── src 39 | | ├── agents 40 | | │   ├── chess.py - Stockfish agent (chess expert) 41 | | │   ├── crossword.py - Oracle agent (crossword expert) 42 | | │   ├── grid_world.py - Shortest path agent (grid world expert) 43 | | │   ├── random.py - Random action agent 44 | | │   └── tic_tac_toe.py - Minimax agent (tic-tac-toe expert) 45 | | | 46 | | ├── bagz.py - Readers for our .bag data files 47 | | ├── config.py - Experiment configurations 48 | | ├── constants.py - Project constants 49 | | | 50 | | ├── environments 51 | | │   ├── chess.py - Chess environment 52 | | │   ├── crossword.py - Crossword environment 53 | | │   ├── dm_control.py - DM Control environment 54 | | │   ├── grid_world.py - Grid world environment 55 | | │   └── tic_tac_toe.py - Tic-tac-toe environment 56 | | | 57 | | ├── evaluate.py - Evaluation loop 58 | | ├── interfaces.py - Project interfaces 59 | | ├── main.py - Experiment launch script 60 | | └── prompts.py - Prompt-building functionality 61 | | 62 | ├── Stockfish - Stockfish (needs to be installed) 63 | | 64 | ├── README.md 65 | └── requirements.txt - Dependencies 66 | ``` 67 | 68 | ## Installation 69 | 70 | Clone the source code into a local directory: 71 | 72 | ```bash 73 | git clone https://github.com/google-deepmind/lm_act.git 74 | cd lm_act 75 | ``` 76 | 77 | This repository requires Python 3.11. 78 | `pip install -r requirements.txt` will install all required dependencies. 79 | This is best done inside a [conda environment](https://www.anaconda.com/). 80 | To that end, install [Anaconda](https://www.anaconda.com/download#downloads). 81 | Then, create and activate the conda environment: 82 | 83 | ```bash 84 | conda create --name lm_act python=3.11 85 | conda activate lm_act 86 | ``` 87 | 88 | Install `pip` and use it to install all the dependencies: 89 | 90 | ```bash 91 | conda install pip 92 | pip install -r requirements.txt 93 | ``` 94 | 95 | ### Installing Crafter 96 | 97 | Download the crafter repository: 98 | 99 | ```bash 100 | git clone https://github.com/danijar/crafter.git 101 | ``` 102 | 103 | ### Installing Stockfish 104 | 105 | Download and compile the latest version of Stockfish (for Unix-like systems): 106 | 107 | ```bash 108 | git clone https://github.com/official-stockfish/Stockfish.git 109 | cd Stockfish/src 110 | make -j profile-build ARCH=x86-64-avx2 111 | cd ../.. 112 | ``` 113 | 114 | ### Downloading the Expert Demonstrations 115 | 116 | To download our expert demonstrations to the correct locations, run the 117 | following command: 118 | 119 | ```bash 120 | cd data 121 | ./download.sh 122 | cd .. 123 | ``` 124 | 125 | ## Usage 126 | 127 | Before running any code, make sure to activate the conda environment and set the 128 | `PYTHONPATH`: 129 | 130 | ```bash 131 | conda activate lm_act 132 | export PYTHONPATH=$(pwd)/.. 133 | ``` 134 | 135 | To evaluate an agent, run the following command: 136 | ```bash 137 | python src/main.py \ 138 | --environment=tic_tac_toe \ 139 | --observation_type=txt \ 140 | --agent=random \ 141 | --num_demonstrations=0 142 | ``` 143 | 144 | ## Citing this work 145 | 146 | ```latex 147 | @inproceedings{ruoss2025lmact, 148 | author = {Anian Ruoss and 149 | Fabio Pardo and 150 | Harris Chan and 151 | Bonnie Li and 152 | Volodymyr Mnih and 153 | Tim Genewein}, 154 | title = {{LMAct}: A Benchmark for In-Context Imitation Learning with 155 | Long Multimodal Demonstrations 156 | booktitle = {{ICML}}, 157 | year = {2025}, 158 | } 159 | ``` 160 | 161 | ## License and disclaimer 162 | 163 | Copyright 2024 DeepMind Technologies Limited 164 | 165 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 166 | you may not use this file except in compliance with the Apache 2.0 license. 167 | You may obtain a copy of the Apache 2.0 license at: 168 | https://www.apache.org/licenses/LICENSE-2.0 169 | 170 | All other materials are licensed under the Creative Commons Attribution 4.0 171 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 172 | https://creativecommons.org/licenses/by/4.0/legalcode 173 | 174 | Unless required by applicable law or agreed to in writing, all software and 175 | materials distributed here under the Apache 2.0 or CC-BY licenses are 176 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 177 | either express or implied. See the licenses for the specific language governing 178 | permissions and limitations under those licenses. 179 | 180 | This is not an official Google product. 181 | -------------------------------------------------------------------------------- /src/environments/tic_tac_toe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tic-Tac-Toe environment.""" 17 | 18 | import copy 19 | import dataclasses 20 | import io 21 | import pathlib 22 | from typing import Literal 23 | 24 | import dm_env 25 | from matplotlib import pyplot as plt 26 | import numpy as np 27 | from PIL import Image 28 | from PIL import ImageDraw 29 | 30 | from lm_act.src import bagz 31 | from lm_act.src import config as config_lib 32 | from lm_act.src import interfaces 33 | 34 | 35 | ROWS = ['a', 'b', 'c'] 36 | COLS = ['1', '2', '3'] 37 | AXES = [ 38 | # Rows. 39 | ('row_0', np.array([0, 1, 2])), 40 | ('row_1', np.array([3, 4, 5])), 41 | ('row_2', np.array([6, 7, 8])), 42 | # Columns. 43 | ('col_0', np.array([0, 3, 6])), 44 | ('col_1', np.array([1, 4, 7])), 45 | ('col_2', np.array([2, 5, 8])), 46 | # Negative diagonal. 47 | ('neg_diag', np.array([0, 4, 8])), 48 | # Positive diagonal. 49 | ('pos_diag', np.array([6, 4, 2])), 50 | ] 51 | 52 | 53 | def _draw_board( 54 | board: np.ndarray, 55 | size: int = 768, 56 | color: str = 'black', 57 | background: str = 'white', 58 | render_coordinates: bool = True, 59 | ) -> tuple[bytes, np.ndarray]: 60 | """Returns the board as a PNG image and RGB array.""" 61 | image = Image.new('RGB', size=(size, size), color=background) 62 | draw = ImageDraw.Draw(image) 63 | width = size // 25 64 | 65 | # Draw the grid. 66 | draw.line( 67 | ((size // 3, 0), (size // 3, size)), 68 | fill=color, 69 | width=width, 70 | ) 71 | draw.line( 72 | ((2 * size // 3, 0), (2 * size // 3, size)), 73 | fill=color, 74 | width=width, 75 | ) 76 | draw.line( 77 | ((0, size // 3), (size, size // 3)), 78 | fill=color, 79 | width=width, 80 | ) 81 | draw.line( 82 | ((0, 2 * size // 3), (size, 2 * size // 3)), 83 | fill=color, 84 | width=width, 85 | ) 86 | 87 | # Draw the symbols. 88 | for row_idx, row in enumerate(board): 89 | for col_idx, symbol in enumerate(row): 90 | 91 | def _to_coord(idx: int, offset: int) -> int: 92 | return (4 * idx + offset) * size // 12 93 | 94 | coords = { 95 | 'top_left': (_to_coord(col_idx, 1), _to_coord(row_idx, 1)), 96 | 'top_right': (_to_coord(col_idx, 1), _to_coord(row_idx, 3)), 97 | 'bottom_left': (_to_coord(col_idx, 3), _to_coord(row_idx, 1)), 98 | 'bottom_right': (_to_coord(col_idx, 3), _to_coord(row_idx, 3)), 99 | } 100 | match symbol: 101 | case 'o': 102 | draw.ellipse( 103 | [coords['top_left'], coords['bottom_right']], 104 | outline=color, 105 | width=width, 106 | ) 107 | case 'x': 108 | draw.line( 109 | (coords['top_left'], coords['bottom_right']), 110 | fill=color, 111 | width=width, 112 | ) 113 | draw.line( 114 | (coords['top_right'], coords['bottom_left']), 115 | fill=color, 116 | width=width, 117 | ) 118 | 119 | if render_coordinates: 120 | image = Image.fromarray(_add_coordinates(np.array(image))) 121 | 122 | with io.BytesIO() as buffer: 123 | image.save(buffer, format='PNG') 124 | return buffer.getvalue(), np.array(image) 125 | 126 | 127 | def _add_coordinates(rgb_image: np.ndarray) -> np.ndarray: 128 | """Adds coordinates to the image. 129 | 130 | Args: 131 | rgb_image: The RGB image array to add coordinates to. 132 | 133 | Returns: 134 | The RGB image array with coordinates added. 135 | """ 136 | rgb_image = rgb_image.astype(np.uint8) 137 | height, width, _ = rgb_image.shape 138 | 139 | x_ticks = np.linspace(height / 6, 5 * height / 6, 3) 140 | y_ticks = np.linspace(width / 6, 5 * width / 6, 3) 141 | 142 | x_tick_labels = ['1', '2', '3'] 143 | y_tick_labels = ['a', 'b', 'c'] 144 | 145 | font_size = round(8 / 256 * height) 146 | 147 | fig = plt.figure(figsize=(height / 100, width / 100)) 148 | plt.imshow(rgb_image) 149 | 150 | plt.xticks(x_ticks, x_tick_labels, fontsize=font_size) 151 | plt.yticks(y_ticks, y_tick_labels, fontsize=font_size) 152 | 153 | plt.tick_params( 154 | axis='both', 155 | which='both', 156 | labeltop=True, 157 | labelright=True, 158 | labelbottom=True, 159 | labelleft=True, 160 | length=0, 161 | ) 162 | 163 | fig.tight_layout() 164 | fig.canvas.draw() 165 | new_width, new_height = fig.canvas.get_width_height() 166 | rgb_buffer = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 167 | 168 | plt.close(fig) 169 | 170 | return rgb_buffer.reshape((new_height, new_width, 3)) 171 | 172 | 173 | def legal_actions(board: np.ndarray) -> list[str]: 174 | return [ 175 | ROWS[coords[0]] + COLS[coords[1]] for coords in np.argwhere(board == ' ') 176 | ] 177 | 178 | 179 | @dataclasses.dataclass(frozen=True, kw_only=True) 180 | class EnvironmentConfig(config_lib.Environment): 181 | """Configuration for the environment.""" 182 | 183 | name: str = 'tic_tac_toe' 184 | observation_type: Literal['txt', 'png'] = 'png' 185 | render_coordinates: bool = False 186 | seed: int = 0 187 | 188 | def __post_init__(self): 189 | if self.render_coordinates: 190 | if self.observation_type == 'txt': 191 | raise ValueError( 192 | 'Rendering coordinates is only supported for `png` observations.' 193 | ) 194 | object.__setattr__(self, 'name', 'tic_tac_toe_with_coordinates') 195 | 196 | 197 | class TicTacToe(interfaces.Environment): 198 | """A simple tic-tac-toe environment to play against a random policy.""" 199 | 200 | def __init__( 201 | self, 202 | config: EnvironmentConfig, 203 | opening_paths: list[pathlib.Path] | None = None, 204 | openings: list[tuple[np.ndarray, bool]] | None = None, 205 | ) -> None: 206 | if openings is not None: 207 | self._openings = openings 208 | elif opening_paths is not None: 209 | self._openings = list() 210 | for opening_path in opening_paths: 211 | boards = bagz.BagReader( 212 | (opening_path / 'observations_board.bag').as_posix() 213 | ) 214 | board = np.frombuffer(boards[0], dtype=np.dtype(' dm_env.TimeStep: 226 | self._board, self._player_is_x = self._openings.pop(0) 227 | self._board = copy.deepcopy(self._board) 228 | return dm_env.restart(observation=self._observation) 229 | 230 | @property 231 | def _observation(self): 232 | return { 233 | 'board': copy.deepcopy(self._board), 234 | 'txt': '\n----------\n'.join(' | '.join(row) for row in self._board), 235 | # Gemini resizes all images to 768x768, so we might as well do it here. 236 | 'png': _draw_board( 237 | board=self._board, 238 | size=768, 239 | render_coordinates=self._render_coordinates, 240 | )[0], 241 | # In contrast, Sequence Storage can only render images up to 256x256. 242 | 'rgb': _draw_board( 243 | board=self._board, 244 | size=256, 245 | render_coordinates=self._render_coordinates, 246 | )[1], 247 | 'symbol': self.symbol(is_player=True), 248 | } 249 | 250 | def symbol(self, is_player: bool) -> str: 251 | if is_player: 252 | return 'x' if self._player_is_x else 'o' 253 | return 'o' if self._player_is_x else 'x' 254 | 255 | def _turn( 256 | self, 257 | action: str, 258 | is_player: bool, 259 | ) -> None | dm_env.TimeStep: 260 | symbol = self.symbol(is_player=is_player) 261 | self._board[ROWS.index(action[0]), COLS.index(action[1])] = symbol 262 | 263 | for _, axis in AXES: 264 | line = np.take_along_axis(self._board.flatten(), axis, axis=None) 265 | if (line == symbol).all(): 266 | return dm_env.termination( 267 | observation=self._observation, 268 | reward=1 if is_player else -1, 269 | ) 270 | 271 | if ' ' not in self._board: 272 | return dm_env.termination(observation=self._observation, reward=0) 273 | 274 | return None 275 | 276 | def step(self, action: str) -> dm_env.TimeStep: 277 | if action not in self.legal_actions: 278 | raise ValueError(f'Action {action} is illegal.') 279 | 280 | if (outcome := self._turn(action=action, is_player=True)) is not None: 281 | return outcome 282 | 283 | # The adversary randomly chooses a legal action. 284 | move = self._rng.choice(self.legal_actions) 285 | 286 | if (outcome := self._turn(action=move, is_player=False)) is not None: 287 | return outcome 288 | 289 | return dm_env.transition(observation=self._observation, reward=0) 290 | 291 | @property 292 | def legal_actions(self) -> list[str]: 293 | return legal_actions(self._board) 294 | -------------------------------------------------------------------------------- /src/bagz.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Bagz file reader/writer and PyGrain-compatible data source for POSIX systems. 17 | 18 | Bagz is a file format for storing a sequence of string records, typically 19 | serialised protocol buffers. It supports fast index based look-up. 20 | """ 21 | 22 | import bisect 23 | from collections.abc import Sequence 24 | import itertools 25 | import mmap 26 | import os 27 | import re 28 | import shutil 29 | import struct 30 | from typing import Any, SupportsIndex 31 | 32 | from etils import epath 33 | from typing_extensions import Self 34 | import zstandard as zstd 35 | 36 | 37 | class BagFileReader(Sequence[bytes]): 38 | """Reader for single Bagz files.""" 39 | 40 | def __init__( 41 | self, 42 | filename: str, 43 | *, 44 | separate_limits: bool = False, 45 | decompress: bool | None = None, 46 | ) -> None: 47 | """Creates a BagFileReader. 48 | 49 | Args: 50 | filename: The name of the single Bagz file to read. 51 | separate_limits: Whether the limits are stored in a separate file. 52 | decompress: Whether to decompress the records. If None, uses the file 53 | extension to determine whether to decompress. 54 | """ 55 | if decompress or (decompress is None and filename.endswith('.bagz')): 56 | self._process = lambda x: zstd.decompress(x) if x else x 57 | else: 58 | self._process = lambda x: x 59 | self._filename = filename 60 | fd = os.open(filename, os.O_RDONLY) 61 | try: 62 | self._records = mmap.mmap(fd, 0, access=mmap.ACCESS_READ) 63 | file_size = self._records.size() 64 | except ValueError: 65 | self._records = b'' 66 | file_size = 0 67 | finally: 68 | os.close(fd) 69 | if separate_limits: 70 | directory, name = os.path.split(filename) 71 | fd = os.open(os.path.join(directory, 'limits.' + name), os.O_RDONLY) 72 | try: 73 | self._limits = mmap.mmap(fd, 0, access=mmap.ACCESS_READ) 74 | index_size = self._limits.size() 75 | except ValueError: 76 | self._limits = b'' 77 | index_size = 0 78 | finally: 79 | os.close(fd) 80 | index_start = 0 81 | else: 82 | if 0 < file_size < 8: 83 | raise ValueError('Bagz file too small') 84 | self._limits = self._records 85 | if file_size: 86 | (index_start,) = struct.unpack('= index_start 90 | index_size = file_size - index_start 91 | assert index_size % 8 == 0 92 | self._num_records = index_size // 8 93 | self._limits_start = index_start 94 | 95 | def __len__(self) -> int: 96 | """Returns the number of records in the Bagz file.""" 97 | return self._num_records 98 | 99 | def __getitem__(self, index: SupportsIndex) -> bytes: 100 | """Returns a record from the Bagz file.""" 101 | i = index.__index__() 102 | if not 0 <= i < self._num_records: 103 | raise IndexError('bagz.BragReader index out of range') 104 | end = i * 8 + self._limits_start 105 | if i: 106 | rec_range = struct.unpack('<2q', self._limits[end - 8 : end + 8]) 107 | else: 108 | rec_range = (0, *struct.unpack(' None: 122 | """Creates a BagShardReader. 123 | 124 | Args: 125 | filename: The name of the sharded Bagz file to read. 126 | separate_limits: Whether the limits are stored in a separate file. 127 | decompress: Whether to decompress the records. If None, uses the file 128 | extension to determine whether to decompress. 129 | """ 130 | matches = re.findall(r'@(\d+)', filename) 131 | assert len(matches) == 1 132 | num_files = int(matches[0]) 133 | assert num_files < 100_000 134 | self._bags = tuple( 135 | BagFileReader( 136 | filename=re.sub( 137 | r'@(\d+)', f'-{idx:05d}-of-{num_files:05d}', filename 138 | ), 139 | separate_limits=separate_limits, 140 | decompress=decompress, 141 | ) 142 | for idx in range(num_files) 143 | ) 144 | self._accum = tuple(itertools.accumulate(map(len, self._bags))) 145 | 146 | def __len__(self) -> int: 147 | """Returns the number of records in the Bagz file.""" 148 | return self._accum[-1] 149 | 150 | def __getitem__(self, index: int) -> bytes: 151 | if index < 0: 152 | index += self._accum[-1] 153 | if seqn := bisect.bisect_left(self._accum, index + 1): 154 | index -= self._accum[seqn - 1] 155 | return self._bags[seqn][index] 156 | 157 | 158 | class BagReader(Sequence[bytes]): 159 | """Reader for Bagz files.""" 160 | 161 | def __init__( 162 | self, 163 | filename: str, 164 | *, 165 | separate_limits: bool = False, 166 | decompress: bool | None = None, 167 | ) -> None: 168 | """Creates a BagReader. 169 | 170 | Args: 171 | filename: The name of the Bagz file to read. Supports the @N shard syntax 172 | (where @0 corresponds to the single file case). If the shard syntax does 173 | not parse, then `filename` is treated as a single file. 174 | separate_limits: Whether the limits are stored in a separate file. 175 | decompress: Whether to decompress the records. If None, uses the file 176 | extension to determine whether to decompress. 177 | """ 178 | if matches := re.findall(r'@(\d+)', filename): 179 | assert len(matches) == 1 180 | if int(matches[0]) != '0': 181 | reader_class = BagShardReader 182 | else: 183 | filename = filename.replace(matches[0], '') 184 | reader_class = BagFileReader 185 | else: 186 | reader_class = BagFileReader 187 | 188 | self._reader = reader_class( 189 | filename=filename, 190 | separate_limits=separate_limits, 191 | decompress=decompress, 192 | ) 193 | 194 | def __len__(self) -> int: 195 | """Returns the number of records in the Bagz file.""" 196 | return len(self._reader) 197 | 198 | def __getitem__(self, index: SupportsIndex) -> bytes: 199 | """Returns a record from the Bagz file.""" 200 | return self._reader[index] 201 | 202 | 203 | class BagWriter: 204 | """Writer for Bagz files.""" 205 | 206 | def __init__( 207 | self, 208 | filename: str, 209 | *, 210 | separate_limits: bool = False, 211 | compress: bool | None = None, 212 | compression_level: int = 0, 213 | ) -> None: 214 | """Creates a BagWriter. 215 | 216 | Args: 217 | filename: The name of the Bagz file to write. 218 | separate_limits: Whether to keep the limits in a separate file. 219 | compress: Whether to compress the records. If None, uses the file 220 | extension to determine whether to compress. 221 | compression_level: The compression level to use when compressing. 222 | """ 223 | if compress or (compress is None and filename.endswith('.bagz')): 224 | self._process = zstd.ZstdCompressor(level=compression_level).compress 225 | else: 226 | self._process = lambda x: x 227 | self._separate_limits = separate_limits 228 | directory, name = os.path.split(filename) 229 | self._records = open(filename, 'wb') 230 | self._limits = open(os.path.join(directory, 'limits.' + name), 'wb+') 231 | 232 | def write(self, data: bytes) -> None: 233 | """Writes a record to the Bagz file.""" 234 | if data: 235 | self._records.write(self._process(data)) 236 | self._limits.write(struct.pack(' None: 239 | """Flushes the Bagz file.""" 240 | self._records.flush() 241 | self._limits.flush() 242 | 243 | def __enter__(self) -> Self: 244 | return self 245 | 246 | def __exit__(self, exc_type, exc_value, traceback) -> None: 247 | """Ensures the Bagz file is closed when exiting a context.""" 248 | self.close() 249 | 250 | def close(self) -> None: 251 | """Concatenates the limits file to the end of the data file.""" 252 | if self._separate_limits: 253 | self._records.close() 254 | self._limits.close() 255 | else: 256 | self._limits.seek(0) 257 | shutil.copyfileobj(self._limits, self._records) 258 | self._records.close() 259 | os.unlink(self._limits.name) 260 | self._limits.close() 261 | 262 | 263 | class BagDataSource: 264 | """PyGrain-compatible data source for bagz files.""" 265 | 266 | def __init__(self, path: epath.PathLike) -> None: 267 | """Creates a new BagDataSource object. 268 | 269 | Args: 270 | path: The path to the bag file. 271 | """ 272 | self._path = os.fspath(path) 273 | self._reader = BagReader(self._path) 274 | self._num_records = len(self._reader) 275 | 276 | def __len__(self) -> int: 277 | return self._num_records 278 | 279 | def __getitem__(self, record_key: SupportsIndex) -> bytes: 280 | return self._reader[record_key] 281 | 282 | def __getstate__(self) -> dict[str, Any]: 283 | state = self.__dict__.copy() 284 | del state['_reader'] 285 | return state 286 | 287 | def __setstate__(self, state) -> None: 288 | self.__dict__.update(state) 289 | self._reader = BagReader(self._path) 290 | 291 | def __repr__(self) -> str: 292 | return f'BagDataSource(path={self._path!r}' 293 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Evaluates a single episode.""" 17 | 18 | import copy 19 | import os 20 | import pathlib 21 | from typing import Any 22 | 23 | from absl import logging 24 | import numpy as np 25 | 26 | from lm_act.src import bagz 27 | from lm_act.src import config as config_lib 28 | from lm_act.src import constants 29 | from lm_act.src import prompts 30 | 31 | 32 | _BASE_DIR_PATH = pathlib.Path( 33 | os.path.join( 34 | os.getcwd(), 35 | 'data/lm_act/', 36 | ) 37 | ) 38 | 39 | 40 | def _load_demonstrations_and_opening_path( 41 | rng: np.random.Generator, 42 | config: config_lib.Experiment, 43 | ) -> tuple[list[list[Any]], list[list[Any]], pathlib.Path]: 44 | """Loads the demonstrations (observations & actions) and the opening path. 45 | 46 | Args: 47 | rng: The random number generator. 48 | config: The experiment configuration. 49 | 50 | Returns: 51 | - The demonstrations episodes, consisting of observations and actions. 52 | - The opening path, i.e., the path to the opening (which is used to set the 53 | initial state of the environment for the evaluation episode). 54 | 55 | Raises: 56 | ValueError: If there are insufficient demonstrations in the directory. 57 | """ 58 | base_dir_path = _BASE_DIR_PATH / config.environment.name 59 | demonstration_names = [ 60 | file_name 61 | for file_name in os.listdir(base_dir_path) 62 | if file_name.startswith('demonstration') 63 | ] 64 | if len(demonstration_names) < config.num_demonstrations + 1: 65 | raise ValueError( 66 | f'Insufficient demonstrations in {base_dir_path}: Need at least' 67 | f' {config.num_demonstrations + 1} but only found' 68 | f' {len(demonstration_names)}.' 69 | ) 70 | 71 | if config.replay_episode: 72 | assert config.num_demonstrations == 1 73 | num_openings = config.num_demonstrations 74 | else: 75 | # We need to add 1 to account for the opening that that will be evaluated. 76 | num_openings = config.num_demonstrations + 1 77 | demonstration_names = rng.choice( 78 | demonstration_names, 79 | size=num_openings, 80 | replace=False, 81 | shuffle=False, 82 | ) 83 | opening_name = demonstration_names[-1] 84 | demonstration_names = demonstration_names[: config.num_demonstrations] 85 | 86 | demo_observations = list() 87 | demo_actions = list() 88 | 89 | match config.environment.observation_type: 90 | case 'rgb': 91 | rgb_shape = constants.get_rgb_shape(config.environment.name) 92 | observation_decode_fn = lambda x: np.frombuffer( 93 | x, 94 | dtype=np.uint8, 95 | ).reshape(rgb_shape) 96 | case 'png': 97 | # PNG data does not need to be decoded. 98 | observation_decode_fn = lambda x: x 99 | case _: 100 | observation_decode_fn = lambda x: x.decode('utf-8') 101 | action_decode_fn = lambda x: x.decode('utf-8') 102 | 103 | for demonstration_name in demonstration_names: 104 | demo_dir_path = base_dir_path / demonstration_name 105 | observations_path = ( 106 | demo_dir_path 107 | / f'observations_{config.environment.observation_type}.bag' 108 | ) 109 | actions_path = ( 110 | demo_dir_path / f'actions_{config.environment.action_type}.bag' 111 | ) 112 | observations = bagz.BagReader(observations_path.as_posix()) 113 | actions = bagz.BagReader(actions_path.as_posix()) 114 | assert len(observations) == len(actions) 115 | demo_observations.append(list(map(observation_decode_fn, observations))) 116 | demo_actions.append(list(map(action_decode_fn, actions))) 117 | 118 | return demo_observations, demo_actions, base_dir_path / opening_name 119 | 120 | 121 | def _create_demonstration_prompt( 122 | config: config_lib.Experiment, 123 | demo_observations: list[list[Any]], 124 | demo_actions: list[list[Any]], 125 | ) -> tuple[str, dict[str, Any]]: 126 | """Returns the demonstration prompt and content for the given config.""" 127 | content_by_tag = dict() 128 | demo_prompts = list() 129 | 130 | for demo_idx, (observations, actions) in enumerate( 131 | zip(demo_observations, demo_actions) 132 | ): 133 | for step_idx, (observation, action) in enumerate( 134 | zip(observations, actions) 135 | ): 136 | match config.environment.observation_type: 137 | case 'fen' | 'coords': 138 | demo_prompt = f'Observation: {observation} ' 139 | case 'dict': 140 | demo_prompt = f'Observation: {observation}\n' 141 | case 'pgn' | 'txt': 142 | demo_prompt = f'Observation:\n{observation}\n' 143 | case 'rgb' | 'png': 144 | tag = f'' 145 | content_by_tag[tag] = observation 146 | demo_prompt = f'Observation: {tag} ' 147 | case _: 148 | raise ValueError( 149 | 'Unsupported observation type:' 150 | f' {config.environment.observation_type}' 151 | ) 152 | 153 | demo_prompt += f'Action: {action}' 154 | demo_prompts.append(demo_prompt) 155 | demo_prompts.append('\n') 156 | demo_prompts.append('\n') 157 | 158 | demonstration_prompt = prompts.build_demonstration_prompt( 159 | demonstrations=''.join(demo_prompts), 160 | ) 161 | logging.info('Demonstration prompt: %s', demonstration_prompt) 162 | return demonstration_prompt, content_by_tag 163 | 164 | 165 | def _create_trajectory_prompt( 166 | config: config_lib.Experiment, 167 | observations: list[Any], 168 | actions: list[Any], 169 | legal_actions: list[str], 170 | ) -> tuple[str, dict[str, Any]]: 171 | """Returns the trajectory prompt and content for the given config.""" 172 | content_by_tag = dict() 173 | trajectory_prompts = list() 174 | 175 | # The first action is a dummy action so we place it at the end of the list. 176 | actions = np.roll(copy.deepcopy(actions), -1) 177 | 178 | for step_idx, (observation, action) in enumerate(zip(observations, actions)): 179 | match config.environment.observation_type: 180 | case 'fen' | 'coords': 181 | trajectory_prompt = f'Observation: {observation} ' 182 | case 'dict': 183 | trajectory_prompt = f'Observation: {observation}\n' 184 | case 'pgn' | 'txt': 185 | trajectory_prompt = f'Observation:\n{observation}\n' 186 | case 'rgb' | 'png': 187 | tag = f'' 188 | content_by_tag[tag] = observation 189 | trajectory_prompt = f'Observation: {tag} ' 190 | case _: 191 | raise ValueError( 192 | 'Unsupported observation type:' 193 | f' {config.environment.observation_type}' 194 | ) 195 | 196 | if config.prompt.include_past_actions and step_idx < len(actions) - 1: 197 | trajectory_prompt += f'Action: {action}' 198 | 199 | if trajectory_prompt: 200 | trajectory_prompts.append(trajectory_prompt) 201 | trajectory_prompts.append('\n') 202 | 203 | trajectory_prompt = prompts.build_trajectory_prompt( 204 | config=config, 205 | trajectory=''.join(trajectory_prompts), 206 | legal_actions=legal_actions, 207 | ) 208 | logging.info('Current trajectory prompt: %s', trajectory_prompt) 209 | return trajectory_prompt, content_by_tag 210 | 211 | 212 | def evaluate_episode_replay( 213 | episode_idx: int, 214 | config: config_lib.Experiment, 215 | ) -> int: 216 | """Returns the number of correctly replayed actions for a single episode.""" 217 | 218 | # Every episode has to initialize the RNG with a different seed. 219 | rng = np.random.default_rng(seed=episode_idx) 220 | 221 | logging.info('Setting up the agent: %s.', config.agent.name) 222 | agent = constants.get_agent_builder(config.agent.name)(config=config.agent) 223 | 224 | logging.info('Loading the demonstrations and the evaluation opening name.') 225 | demo_observations, demo_actions, opening_path = ( 226 | _load_demonstrations_and_opening_path(rng=rng, config=config) 227 | ) 228 | assert len(demo_observations) == 1 229 | assert len(demo_actions) == 1 230 | 231 | logging.info('Replaying episode %d (opening %s).', episode_idx, opening_path) 232 | 233 | logging.info('Creating the demonstration chunks.') 234 | demonstration_prompt, demonstration_prompt_data = ( 235 | _create_demonstration_prompt( 236 | config=config, 237 | demo_observations=demo_observations, 238 | demo_actions=demo_actions, 239 | ) 240 | ) 241 | 242 | num_correctly_replayed_actions = 0 243 | 244 | for step, (demo_observation, demo_action) in enumerate( 245 | zip(demo_observations[0], demo_actions[0]) 246 | ): 247 | trajectory_prompt, trajectory_prompt_data = _create_trajectory_prompt( 248 | config=config, 249 | observations=demo_observations[0][: step + 1], 250 | actions=[None] + demo_actions[0][:step], # Dummy initial action. 251 | legal_actions=list(), # We cannot compute the legal actions. 252 | ) 253 | sample = agent.step( 254 | observation={ 255 | 'prompt': demonstration_prompt + trajectory_prompt, 256 | 'prompt_data': demonstration_prompt_data | trajectory_prompt_data, 257 | }, 258 | environment=None, 259 | rng=rng, 260 | ) 261 | replayed_action_is_correct = sample == demo_action 262 | num_correctly_replayed_actions += replayed_action_is_correct 263 | 264 | logging.info({ 265 | 'demo_observation': demo_observation, 266 | 'demo_action': demo_action, 267 | 'sample': sample, 268 | 'replayed_action_is_correct': replayed_action_is_correct, 269 | }) 270 | 271 | return num_correctly_replayed_actions 272 | 273 | 274 | def evaluate_episode( 275 | episode_idx: int, 276 | config: config_lib.Experiment, 277 | ) -> tuple[float, int, int, int, int]: 278 | """Evaluates a single episode.""" 279 | 280 | # Every episode has to initialize the RNG with a different seed. 281 | rng = np.random.default_rng(seed=episode_idx) 282 | 283 | logging.info('Setting up the agent: %s.', config.agent.name) 284 | agent = constants.get_agent_builder(config.agent.name)(config=config.agent) 285 | 286 | logging.info('Loading the demonstrations and the evaluation opening name.') 287 | demo_observations, demo_actions, opening_path = ( 288 | _load_demonstrations_and_opening_path(rng=rng, config=config) 289 | ) 290 | 291 | logging.info( 292 | 'Evaluating episode %d with opening %s.', episode_idx, opening_path 293 | ) 294 | 295 | logging.info('Creating the demonstration chunks.') 296 | demonstration_prompt, demonstration_prompt_data = ( 297 | _create_demonstration_prompt( 298 | config=config, 299 | demo_observations=demo_observations, 300 | demo_actions=demo_actions, 301 | ) 302 | ) 303 | 304 | logging.info('Setting up the environment: %s.', config.environment.name) 305 | env = constants.get_environment_builder(config.environment.name)( 306 | config=config.environment, 307 | opening_paths=[opening_path], 308 | ) 309 | time_step = env.reset() 310 | 311 | observations = [time_step.observation[config.environment.observation_type]] 312 | rewards = [time_step.reward] 313 | actions = [None] # Dummy action for the initial observation. 314 | 315 | num_illegal_actions = num_invalid_actions = num_empty_actions = 0 316 | 317 | for _ in range(config.num_evaluation_steps): 318 | if time_step.last(): 319 | break 320 | 321 | trajectory_prompt, trajectory_prompt_data = _create_trajectory_prompt( 322 | config=config, 323 | observations=observations, 324 | actions=actions, 325 | legal_actions=env.legal_actions, 326 | ) 327 | sample = agent.step( 328 | observation=time_step.observation 329 | | { 330 | 'prompt': demonstration_prompt + trajectory_prompt, 331 | 'prompt_data': demonstration_prompt_data | trajectory_prompt_data, 332 | }, 333 | environment=env, 334 | rng=rng, 335 | ) 336 | 337 | sample_is_empty = not sample 338 | num_empty_actions += sample_is_empty 339 | 340 | if sample_is_invalid := env.action_is_invalid(sample): 341 | num_invalid_actions += 1 342 | # If the sample is invalid, we also always consider it illegal. 343 | sample_is_illegal = True 344 | num_illegal_actions += 1 345 | elif sample_is_illegal := env.action_is_illegal(sample): 346 | num_illegal_actions += 1 347 | 348 | action = env.sample_legal_action(rng) if sample_is_illegal else sample 349 | 350 | logging.info({ 351 | 'observation': time_step.observation[ 352 | config.environment.observation_type 353 | ], 354 | 'reward': time_step.reward, 355 | 'action': action, 356 | 'sample': sample, 357 | 'sample_is_invalid': sample_is_invalid, 358 | 'sample_is_illegal': sample_is_illegal, 359 | 'sample_is_empty': sample_is_empty, 360 | 'step_type': int(time_step.step_type), 361 | }) 362 | 363 | time_step = env.step(action) 364 | observations.append( 365 | time_step.observation[config.environment.observation_type] 366 | ) 367 | rewards.append(time_step.reward) 368 | actions.append(action) 369 | 370 | logging.info({ 371 | 'rgb': ( 372 | time_step.observation['rgb'] 373 | if 'rgb' in time_step.observation 374 | else None 375 | ), 376 | 'observation': time_step.observation[config.environment.observation_type], 377 | 'reward': time_step.reward, 378 | 'action': None, 379 | 'sample': None, 380 | 'sample_is_invalid': None, 381 | 'sample_is_illegal': None, 382 | 'sample_is_empty': None, 383 | 'step_type': int(time_step.step_type), 384 | }) 385 | 386 | score = sum(rewards[1:]) # Skip the first reward since it is always None. 387 | num_steps = len(rewards) - 1 388 | 389 | return ( 390 | score, 391 | num_steps, 392 | num_invalid_actions, 393 | num_illegal_actions, 394 | num_empty_actions, 395 | ) 396 | --------------------------------------------------------------------------------