├── requirements.txt
├── .gitignore
├── data
└── download.sh
├── CONTRIBUTING.md
├── src
├── agents
│ ├── random.py
│ ├── crossword.py
│ ├── grid_world.py
│ ├── chess.py
│ └── tic_tac_toe.py
├── config.py
├── interfaces.py
├── prompts.py
├── constants.py
├── environments
│ ├── chess.py
│ ├── dm_control.py
│ ├── grid_world.py
│ ├── crossword.py
│ └── tic_tac_toe.py
├── main.py
├── bagz.py
└── evaluate.py
├── README.md
└── LICENSE
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | chess
3 | dm-control
4 | dm-env
5 | etils
6 | imageio
7 | immutabledict
8 | importlib_resources
9 | jax
10 | matplotlib
11 | numpy
12 | tqdm
13 | typing-extensions
14 | zstandard
15 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/data/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2025 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 |
18 | set -ex
19 |
20 | wget https://storage.googleapis.com/lm_act/LICENSE
21 | wget https://storage.googleapis.com/lm_act/lm_act.zip
22 | unzip lm_act.zip
23 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/src/agents/random.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Agent that random chooses a legal action."""
17 |
18 | import dataclasses
19 | from typing import Any
20 |
21 | import numpy as np
22 |
23 | from lm_act.src import config as config_lib
24 | from lm_act.src import interfaces
25 |
26 |
27 | @dataclasses.dataclass(frozen=True, kw_only=True)
28 | class RandomAgentConfig(config_lib.Agent):
29 | """Configuration for the random agent."""
30 |
31 | name: str = 'random'
32 |
33 |
34 | class RandomAgent(interfaces.Agent):
35 | """Random agent."""
36 |
37 | def __init__(self, config: RandomAgentConfig) -> None:
38 | pass
39 |
40 | def step(
41 | self,
42 | observation: Any,
43 | environment: interfaces.Environment,
44 | rng: np.random.Generator,
45 | ) -> str:
46 | return environment.sample_legal_action(rng)
47 |
--------------------------------------------------------------------------------
/src/agents/crossword.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Oracle agent for the crossword environment."""
17 |
18 | from collections.abc import Mapping
19 | import dataclasses
20 | from typing import Any
21 |
22 | import numpy as np
23 |
24 | from lm_act.src import config as config_lib
25 | from lm_act.src import interfaces
26 |
27 |
28 | @dataclasses.dataclass(frozen=True, kw_only=True)
29 | class OracleAgentConfig(config_lib.Agent):
30 | """Configuration for the oracle agent."""
31 |
32 | name: str = 'crossword_oracle'
33 |
34 |
35 | class OracleAgent(interfaces.Agent):
36 | """Interface for agents."""
37 |
38 | def __init__(self, config: OracleAgentConfig) -> None:
39 | pass
40 |
41 | def step(
42 | self,
43 | observation: Mapping[str, Any],
44 | environment: interfaces.Environment,
45 | rng: np.random.Generator,
46 | ) -> str:
47 | """Returns one of the missing words."""
48 | solutions = observation['puzzle'].split('Solution:\n')[1].split('\n')
49 | solutions = [sol.split(' (')[0] for sol in solutions]
50 | return rng.choice(solutions)
51 |
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Configuration for the experiment."""
17 |
18 | import dataclasses
19 |
20 |
21 | @dataclasses.dataclass(frozen=True, kw_only=True)
22 | class Agent:
23 | """Configuration for the agent."""
24 |
25 | name: str
26 | action_type: str = 'txt'
27 |
28 |
29 | @dataclasses.dataclass(frozen=True, kw_only=True)
30 | class Environment:
31 | """Configuration for the environment."""
32 |
33 | name: str
34 | observation_type: str
35 | action_type: str = 'txt'
36 |
37 |
38 | @dataclasses.dataclass(frozen=True, kw_only=True)
39 | class Prompt:
40 | """Configuration for the prompt."""
41 |
42 | show_legal_actions: bool
43 | use_chain_of_thought: bool
44 | include_past_actions: bool = True
45 |
46 |
47 | @dataclasses.dataclass(frozen=True, kw_only=True)
48 | class Experiment:
49 | """Configuration for the experiment."""
50 |
51 | num_demonstrations: int = 0
52 | num_evaluation_steps: int = 100
53 | num_evaluation_episodes: int = 100
54 | replay_episode: bool = False
55 |
56 | agent: Agent
57 | environment: Environment
58 | prompt: Prompt
59 |
60 | def __post_init__(self):
61 | if self.agent.action_type != self.environment.action_type:
62 | raise ValueError('The agent and environment action types must match.')
63 |
64 | if self.replay_episode and self.num_demonstrations != 1:
65 | raise ValueError('Replaying an episode requires exactly 1 demonstration.')
66 |
--------------------------------------------------------------------------------
/src/agents/grid_world.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Shortest path agent for grid world."""
17 |
18 | from collections.abc import Mapping
19 | import dataclasses
20 | from typing import Any
21 |
22 | import numpy as np
23 |
24 | from lm_act.src import config as config_lib
25 | from lm_act.src import interfaces
26 |
27 |
28 | @dataclasses.dataclass(frozen=True, kw_only=True)
29 | class ShortestPathAgentConfig(config_lib.Agent):
30 | """Configuration for the minimax agent."""
31 |
32 | name: str = 'grid_world_shortest_path'
33 |
34 |
35 | class ShortestPathAgent(interfaces.Agent):
36 | """Shortest path agent for grid world."""
37 |
38 | def __init__(self, config: ShortestPathAgentConfig) -> None:
39 | pass
40 |
41 | def step(
42 | self,
43 | observation: Mapping[str, Any],
44 | environment: interfaces.Environment,
45 | rng: np.random.Generator,
46 | ) -> str:
47 | """Returns an optimal action for the observation and legal actions."""
48 | player_y, player_x = observation['player']
49 | target_y, target_x = observation['target']
50 |
51 | optimal_actions = list()
52 |
53 | if target_x < player_x:
54 | optimal_actions.append('left')
55 | elif player_x < target_x:
56 | optimal_actions.append('right')
57 | if target_y < player_y:
58 | optimal_actions.append('up')
59 | elif player_y < target_y:
60 | optimal_actions.append('down')
61 |
62 | return rng.choice(optimal_actions)
63 |
--------------------------------------------------------------------------------
/src/agents/chess.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Random and Stockfish agents for chess."""
17 |
18 | import dataclasses
19 | import os
20 | from typing import Any
21 |
22 | import chess
23 | import chess.engine
24 | import chess.pgn
25 | import chess.svg
26 | import numpy as np
27 |
28 | from lm_act.src import config as config_lib
29 | from lm_act.src import interfaces
30 | from lm_act.src.environments import chess as chess_env
31 |
32 |
33 | @dataclasses.dataclass(frozen=True, kw_only=True)
34 | class StockfishAgentConfig(config_lib.Agent):
35 | """Configuration for the Stockfish agent."""
36 |
37 | name: str = 'chess_stockfish'
38 | action_type: chess_env.ActionNotation = 'san'
39 | skill_level: int = 20
40 | time_limit: float = 0.05
41 | node_limit: int | None = None
42 |
43 |
44 | class StockfishAgent(interfaces.Agent):
45 | """Stockfish agent."""
46 |
47 | def __init__(
48 | self,
49 | config: StockfishAgentConfig,
50 | ) -> None:
51 | self._skill_level = config.skill_level
52 | self._action_notation = config.action_type
53 | self._limit = chess.engine.Limit(
54 | nodes=config.node_limit,
55 | time=config.time_limit,
56 | )
57 |
58 | def step(
59 | self,
60 | observation: Any,
61 | environment: interfaces.Environment,
62 | rng: np.random.Generator,
63 | ) -> str:
64 | """Returns Stockfish's action for the board in the observation."""
65 | bin_path = os.path.join(
66 | os.getcwd(),
67 | '../Stockfish/src/stockfish',
68 | )
69 | with chess.engine.SimpleEngine.popen_uci(bin_path) as engine:
70 | engine.configure({'Skill Level': self._skill_level})
71 | return chess_env.action_to_string(
72 | action_notation=self._action_notation,
73 | action=engine.play(observation['board'], limit=self._limit).move,
74 | board=observation['board'],
75 | )
76 |
--------------------------------------------------------------------------------
/src/interfaces.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Defines the interfaces for agents and environments."""
17 |
18 | import abc
19 | from collections.abc import Mapping, Sequence
20 | import pathlib
21 | from typing import Any
22 |
23 | import dm_env
24 | import numpy as np
25 |
26 | from lm_act.src import config as config_lib
27 |
28 |
29 | class Environment(dm_env.Environment):
30 | """Interface for environments."""
31 |
32 | @abc.abstractmethod
33 | def __init__(
34 | self,
35 | config: config_lib.Environment,
36 | opening_paths: Sequence[pathlib.Path] | None = None,
37 | ) -> None:
38 | """Initializes the environment."""
39 |
40 | @property
41 | @abc.abstractmethod
42 | def legal_actions(self) -> list[str]:
43 | """Returns the legal actions."""
44 |
45 | def sample_legal_action(self, rng: np.random.Generator) -> str:
46 | """Returns a random legal action."""
47 | # By default, we just return one of the legal actions.
48 | return rng.choice(self.legal_actions)
49 |
50 | def action_is_illegal(self, action: str) -> bool:
51 | """Returns whether the action is illegal."""
52 | # By default, we just check if the action is in the legal actions.
53 | return action not in self.legal_actions
54 |
55 | def action_is_invalid(self, action: str) -> bool:
56 | """Returns whether the action is valid."""
57 | # By default, we just check if the action is a single word.
58 | return len(action.split()) != 1
59 |
60 | def observation_spec(self) -> Any:
61 | return NotImplementedError('Unnecessary for this interface.')
62 |
63 | def action_spec(self) -> Any:
64 | return NotImplementedError('Unnecessary for this interface.')
65 |
66 |
67 | class Agent(abc.ABC):
68 | """Interface for agents."""
69 |
70 | @abc.abstractmethod
71 | def __init__(self, config: config_lib.Agent) -> None:
72 | """Initializes the agent."""
73 |
74 | @abc.abstractmethod
75 | def step(
76 | self,
77 | observation: Mapping[str, Any],
78 | env: Environment,
79 | rng: np.random.Generator,
80 | ) -> str:
81 | """Returns the agent's action for the observation, environment, and rng."""
82 |
--------------------------------------------------------------------------------
/src/prompts.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Builds the prompts for the experiment."""
17 |
18 | from lm_act.src import config as config_lib
19 |
20 |
21 | def build_demonstration_prompt(
22 | demonstrations: str,
23 | ) -> str:
24 | """Returns the prompt for the demonstrations."""
25 | if not demonstrations:
26 | return (
27 | 'You are an intelligent agent operating in a dynamic environment. Based'
28 | ' on the series of observations provided, you need to determine the'
29 | ' optimal action that maximizes the expected reward or achieves the'
30 | ' desired goal. Carefully consider all the given observations, infer'
31 | ' the current state of the environment, and select the most appropriate'
32 | ' action.\n\n'
33 | )
34 |
35 | return (
36 | 'You are a powerful reinforcement learning agent. You can effectively'
37 | ' identify a policy exposed by demonstrations and reproduce it in a new'
38 | ' situation.\n\nHere are a number of'
39 | f' demonstrations:\n\n{demonstrations}\n'
40 | )
41 |
42 |
43 | def build_trajectory_prompt(
44 | trajectory: str,
45 | legal_actions: list[str],
46 | config: config_lib.Experiment,
47 | ) -> str:
48 | """Returns the prompt for the current trajectory."""
49 | prompt = f'\nThis is the current trajectory:\n\n{trajectory}\n'
50 |
51 | if config.prompt.show_legal_actions:
52 | prompt += (
53 | '\nIn this situation, this is the list of all the actions that are'
54 | f' legal:\n\n{", ".join(legal_actions)}\n'
55 | )
56 |
57 | prompt += '\nGiven the '
58 | if 0 < config.num_demonstrations:
59 | prompt += 'demonstrations and the '
60 | prompt += 'current trajectory, you should infer the next logical action.'
61 |
62 | if config.prompt.show_legal_actions:
63 | prompt += '\nCheck that the chosen action is in the set of legal actions.'
64 |
65 | if config.prompt.use_chain_of_thought:
66 | prompt += (
67 | '\nThink step by step and very briefly explain your reasoning for'
68 | ' choosing this action.\nYou must answer with the reasoning followed by'
69 | ' the action in the following format:\nReasoning: ...\nAction: ...'
70 | )
71 | else:
72 | if config.prompt.show_legal_actions:
73 | prompt += (
74 | '\nYou must answer with one of the legal actions only, without any'
75 | ' other text.'
76 | )
77 | else:
78 | prompt += (
79 | '\nYou must answer with the action only, without any other text,'
80 | ' following exactly the same format as the previous actions.'
81 | )
82 |
83 | return prompt.strip()
84 |
--------------------------------------------------------------------------------
/src/agents/tic_tac_toe.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Minimax agent for tic tac toe."""
17 |
18 | import copy
19 | import dataclasses
20 | from typing import Any
21 |
22 | import numpy as np
23 |
24 | from lm_act.src import config as config_lib
25 | from lm_act.src import interfaces
26 | from lm_act.src.environments import tic_tac_toe as tic_tac_toe_env
27 |
28 |
29 | @dataclasses.dataclass(frozen=True, kw_only=True)
30 | class MinimaxAgentConfig(config_lib.Agent):
31 | """Configuration for the minimax agent."""
32 |
33 | name: str = 'tic_tac_toe_minimax'
34 |
35 |
36 | def _minimax(
37 | player_symbol: str, # Constant across the recursion.
38 | board: np.ndarray,
39 | current_symbol: str,
40 | depth: int,
41 | is_max: bool,
42 | ) -> tuple[str | None, int]:
43 | """Returns the best move and its value for the given board and symbol."""
44 | # Check whether there is a winner.
45 | for _, axis in tic_tac_toe_env.AXES:
46 | line = np.take_along_axis(board.flatten(), axis, axis=None)
47 | if (line == player_symbol).all():
48 | return None, 10 - depth
49 | elif (line == ('o' if player_symbol == 'x' else 'x')).all():
50 | return None, -10 + depth
51 |
52 | # Check whether there is a draw.
53 | if ' ' not in board:
54 | return None, 0
55 |
56 | # Search all the legal moves.
57 | best_value = -100 if is_max else 100
58 | best_move = None
59 |
60 | for move in tic_tac_toe_env.legal_actions(board):
61 | row = tic_tac_toe_env.ROWS.index(move[0])
62 | col = tic_tac_toe_env.COLS.index(move[1])
63 | new_board = copy.deepcopy(board)
64 | new_board[row, col] = current_symbol
65 | _, value = _minimax(
66 | player_symbol=player_symbol,
67 | board=new_board,
68 | current_symbol='o' if current_symbol == 'x' else 'x',
69 | is_max=not is_max,
70 | depth=depth + 1,
71 | )
72 |
73 | if (is_max and best_value < value) or (not is_max and value < best_value):
74 | best_value = value
75 | best_move = move
76 |
77 | return best_move, best_value
78 |
79 |
80 | class MinimaxAgent(interfaces.Agent):
81 | """Minimax agent for tic tac toe."""
82 |
83 | def __init__(self, config: MinimaxAgentConfig):
84 | pass
85 |
86 | def step(
87 | self,
88 | observation: Any,
89 | environment: interfaces.Environment,
90 | rng: np.random.Generator,
91 | ) -> str:
92 | """Returns the best move for the given board."""
93 | action, _ = _minimax(
94 | player_symbol=observation['symbol'],
95 | board=observation['board'],
96 | current_symbol=observation['symbol'],
97 | depth=0,
98 | is_max=True,
99 | )
100 | if action is None:
101 | raise ValueError('No optimal action found.')
102 | return action
103 |
--------------------------------------------------------------------------------
/src/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Constants used in the LMAct project."""
17 |
18 | from typing import TypeVar
19 |
20 | from lm_act.src import interfaces
21 | from lm_act.src.agents import chess as chess_agents
22 | from lm_act.src.agents import crossword as crossword_agents
23 | from lm_act.src.agents import grid_world as grid_world_agents
24 | from lm_act.src.agents import random as random_agents
25 | from lm_act.src.agents import tic_tac_toe as tic_tac_toe_agents
26 | from lm_act.src.environments import chess as chess_env
27 | from lm_act.src.environments import crossword as crossword_env
28 | from lm_act.src.environments import dm_control as dm_control_env
29 | from lm_act.src.environments import grid_world as grid_world_env
30 | from lm_act.src.environments import tic_tac_toe as tic_tac_toe_env
31 |
32 |
33 | def get_rgb_shape(environment_name: str) -> tuple[int, int, int]:
34 | if environment_name.startswith('atari'):
35 | return (210, 160, 3)
36 | elif environment_name == 'chess':
37 | return (256, 256, 3)
38 | elif environment_name.startswith('tic_tac_toe'):
39 | return (256, 256, 3)
40 | elif environment_name.startswith('grid_world'):
41 | return (192, 192, 3)
42 | else:
43 | raise ValueError(f'Unknown environment name: {environment_name}.')
44 |
45 |
46 | # The interfaces are abstract, so we need to use TypeVar to make them generic.
47 | Agent = TypeVar('Agent', bound=interfaces.Agent)
48 | Environment = TypeVar('Environment', bound=interfaces.Environment)
49 |
50 |
51 | def get_agent_builder(agent_name: str) -> type[Agent]:
52 | match agent_name:
53 | case 'random':
54 | return random_agents.RandomAgent
55 | case 'chess_stockfish':
56 | return chess_agents.StockfishAgent
57 | case 'crossword_oracle':
58 | return crossword_agents.OracleAgent
59 | case 'grid_world_shortest_path':
60 | return grid_world_agents.ShortestPathAgent
61 | case 'tic_tac_toe_minimax':
62 | return tic_tac_toe_agents.MinimaxAgent
63 | case _:
64 | raise ValueError(f'Unknown agent name: {agent_name}.')
65 |
66 |
67 | def get_environment_builder(environment_name: str) -> type[Environment]:
68 | """Returns the environment builder for the given environment name."""
69 | if environment_name.startswith('atari'):
70 | raise NotImplementedError('atari environments are not yet supported.')
71 | elif environment_name == 'chess':
72 | return chess_env.Chess
73 | elif environment_name == 'crossword':
74 | return crossword_env.Crossword
75 | elif environment_name.startswith('dm_control'):
76 | return dm_control_env.DMControl
77 | elif environment_name.startswith('grid_world'):
78 | return grid_world_env.GridWorld
79 | elif environment_name.startswith('tic_tac_toe'):
80 | return tic_tac_toe_env.TicTacToe
81 | else:
82 | raise ValueError(f'Unknown environment name: {environment_name}.')
83 |
--------------------------------------------------------------------------------
/src/environments/chess.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Chess environment."""
17 |
18 | import copy
19 | import dataclasses
20 | import io
21 | import os
22 | import pathlib
23 | from typing import Literal
24 |
25 | import chess
26 | import chess.pgn
27 | import dm_env
28 | from numpy import random
29 |
30 | from lm_act.src import bagz
31 | from lm_act.src import config as config_lib
32 | from lm_act.src import interfaces
33 |
34 |
35 | ActionNotation = Literal['uci', 'san']
36 |
37 |
38 | def action_to_string(
39 | action: chess.Move,
40 | board: chess.Board,
41 | action_notation: ActionNotation,
42 | ) -> str:
43 | """Returns the string representation of a chess action."""
44 | match action_notation:
45 | case 'uci':
46 | return str(action)
47 | case 'san':
48 | return board.san(action)
49 |
50 |
51 | @dataclasses.dataclass(frozen=True, kw_only=True)
52 | class EnvironmentConfig(config_lib.Environment):
53 | """Configuration for the environment."""
54 |
55 | name: str = 'chess'
56 | observation_type: Literal['fen', 'pgn', 'txt', 'png'] = 'fen'
57 | action_type: ActionNotation = 'san'
58 |
59 | stockfish_node_limit: int = 1
60 | stockfish_time_limit: float = 0.001
61 | stockfish_skill_level: int = 0
62 |
63 |
64 | class Chess(interfaces.Environment):
65 | """A simple chess environment to play against stockfish."""
66 |
67 | def __init__(
68 | self,
69 | config: EnvironmentConfig,
70 | opening_paths: list[pathlib.Path] | None = None,
71 | openings: list[chess.pgn.Game] | None = None,
72 | ) -> None:
73 | self._stockfish_limit = chess.engine.Limit(
74 | nodes=config.stockfish_node_limit,
75 | time=config.stockfish_time_limit,
76 | )
77 | self._stockfish_skill_level = config.stockfish_skill_level
78 | self._action_notation = config.action_type
79 |
80 | if openings is not None:
81 | self._openings = openings
82 | elif opening_paths is not None:
83 | self._openings = list()
84 | for opening_path in opening_paths:
85 | pgns = bagz.BagReader(
86 | (opening_path / 'observations_pgn.bag').as_posix()
87 | )
88 | pgn = io.StringIO(pgns[0].decode('utf-8'))
89 | if (game := chess.pgn.read_game(pgn)) is not None:
90 | self._openings.append(game)
91 | else:
92 | raise ValueError('Either `openings` or `opening_paths` must be provided.')
93 |
94 | self._game: chess.pgn.Game = None
95 | self._node: chess.pgn.GameNode = None
96 | self._player_is_white: bool = None
97 |
98 | def reset(self) -> dm_env.TimeStep:
99 | self._game = self._openings.pop(0)
100 | self._node = self._game.end()
101 | self._player_is_white = self._board.turn == chess.WHITE
102 | return dm_env.restart(observation=self._observation)
103 |
104 | @property
105 | def _board(self):
106 | return self._node.board()
107 |
108 | @property
109 | def _observation(self):
110 | return {
111 | 'board': copy.deepcopy(self._board),
112 | 'fen': self._board.fen(en_passant='fen'),
113 | 'pgn': str(self._game),
114 | 'txt': str(self._board),
115 | }
116 |
117 | def _turn(
118 | self,
119 | action: str,
120 | notation: ActionNotation,
121 | ) -> None | dm_env.TimeStep:
122 | match notation:
123 | case 'uci':
124 | action = chess.Move.from_uci(action)
125 | case 'san':
126 | action = self._board.push_san(action)
127 | self._node = self._node.add_main_variation(action)
128 |
129 | if (
130 | claim_draw := self._board.can_claim_draw()
131 | ) or self._board.is_game_over():
132 | match (outcome := self._board.outcome(claim_draw=claim_draw)).winner: # pytype: disable=attribute-error
133 | case chess.WHITE:
134 | reward = 1 if self._player_is_white else -1
135 | case chess.BLACK:
136 | reward = -1 if self._player_is_white else 1
137 | case None:
138 | reward = 0
139 | case _:
140 | raise ValueError(f'Unknown outcome: {outcome}')
141 | return dm_env.termination(observation=self._observation, reward=reward)
142 |
143 | return None
144 |
145 | def step(self, action: str) -> dm_env.TimeStep:
146 | outcome = self._turn(action=action, notation=self._action_notation)
147 | if outcome is not None:
148 | return outcome
149 |
150 | bin_path = os.path.join(
151 | os.getcwd(),
152 | '../Stockfish/src/stockfish',
153 | )
154 | with chess.engine.SimpleEngine.popen_uci(bin_path) as engine:
155 | engine.configure({'Skill Level': self._stockfish_skill_level})
156 | action = str(engine.play(self._board, limit=self._stockfish_limit).move)
157 |
158 | if (outcome := self._turn(action=action, notation='uci')) is not None:
159 | return outcome
160 |
161 | return dm_env.transition(observation=self._observation, reward=0)
162 |
163 | @property
164 | def legal_actions(self) -> list[str]:
165 | return sorted(
166 | action_to_string(
167 | action=action,
168 | board=self._board,
169 | action_notation=self._action_notation,
170 | )
171 | for action in self._board.legal_moves
172 | )
173 |
174 | def sample_legal_action(self, rng: random.Generator) -> str:
175 | return action_to_string(
176 | action=rng.choice(list(self._board.legal_moves)),
177 | board=self._board,
178 | action_notation=self._action_notation,
179 | )
180 |
--------------------------------------------------------------------------------
/src/environments/dm_control.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """DM Control environment."""
17 |
18 | import ast
19 | from collections.abc import Mapping
20 | import dataclasses
21 | import pathlib
22 | from typing import Any, Literal
23 |
24 | from dm_control import suite
25 | import dm_env
26 | import numpy as np
27 | from numpy import random
28 |
29 | from lm_act.src import bagz
30 | from lm_act.src import config as config_lib
31 | from lm_act.src import interfaces
32 |
33 |
34 | @dataclasses.dataclass(frozen=True, kw_only=True)
35 | class EnvironmentConfig(config_lib.Environment):
36 | """Configuration for the environment."""
37 |
38 | name: str = 'dm_control'
39 | domain: str = 'cheetah'
40 | task: str = 'run'
41 | observation_type: Literal['dict'] = 'dict'
42 | seed: int = 0
43 |
44 | def __post_init__(self):
45 | object.__setattr__(self, 'name', f'{self.name}_{self.domain}_{self.task}')
46 |
47 |
48 | class DMControl(interfaces.Environment):
49 | """DM Control environment."""
50 |
51 | def __init__(
52 | self,
53 | config: EnvironmentConfig,
54 | opening_paths: list[pathlib.Path] | None = None,
55 | openings: list[Mapping[str, Any]] | None = None,
56 | ) -> None:
57 | if openings is not None:
58 | self._openings = openings
59 | elif opening_paths is not None:
60 | self._openings = list()
61 | for opening_path in opening_paths:
62 | observations = bagz.BagReader(
63 | (opening_path / 'observations_dict.bag').as_posix()
64 | )
65 | self._openings.append(ast.literal_eval(observations[0].decode('utf-8')))
66 | else:
67 | raise ValueError('Either `openings` or `opening_paths` must be provided.')
68 |
69 | self._config = config
70 | self._env = suite.load(
71 | domain_name=config.domain,
72 | task_name=config.task,
73 | task_kwargs=dict(random=config.seed),
74 | )
75 | self._action_spec = self._env.action_spec()
76 |
77 | @property
78 | def domain(self) -> str:
79 | return self._config.domain
80 |
81 | @property
82 | def task(self) -> str:
83 | return self._config.task
84 |
85 | def _prepare_observation(self, time_step: dm_env.TimeStep) -> dm_env.TimeStep:
86 | time_step.observation['dict'] = str(
87 | {k: v.tolist() for k, v in time_step.observation.items()}
88 | )
89 | return time_step
90 |
91 | def _set_init_state(self, observation: Mapping[str, Any]) -> dm_env.TimeStep:
92 | """Returns the time step after setting the initial environment state."""
93 | time_step = self._env.reset()
94 |
95 | with self._env.physics.reset_context():
96 | match self._config.domain:
97 | case 'cheetah' | 'hopper':
98 | self._env.physics.data.qpos[0] = 0
99 | self._env.physics.data.qpos[1:] = observation['position']
100 | self._env.physics.data.qvel[:] = observation['velocity']
101 | case 'point_mass':
102 | self._env.physics.data.qpos[:] = observation['position']
103 | self._env.physics.data.qvel[:] = observation['velocity']
104 | case _:
105 | raise ValueError(f'Unknown domain: {self._config.domain}')
106 |
107 | self._env.physics.after_reset()
108 | env_observation = self._env.task.get_observation(self._env.physics)
109 |
110 | # Check that the observation from the environment matches `observation`.
111 | for key, value in observation.items():
112 | np.testing.assert_equal(env_observation[key], value)
113 |
114 | time_step = dm_env.TimeStep(
115 | step_type=time_step.step_type,
116 | reward=time_step.reward,
117 | discount=time_step.discount,
118 | observation=env_observation,
119 | )
120 | return self._prepare_observation(time_step=time_step)
121 |
122 | def reset(self) -> dm_env.TimeStep:
123 | """Resets the environment."""
124 | observation = self._openings.pop(0)
125 | return self._set_init_state(observation)
126 |
127 | def step(self, action: str) -> dm_env.TimeStep:
128 | """Steps the environment."""
129 | actions = self._extract_actions(action)
130 | return self._prepare_observation(self._env.step(actions))
131 |
132 | @property
133 | def legal_actions(self) -> list[str]:
134 | """Returns the legal actions."""
135 | return [
136 | 'A comma-separated list (enclosed by square brackets) of'
137 | f' {self._action_spec.shape[0]} values between'
138 | f' {self._action_spec.minimum.tolist()} and'
139 | f' {self._action_spec.maximum.tolist()}.'
140 | ]
141 |
142 | def sample_legal_action(self, rng: random.Generator) -> str:
143 | min_values = self._action_spec.minimum
144 | max_values = self._action_spec.maximum
145 | return str(
146 | (
147 | (max_values - min_values) * rng.random(len(min_values)) + min_values
148 | ).tolist()
149 | )
150 |
151 | def action_is_illegal(self, action: str) -> bool:
152 | """Returns whether the action is legal."""
153 | # For DM Control, we treat invalid and illegal actions as the same.
154 | return self.action_is_invalid(action)
155 |
156 | def action_is_invalid(self, action: str) -> bool:
157 | """Returns whether the action is valid."""
158 | try:
159 | self._extract_actions(action)
160 | return False
161 | except ValueError:
162 | return True
163 |
164 | def _extract_actions(self, action: str) -> np.ndarray:
165 | if not action.startswith('[') or not action.endswith(']'):
166 | raise ValueError(f'Invalid action: {action}')
167 | action = action[1:-1]
168 | values = np.fromiter(map(float, action.split(',')), dtype=np.float64)
169 | self._action_spec.validate(values)
170 | return values
171 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Evaluates an agent on the LMAct benchmark."""
17 |
18 | from collections.abc import Sequence
19 | import logging
20 |
21 | from absl import app
22 | from absl import flags
23 | import immutabledict
24 | import numpy as np
25 | import tqdm
26 |
27 | from lm_act.src import config as config_lib
28 | from lm_act.src import evaluate
29 | from lm_act.src.agents import chess as chess_agent
30 | from lm_act.src.agents import crossword as crossword_agent
31 | from lm_act.src.agents import grid_world as grid_world_agent
32 | from lm_act.src.agents import random as random_agent
33 | from lm_act.src.agents import tic_tac_toe as tic_tac_toe_agent
34 | from lm_act.src.environments import chess
35 | from lm_act.src.environments import crossword
36 | from lm_act.src.environments import dm_control
37 | from lm_act.src.environments import grid_world
38 | from lm_act.src.environments import tic_tac_toe
39 |
40 |
41 | _ENVIRONMENT = flags.DEFINE_enum(
42 | name='environment',
43 | default='tic_tac_toe',
44 | enum_values=[
45 | 'chess',
46 | 'crossword',
47 | 'dm_control',
48 | 'grid_world',
49 | 'tic_tac_toe',
50 | ],
51 | help='The environment to evaluate.',
52 | )
53 | _OBSERVATION_TYPE = flags.DEFINE_enum(
54 | name='observation_type',
55 | default='txt',
56 | enum_values=['coords', 'dict', 'fen', 'pgn', 'png', 'rgb', 'txt'],
57 | help='The observation representation to evaluate.',
58 | )
59 | _ACTION_TYPE = flags.DEFINE_enum(
60 | name='action_type',
61 | default='txt',
62 | enum_values=['txt', 'san'],
63 | help='The action representation to evaluate.',
64 | )
65 | _AGENT = flags.DEFINE_enum(
66 | name='agent',
67 | default='random',
68 | enum_values=[
69 | 'random',
70 | 'chess_stockfish',
71 | 'crossword_oracle',
72 | 'grid_world_shortest_path',
73 | 'tic_tac_toe_minimax',
74 | ],
75 | help='The agent to evaluate.',
76 | )
77 | _NUM_DEMONSTRATIONS = flags.DEFINE_integer(
78 | name='num_demonstrations',
79 | default=0,
80 | help='The number of demonstrations to use.',
81 | )
82 | _NUM_EVALUTION_EPISODES = flags.DEFINE_integer(
83 | name='num_evaluation_episodes',
84 | default=100,
85 | help='The number of episodes to evaluate.',
86 | )
87 | _NUM_EVALUATION_STEPS = flags.DEFINE_integer(
88 | name='num_evaluation_steps',
89 | default=100,
90 | help='The number of steps to evaluate.',
91 | )
92 |
93 | _CONFIG_BY_ENVIRONMENT = immutabledict.immutabledict({
94 | 'chess': chess.EnvironmentConfig,
95 | 'crossword': crossword.EnvironmentConfig,
96 | 'dm_control': dm_control.EnvironmentConfig,
97 | 'grid_world': grid_world.EnvironmentConfig,
98 | 'tic_tac_toe': tic_tac_toe.EnvironmentConfig,
99 | })
100 | _CONFIG_BY_AGENT = immutabledict.immutabledict({
101 | 'random': random_agent.RandomAgentConfig,
102 | 'chess_stockfish': chess_agent.StockfishAgentConfig,
103 | 'crossword_oracle': crossword_agent.OracleAgentConfig,
104 | 'grid_world_shortest_path': grid_world_agent.ShortestPathAgentConfig,
105 | 'tic_tac_toe_minimax': tic_tac_toe_agent.MinimaxAgentConfig,
106 | })
107 |
108 |
109 | def main(argv: Sequence[str]) -> None:
110 | if len(argv) > 1:
111 | raise app.UsageError('Too many command-line arguments.')
112 |
113 | logging.getLogger().setLevel(logging.WARNING)
114 |
115 | print(f'Environment: {_ENVIRONMENT.value}')
116 | print(f'Observation type: {_OBSERVATION_TYPE.value}')
117 | print(f'Agent: {_AGENT.value}')
118 | print(f'Num evaluation episodes: {_NUM_EVALUTION_EPISODES.value}')
119 |
120 | scores = list()
121 | num_steps = list()
122 | num_invalid_actions = list()
123 | num_illegal_actions = list()
124 | num_empty_actions = list()
125 |
126 | for episode in tqdm.trange(_NUM_EVALUTION_EPISODES.value):
127 | (
128 | episode_score,
129 | episode_num_steps,
130 | episode_num_invalid_actions,
131 | episode_num_illegal_actions,
132 | episode_num_empty_actions,
133 | ) = evaluate.evaluate_episode(
134 | episode_idx=episode,
135 | config=config_lib.Experiment(
136 | num_demonstrations=_NUM_DEMONSTRATIONS.value,
137 | num_evaluation_steps=_NUM_EVALUATION_STEPS.value,
138 | agent=_CONFIG_BY_AGENT[_AGENT.value](
139 | action_type=_ACTION_TYPE.value,
140 | ),
141 | environment=_CONFIG_BY_ENVIRONMENT[_ENVIRONMENT.value](
142 | observation_type=_OBSERVATION_TYPE.value,
143 | action_type=_ACTION_TYPE.value,
144 | ),
145 | prompt=config_lib.Prompt(
146 | show_legal_actions=None,
147 | use_chain_of_thought=None,
148 | ),
149 | ),
150 | )
151 |
152 | scores.append(episode_score)
153 | num_steps.append(episode_num_steps)
154 | num_invalid_actions.append(episode_num_invalid_actions)
155 | num_illegal_actions.append(episode_num_illegal_actions)
156 | num_empty_actions.append(episode_num_empty_actions)
157 |
158 | logging.info({
159 | 'episode': episode,
160 | 'score': episode_score,
161 | 'num_steps': episode_num_steps,
162 | 'num_invalid_actions': episode_num_invalid_actions,
163 | 'num_illegal_actions': episode_num_illegal_actions,
164 | 'num_empty_actions': episode_num_empty_actions,
165 | })
166 |
167 | print(f'Average score: {np.mean(scores):.2f}')
168 | print(f'Average num steps: {np.mean(num_steps):.2f}')
169 | print(f'Average num invalid actions: {np.mean(num_invalid_actions):.2f}')
170 | print(f'Average num illegal actions: {np.mean(num_illegal_actions):.2f}')
171 | print(f'Average num empty actions: {np.mean(num_empty_actions):.2f}')
172 |
173 |
174 | if __name__ == '__main__':
175 | app.run(main)
176 |
--------------------------------------------------------------------------------
/src/environments/grid_world.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Grid World environment."""
17 |
18 | from collections.abc import Mapping
19 | import copy
20 | import dataclasses
21 | import os
22 | import pathlib
23 | from typing import Literal
24 |
25 | import dm_env
26 | import imageio
27 | import jax
28 | import numpy as np
29 |
30 | from lm_act.src import bagz
31 | from lm_act.src import config as config_lib
32 | from lm_act.src import interfaces
33 |
34 |
35 | _ASSETS_PATH = pathlib.Path(
36 | os.path.join(os.getcwd(), '../crafter/crafter/assets')
37 | )
38 |
39 |
40 | @dataclasses.dataclass(frozen=True, kw_only=True)
41 | class EnvironmentConfig(config_lib.Environment):
42 | """Configuration for the environment."""
43 |
44 | name: str = 'grid_world'
45 | observation_type: Literal['rgb', 'txt', 'coords'] = 'rgb'
46 | height: int = 12
47 | width: int = 12
48 |
49 | def __post_init__(self):
50 | object.__setattr__(self, 'name', f'grid_world_{self.height}x{self.width}')
51 |
52 |
53 | class GridWorld(interfaces.Environment):
54 | """2D grid world environment with a player and a target."""
55 |
56 | def __init__(
57 | self,
58 | config: EnvironmentConfig,
59 | opening_paths: list[pathlib.Path] | None = None,
60 | openings: list[tuple[np.ndarray, np.ndarray]] | None = None,
61 | ) -> None:
62 | if openings is not None:
63 | self._openings = openings
64 | elif opening_paths is not None:
65 | self._openings = list()
66 | for opening_path in opening_paths:
67 | player_coordinates = bagz.BagReader(
68 | (opening_path / 'observations_player.bag').as_posix()
69 | )
70 | target_coordinates = bagz.BagReader(
71 | (opening_path / 'observations_target.bag').as_posix()
72 | )
73 | player = np.frombuffer(player_coordinates[0], dtype=np.int64)
74 | target = np.frombuffer(target_coordinates[0], dtype=np.int64)
75 | self._openings.append((player, target))
76 | else:
77 | raise ValueError('Either `openings` or `opening_paths` must be provided.')
78 |
79 | self._width = config.width
80 | self._height = config.height
81 |
82 | self._player: np.ndarray = None
83 | self._target: np.ndarray = None
84 |
85 | self._walls = np.full((self._height, self._width), False, dtype=bool)
86 | self._walls[0, :] = True
87 | self._walls[-1, :] = True
88 | self._walls[:, 0] = True
89 | self._walls[:, -1] = True
90 |
91 | with open(_ASSETS_PATH / 'food.png', 'r') as f:
92 | target_sprite = imageio.imread(f)[:, :, :-1]
93 | with open(_ASSETS_PATH / 'player.png', 'r') as f:
94 | player_sprite = imageio.imread(f)[:, :, :-1]
95 | with open(_ASSETS_PATH / 'stone.png', 'r') as f:
96 | wall_sprite = imageio.imread(f)
97 |
98 | sprite_matrix = np.array([wall_sprite, target_sprite, player_sprite])
99 | self._sprite_matrix = np.transpose(sprite_matrix[:, ::-1], [1, 2, 0, 3])
100 |
101 | self._rgb = jax.jit(self._rgb)
102 |
103 | def reset(self) -> dm_env.TimeStep:
104 | self._player, self._target = self._openings.pop(0)
105 | return dm_env.restart(observation=self._observation)
106 |
107 | def _rgb(self, state: np.ndarray) -> jax.Array:
108 | return jax.lax.conv_transpose(
109 | state[None], # NHWC
110 | self._sprite_matrix, # HWIO
111 | (16, 16),
112 | 'SAME',
113 | )[0]
114 |
115 | @property
116 | def _text(self) -> str:
117 | scene = list()
118 |
119 | for row in range(self._height):
120 | row_str = '|'
121 | for col in range(self._width):
122 | if self._walls[row, col]:
123 | tile = 'wall'
124 | elif self._player[0] == row and self._player[1] == col:
125 | tile = 'player'
126 | elif self._target[0] == row and self._target[1] == col:
127 | tile = 'target'
128 | else:
129 | tile = 'tile'
130 | row_str += tile.center(8) + '|'
131 | scene.append(row_str)
132 | scene.append('-' * len(row_str))
133 |
134 | scene = ['-' * len(scene[0])] + scene
135 |
136 | return '\n'.join(scene)
137 |
138 | @property
139 | def _observation(self) -> Mapping[str, np.ndarray | str]:
140 | player_state = np.zeros((self._height, self._width), dtype=np.bool_)
141 | player_state[self._player[0], self._player[1]] = True
142 | target_state = np.zeros((self._height, self._width), dtype=np.bool_)
143 | target_state[self._target[0], self._target[1]] = True
144 | state = np.stack(
145 | [self._walls, target_state, player_state],
146 | axis=-1,
147 | dtype=np.uint8,
148 | )
149 | return {
150 | 'player': copy.deepcopy(self._player),
151 | 'target': copy.deepcopy(self._target),
152 | 'rgb': copy.deepcopy(np.array(self._rgb(state), dtype=np.uint8)),
153 | 'txt': self._text,
154 | 'coords': str(
155 | dict(player=self._player.tolist(), target=self._target.tolist())
156 | ),
157 | }
158 |
159 | def step(self, action: str) -> dm_env.TimeStep:
160 | next_y, next_x = self._player
161 |
162 | match action:
163 | case 'left':
164 | next_x -= 1
165 | case 'right':
166 | next_x += 1
167 | case 'up':
168 | next_y -= 1
169 | case 'down':
170 | next_y += 1
171 | case _:
172 | raise ValueError(f'Unsupported action: {action}')
173 |
174 | if not self._walls[next_y, next_x]:
175 | self._player = np.array([next_y, next_x])
176 | if np.all(self._player == self._target):
177 | return dm_env.termination(observation=self._observation, reward=1)
178 |
179 | return dm_env.transition(observation=self._observation, reward=0)
180 |
181 | @property
182 | def legal_actions(self) -> list[str]:
183 | """Returns the legal actions."""
184 | return ['left', 'right', 'up', 'down']
185 |
--------------------------------------------------------------------------------
/src/environments/crossword.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Crossword environment."""
17 |
18 | import ast
19 | import collections
20 | from collections.abc import Mapping
21 | import dataclasses
22 | import logging
23 | import os
24 | import pathlib
25 | import re
26 | from typing import Literal
27 |
28 | import dm_env
29 |
30 | from lm_act.src import bagz
31 | from lm_act.src import config as config_lib
32 | from lm_act.src import interfaces
33 |
34 |
35 | _BASE_DIR_PATH = pathlib.Path(
36 | os.path.join(
37 | os.getcwd(),
38 | 'data/lm_act/',
39 | )
40 | )
41 |
42 |
43 | @dataclasses.dataclass(frozen=True, kw_only=True)
44 | class EnvironmentConfig(config_lib.Environment):
45 | """Configuration for the environment."""
46 |
47 | name: str = 'crossword'
48 | observation_type: Literal['txt'] = 'txt'
49 |
50 |
51 | class Crossword(interfaces.Environment):
52 | """A simple crossword environment."""
53 |
54 | def __init__(
55 | self,
56 | config: EnvironmentConfig,
57 | opening_paths: list[pathlib.Path] | None = None,
58 | openings: list[str] | None = None,
59 | ) -> None:
60 | if openings is not None:
61 | self._openings = openings
62 | elif opening_paths is not None:
63 | self._openings = list()
64 | for opening_path in opening_paths:
65 | puzzles = bagz.BagReader(
66 | (opening_path / 'observations_puzzle.bag').as_posix()
67 | )
68 | self._openings.append(puzzles[0].decode('utf-8'))
69 | else:
70 | raise ValueError('Either `openings` or `opening_paths` must be provided.')
71 |
72 | self._puzzle: str = None
73 | self._grid: str = None
74 | self._clues: str = None
75 | self._coords_by_solution: Mapping[str, tuple[int, int]] = None
76 | self._unfound_solutions: list[str] = None
77 |
78 | self._words_by_length = collections.defaultdict(set)
79 |
80 | with open(_BASE_DIR_PATH / 'crossword' / 'words.txt', 'r') as f:
81 | for word in f:
82 | word = word.strip()
83 | self._words_by_length[len(word)].add(word)
84 |
85 | def reset(self) -> dm_env.TimeStep:
86 | self._puzzle = self._openings.pop(0)
87 | self._grid, self._clues, solutions, *_ = self._puzzle.split('\n\n')
88 |
89 | if (grid_len := len(self._grid)) != 554:
90 | raise ValueError(f'Invalid grid size, should be 554 but is {grid_len}.')
91 |
92 | self._coords_by_solution = dict()
93 | for solution in solutions.split('Solution:\n')[1].split('\n'):
94 | # The solution has the format "A1: word (row, col)", so we need some regex
95 | # magic to separate the word and the coordinates.
96 | if (match := re.match(r'(.*?)\s*(\(\d+,\s*\d+\))', solution)) is not None:
97 | sol, coords = match.groups()
98 | self._coords_by_solution[sol] = ast.literal_eval(coords)
99 |
100 | self._unfound_solutions = list(self._coords_by_solution.keys())
101 |
102 | return dm_env.restart(observation=self._observation)
103 |
104 | @property
105 | def _observation(self):
106 | return {
107 | 'txt': self._grid + '\n\n' + self._clues,
108 | 'puzzle': (
109 | self._grid
110 | + '\n\n'
111 | + self._clues
112 | + '\n\nSolution:\n'
113 | + '\n'.join(
114 | f'{sol} {self._coords_by_solution[sol]}'
115 | for sol in self._unfound_solutions
116 | )
117 | ),
118 | }
119 |
120 | def _update_grid(self, solution: str) -> None:
121 | is_vertical = solution[0] == 'D'
122 | row, col = self._coords_by_solution[solution]
123 |
124 | # Account for the offsets.
125 | row = row * 2 + 1
126 | col = col * 5 + 3
127 |
128 | for character in solution.split(': ')[1]:
129 | idx = row * 37 + col
130 | self._grid = (
131 | self._grid[: idx - 1] + ' ' + character + self._grid[idx + 1 :]
132 | )
133 | row += 2 * is_vertical
134 | col += 5 * (1 - is_vertical)
135 |
136 | def step(self, action: str) -> dm_env.TimeStep:
137 | # The environment should not be case-sensitive.
138 | action = action.upper().strip()
139 |
140 | if action in self._unfound_solutions:
141 | self._unfound_solutions.remove(action)
142 | self._update_grid(action)
143 |
144 | if self._unfound_solutions:
145 | return dm_env.transition(reward=1, observation=self._observation)
146 | return dm_env.termination(reward=1, observation=self._observation)
147 |
148 | if action in self._coords_by_solution.keys():
149 | return dm_env.transition(reward=0, observation=self._observation)
150 |
151 | # We continue if the proposed word is incorrect but has the correct length.
152 | try:
153 | idx, word = action.split(': ')
154 | solution = [
155 | sol for sol in self._coords_by_solution if sol.startswith(idx)
156 | ][0]
157 | if len(solution.split(': ')[1]) == len(word):
158 | return dm_env.transition(reward=0, observation=self._observation)
159 | except (IndexError, ValueError):
160 | logging.info('Invalid action: %s', action)
161 | return dm_env.termination(reward=-1, observation=self._observation)
162 |
163 | # Otherwise, we terminate the game.
164 | return dm_env.termination(reward=-1, observation=self._observation)
165 |
166 | def action_is_invalid(self, action: str) -> bool:
167 | """Returns whether the action in the format `A0: word` or `D1: word`."""
168 | return not re.match(r'^[AD]\d+:\s\w+$', action)
169 |
170 | @property
171 | def legal_actions(self) -> list[str]:
172 | """Returns the legal actions (all possible words of the correct length)."""
173 | legal_actions = list()
174 | for solution in self._unfound_solutions:
175 | idx, fill = solution.split(': ')
176 | for word in self._words_by_length[len(fill)]:
177 | legal_actions.append(f'{idx}: {word}')
178 | return legal_actions
179 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LMAct: A Benchmark for In-Context Imitation Learning with Long Multimodal Demonstrations
2 |
3 |
4 |
5 |
6 |
7 | This repository provides an implementation of our ICML 2025 paper [LMAct: A Benchmark for In-Context Imitation Learning with Long Multimodal Demonstrations](https://arxiv.org/abs/2412.01441).
8 |
9 | > In this paper, we present a benchmark to pressure-test today’s frontier
10 | models’ multimodal decision-making capabilities in the very long-context regime
11 | (up to one million tokens) and investigate whether these models can learn from
12 | large numbers of expert demonstrations in their context.
13 | We evaluate the performance of Claude 3.5 Sonnet, Gemini 1.5 Flash, Gemini 1.5
14 | Pro, Gemini 2.0 Flash Experimental, GPT-4o, o1-mini, o1-preview, and o1 as
15 | policies across a battery of simple interactive decision-making tasks: playing
16 | tic-tac-toe, chess, and Atari, navigating grid worlds, solving crosswords, and
17 | controlling a simulated cheetah.
18 | We study increasing amounts of expert demonstrations in the context — from no
19 | demonstrations to 512 full episodes.
20 | Across our tasks, models rarely manage to fully reach expert performance, and
21 | often, presenting more demonstrations has little effect.
22 | Some models steadily improve with more demonstrations on a few tasks.
23 | We investigate the effect of encoding observations as text or images and the
24 | impact of chain-of-thought prompting.
25 | To help quantify the impact of other approaches and future innovations, we open
26 | source our benchmark that covers the zero-, few-, and many-shot regimes in a
27 | unified evaluation.
28 |
29 | ## Contents
30 |
31 | ```
32 | .
33 | |
34 | ├── crafter - Crafter (needs to be downloaded)
35 | |
36 | ├── data - Expert demonstrations (need to be downloaded)
37 | |
38 | ├── src
39 | | ├── agents
40 | | │ ├── chess.py - Stockfish agent (chess expert)
41 | | │ ├── crossword.py - Oracle agent (crossword expert)
42 | | │ ├── grid_world.py - Shortest path agent (grid world expert)
43 | | │ ├── random.py - Random action agent
44 | | │ └── tic_tac_toe.py - Minimax agent (tic-tac-toe expert)
45 | | |
46 | | ├── bagz.py - Readers for our .bag data files
47 | | ├── config.py - Experiment configurations
48 | | ├── constants.py - Project constants
49 | | |
50 | | ├── environments
51 | | │ ├── chess.py - Chess environment
52 | | │ ├── crossword.py - Crossword environment
53 | | │ ├── dm_control.py - DM Control environment
54 | | │ ├── grid_world.py - Grid world environment
55 | | │ └── tic_tac_toe.py - Tic-tac-toe environment
56 | | |
57 | | ├── evaluate.py - Evaluation loop
58 | | ├── interfaces.py - Project interfaces
59 | | ├── main.py - Experiment launch script
60 | | └── prompts.py - Prompt-building functionality
61 | |
62 | ├── Stockfish - Stockfish (needs to be installed)
63 | |
64 | ├── README.md
65 | └── requirements.txt - Dependencies
66 | ```
67 |
68 | ## Installation
69 |
70 | Clone the source code into a local directory:
71 |
72 | ```bash
73 | git clone https://github.com/google-deepmind/lm_act.git
74 | cd lm_act
75 | ```
76 |
77 | This repository requires Python 3.11.
78 | `pip install -r requirements.txt` will install all required dependencies.
79 | This is best done inside a [conda environment](https://www.anaconda.com/).
80 | To that end, install [Anaconda](https://www.anaconda.com/download#downloads).
81 | Then, create and activate the conda environment:
82 |
83 | ```bash
84 | conda create --name lm_act python=3.11
85 | conda activate lm_act
86 | ```
87 |
88 | Install `pip` and use it to install all the dependencies:
89 |
90 | ```bash
91 | conda install pip
92 | pip install -r requirements.txt
93 | ```
94 |
95 | ### Installing Crafter
96 |
97 | Download the crafter repository:
98 |
99 | ```bash
100 | git clone https://github.com/danijar/crafter.git
101 | ```
102 |
103 | ### Installing Stockfish
104 |
105 | Download and compile the latest version of Stockfish (for Unix-like systems):
106 |
107 | ```bash
108 | git clone https://github.com/official-stockfish/Stockfish.git
109 | cd Stockfish/src
110 | make -j profile-build ARCH=x86-64-avx2
111 | cd ../..
112 | ```
113 |
114 | ### Downloading the Expert Demonstrations
115 |
116 | To download our expert demonstrations to the correct locations, run the
117 | following command:
118 |
119 | ```bash
120 | cd data
121 | ./download.sh
122 | cd ..
123 | ```
124 |
125 | ## Usage
126 |
127 | Before running any code, make sure to activate the conda environment and set the
128 | `PYTHONPATH`:
129 |
130 | ```bash
131 | conda activate lm_act
132 | export PYTHONPATH=$(pwd)/..
133 | ```
134 |
135 | To evaluate an agent, run the following command:
136 | ```bash
137 | python src/main.py \
138 | --environment=tic_tac_toe \
139 | --observation_type=txt \
140 | --agent=random \
141 | --num_demonstrations=0
142 | ```
143 |
144 | ## Citing this work
145 |
146 | ```latex
147 | @inproceedings{ruoss2025lmact,
148 | author = {Anian Ruoss and
149 | Fabio Pardo and
150 | Harris Chan and
151 | Bonnie Li and
152 | Volodymyr Mnih and
153 | Tim Genewein},
154 | title = {{LMAct}: A Benchmark for In-Context Imitation Learning with
155 | Long Multimodal Demonstrations
156 | booktitle = {{ICML}},
157 | year = {2025},
158 | }
159 | ```
160 |
161 | ## License and disclaimer
162 |
163 | Copyright 2024 DeepMind Technologies Limited
164 |
165 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
166 | you may not use this file except in compliance with the Apache 2.0 license.
167 | You may obtain a copy of the Apache 2.0 license at:
168 | https://www.apache.org/licenses/LICENSE-2.0
169 |
170 | All other materials are licensed under the Creative Commons Attribution 4.0
171 | International License (CC-BY). You may obtain a copy of the CC-BY license at:
172 | https://creativecommons.org/licenses/by/4.0/legalcode
173 |
174 | Unless required by applicable law or agreed to in writing, all software and
175 | materials distributed here under the Apache 2.0 or CC-BY licenses are
176 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
177 | either express or implied. See the licenses for the specific language governing
178 | permissions and limitations under those licenses.
179 |
180 | This is not an official Google product.
181 |
--------------------------------------------------------------------------------
/src/environments/tic_tac_toe.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tic-Tac-Toe environment."""
17 |
18 | import copy
19 | import dataclasses
20 | import io
21 | import pathlib
22 | from typing import Literal
23 |
24 | import dm_env
25 | from matplotlib import pyplot as plt
26 | import numpy as np
27 | from PIL import Image
28 | from PIL import ImageDraw
29 |
30 | from lm_act.src import bagz
31 | from lm_act.src import config as config_lib
32 | from lm_act.src import interfaces
33 |
34 |
35 | ROWS = ['a', 'b', 'c']
36 | COLS = ['1', '2', '3']
37 | AXES = [
38 | # Rows.
39 | ('row_0', np.array([0, 1, 2])),
40 | ('row_1', np.array([3, 4, 5])),
41 | ('row_2', np.array([6, 7, 8])),
42 | # Columns.
43 | ('col_0', np.array([0, 3, 6])),
44 | ('col_1', np.array([1, 4, 7])),
45 | ('col_2', np.array([2, 5, 8])),
46 | # Negative diagonal.
47 | ('neg_diag', np.array([0, 4, 8])),
48 | # Positive diagonal.
49 | ('pos_diag', np.array([6, 4, 2])),
50 | ]
51 |
52 |
53 | def _draw_board(
54 | board: np.ndarray,
55 | size: int = 768,
56 | color: str = 'black',
57 | background: str = 'white',
58 | render_coordinates: bool = True,
59 | ) -> tuple[bytes, np.ndarray]:
60 | """Returns the board as a PNG image and RGB array."""
61 | image = Image.new('RGB', size=(size, size), color=background)
62 | draw = ImageDraw.Draw(image)
63 | width = size // 25
64 |
65 | # Draw the grid.
66 | draw.line(
67 | ((size // 3, 0), (size // 3, size)),
68 | fill=color,
69 | width=width,
70 | )
71 | draw.line(
72 | ((2 * size // 3, 0), (2 * size // 3, size)),
73 | fill=color,
74 | width=width,
75 | )
76 | draw.line(
77 | ((0, size // 3), (size, size // 3)),
78 | fill=color,
79 | width=width,
80 | )
81 | draw.line(
82 | ((0, 2 * size // 3), (size, 2 * size // 3)),
83 | fill=color,
84 | width=width,
85 | )
86 |
87 | # Draw the symbols.
88 | for row_idx, row in enumerate(board):
89 | for col_idx, symbol in enumerate(row):
90 |
91 | def _to_coord(idx: int, offset: int) -> int:
92 | return (4 * idx + offset) * size // 12
93 |
94 | coords = {
95 | 'top_left': (_to_coord(col_idx, 1), _to_coord(row_idx, 1)),
96 | 'top_right': (_to_coord(col_idx, 1), _to_coord(row_idx, 3)),
97 | 'bottom_left': (_to_coord(col_idx, 3), _to_coord(row_idx, 1)),
98 | 'bottom_right': (_to_coord(col_idx, 3), _to_coord(row_idx, 3)),
99 | }
100 | match symbol:
101 | case 'o':
102 | draw.ellipse(
103 | [coords['top_left'], coords['bottom_right']],
104 | outline=color,
105 | width=width,
106 | )
107 | case 'x':
108 | draw.line(
109 | (coords['top_left'], coords['bottom_right']),
110 | fill=color,
111 | width=width,
112 | )
113 | draw.line(
114 | (coords['top_right'], coords['bottom_left']),
115 | fill=color,
116 | width=width,
117 | )
118 |
119 | if render_coordinates:
120 | image = Image.fromarray(_add_coordinates(np.array(image)))
121 |
122 | with io.BytesIO() as buffer:
123 | image.save(buffer, format='PNG')
124 | return buffer.getvalue(), np.array(image)
125 |
126 |
127 | def _add_coordinates(rgb_image: np.ndarray) -> np.ndarray:
128 | """Adds coordinates to the image.
129 |
130 | Args:
131 | rgb_image: The RGB image array to add coordinates to.
132 |
133 | Returns:
134 | The RGB image array with coordinates added.
135 | """
136 | rgb_image = rgb_image.astype(np.uint8)
137 | height, width, _ = rgb_image.shape
138 |
139 | x_ticks = np.linspace(height / 6, 5 * height / 6, 3)
140 | y_ticks = np.linspace(width / 6, 5 * width / 6, 3)
141 |
142 | x_tick_labels = ['1', '2', '3']
143 | y_tick_labels = ['a', 'b', 'c']
144 |
145 | font_size = round(8 / 256 * height)
146 |
147 | fig = plt.figure(figsize=(height / 100, width / 100))
148 | plt.imshow(rgb_image)
149 |
150 | plt.xticks(x_ticks, x_tick_labels, fontsize=font_size)
151 | plt.yticks(y_ticks, y_tick_labels, fontsize=font_size)
152 |
153 | plt.tick_params(
154 | axis='both',
155 | which='both',
156 | labeltop=True,
157 | labelright=True,
158 | labelbottom=True,
159 | labelleft=True,
160 | length=0,
161 | )
162 |
163 | fig.tight_layout()
164 | fig.canvas.draw()
165 | new_width, new_height = fig.canvas.get_width_height()
166 | rgb_buffer = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
167 |
168 | plt.close(fig)
169 |
170 | return rgb_buffer.reshape((new_height, new_width, 3))
171 |
172 |
173 | def legal_actions(board: np.ndarray) -> list[str]:
174 | return [
175 | ROWS[coords[0]] + COLS[coords[1]] for coords in np.argwhere(board == ' ')
176 | ]
177 |
178 |
179 | @dataclasses.dataclass(frozen=True, kw_only=True)
180 | class EnvironmentConfig(config_lib.Environment):
181 | """Configuration for the environment."""
182 |
183 | name: str = 'tic_tac_toe'
184 | observation_type: Literal['txt', 'png'] = 'png'
185 | render_coordinates: bool = False
186 | seed: int = 0
187 |
188 | def __post_init__(self):
189 | if self.render_coordinates:
190 | if self.observation_type == 'txt':
191 | raise ValueError(
192 | 'Rendering coordinates is only supported for `png` observations.'
193 | )
194 | object.__setattr__(self, 'name', 'tic_tac_toe_with_coordinates')
195 |
196 |
197 | class TicTacToe(interfaces.Environment):
198 | """A simple tic-tac-toe environment to play against a random policy."""
199 |
200 | def __init__(
201 | self,
202 | config: EnvironmentConfig,
203 | opening_paths: list[pathlib.Path] | None = None,
204 | openings: list[tuple[np.ndarray, bool]] | None = None,
205 | ) -> None:
206 | if openings is not None:
207 | self._openings = openings
208 | elif opening_paths is not None:
209 | self._openings = list()
210 | for opening_path in opening_paths:
211 | boards = bagz.BagReader(
212 | (opening_path / 'observations_board.bag').as_posix()
213 | )
214 | board = np.frombuffer(boards[0], dtype=np.dtype(' dm_env.TimeStep:
226 | self._board, self._player_is_x = self._openings.pop(0)
227 | self._board = copy.deepcopy(self._board)
228 | return dm_env.restart(observation=self._observation)
229 |
230 | @property
231 | def _observation(self):
232 | return {
233 | 'board': copy.deepcopy(self._board),
234 | 'txt': '\n----------\n'.join(' | '.join(row) for row in self._board),
235 | # Gemini resizes all images to 768x768, so we might as well do it here.
236 | 'png': _draw_board(
237 | board=self._board,
238 | size=768,
239 | render_coordinates=self._render_coordinates,
240 | )[0],
241 | # In contrast, Sequence Storage can only render images up to 256x256.
242 | 'rgb': _draw_board(
243 | board=self._board,
244 | size=256,
245 | render_coordinates=self._render_coordinates,
246 | )[1],
247 | 'symbol': self.symbol(is_player=True),
248 | }
249 |
250 | def symbol(self, is_player: bool) -> str:
251 | if is_player:
252 | return 'x' if self._player_is_x else 'o'
253 | return 'o' if self._player_is_x else 'x'
254 |
255 | def _turn(
256 | self,
257 | action: str,
258 | is_player: bool,
259 | ) -> None | dm_env.TimeStep:
260 | symbol = self.symbol(is_player=is_player)
261 | self._board[ROWS.index(action[0]), COLS.index(action[1])] = symbol
262 |
263 | for _, axis in AXES:
264 | line = np.take_along_axis(self._board.flatten(), axis, axis=None)
265 | if (line == symbol).all():
266 | return dm_env.termination(
267 | observation=self._observation,
268 | reward=1 if is_player else -1,
269 | )
270 |
271 | if ' ' not in self._board:
272 | return dm_env.termination(observation=self._observation, reward=0)
273 |
274 | return None
275 |
276 | def step(self, action: str) -> dm_env.TimeStep:
277 | if action not in self.legal_actions:
278 | raise ValueError(f'Action {action} is illegal.')
279 |
280 | if (outcome := self._turn(action=action, is_player=True)) is not None:
281 | return outcome
282 |
283 | # The adversary randomly chooses a legal action.
284 | move = self._rng.choice(self.legal_actions)
285 |
286 | if (outcome := self._turn(action=move, is_player=False)) is not None:
287 | return outcome
288 |
289 | return dm_env.transition(observation=self._observation, reward=0)
290 |
291 | @property
292 | def legal_actions(self) -> list[str]:
293 | return legal_actions(self._board)
294 |
--------------------------------------------------------------------------------
/src/bagz.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Bagz file reader/writer and PyGrain-compatible data source for POSIX systems.
17 |
18 | Bagz is a file format for storing a sequence of string records, typically
19 | serialised protocol buffers. It supports fast index based look-up.
20 | """
21 |
22 | import bisect
23 | from collections.abc import Sequence
24 | import itertools
25 | import mmap
26 | import os
27 | import re
28 | import shutil
29 | import struct
30 | from typing import Any, SupportsIndex
31 |
32 | from etils import epath
33 | from typing_extensions import Self
34 | import zstandard as zstd
35 |
36 |
37 | class BagFileReader(Sequence[bytes]):
38 | """Reader for single Bagz files."""
39 |
40 | def __init__(
41 | self,
42 | filename: str,
43 | *,
44 | separate_limits: bool = False,
45 | decompress: bool | None = None,
46 | ) -> None:
47 | """Creates a BagFileReader.
48 |
49 | Args:
50 | filename: The name of the single Bagz file to read.
51 | separate_limits: Whether the limits are stored in a separate file.
52 | decompress: Whether to decompress the records. If None, uses the file
53 | extension to determine whether to decompress.
54 | """
55 | if decompress or (decompress is None and filename.endswith('.bagz')):
56 | self._process = lambda x: zstd.decompress(x) if x else x
57 | else:
58 | self._process = lambda x: x
59 | self._filename = filename
60 | fd = os.open(filename, os.O_RDONLY)
61 | try:
62 | self._records = mmap.mmap(fd, 0, access=mmap.ACCESS_READ)
63 | file_size = self._records.size()
64 | except ValueError:
65 | self._records = b''
66 | file_size = 0
67 | finally:
68 | os.close(fd)
69 | if separate_limits:
70 | directory, name = os.path.split(filename)
71 | fd = os.open(os.path.join(directory, 'limits.' + name), os.O_RDONLY)
72 | try:
73 | self._limits = mmap.mmap(fd, 0, access=mmap.ACCESS_READ)
74 | index_size = self._limits.size()
75 | except ValueError:
76 | self._limits = b''
77 | index_size = 0
78 | finally:
79 | os.close(fd)
80 | index_start = 0
81 | else:
82 | if 0 < file_size < 8:
83 | raise ValueError('Bagz file too small')
84 | self._limits = self._records
85 | if file_size:
86 | (index_start,) = struct.unpack('= index_start
90 | index_size = file_size - index_start
91 | assert index_size % 8 == 0
92 | self._num_records = index_size // 8
93 | self._limits_start = index_start
94 |
95 | def __len__(self) -> int:
96 | """Returns the number of records in the Bagz file."""
97 | return self._num_records
98 |
99 | def __getitem__(self, index: SupportsIndex) -> bytes:
100 | """Returns a record from the Bagz file."""
101 | i = index.__index__()
102 | if not 0 <= i < self._num_records:
103 | raise IndexError('bagz.BragReader index out of range')
104 | end = i * 8 + self._limits_start
105 | if i:
106 | rec_range = struct.unpack('<2q', self._limits[end - 8 : end + 8])
107 | else:
108 | rec_range = (0, *struct.unpack(' None:
122 | """Creates a BagShardReader.
123 |
124 | Args:
125 | filename: The name of the sharded Bagz file to read.
126 | separate_limits: Whether the limits are stored in a separate file.
127 | decompress: Whether to decompress the records. If None, uses the file
128 | extension to determine whether to decompress.
129 | """
130 | matches = re.findall(r'@(\d+)', filename)
131 | assert len(matches) == 1
132 | num_files = int(matches[0])
133 | assert num_files < 100_000
134 | self._bags = tuple(
135 | BagFileReader(
136 | filename=re.sub(
137 | r'@(\d+)', f'-{idx:05d}-of-{num_files:05d}', filename
138 | ),
139 | separate_limits=separate_limits,
140 | decompress=decompress,
141 | )
142 | for idx in range(num_files)
143 | )
144 | self._accum = tuple(itertools.accumulate(map(len, self._bags)))
145 |
146 | def __len__(self) -> int:
147 | """Returns the number of records in the Bagz file."""
148 | return self._accum[-1]
149 |
150 | def __getitem__(self, index: int) -> bytes:
151 | if index < 0:
152 | index += self._accum[-1]
153 | if seqn := bisect.bisect_left(self._accum, index + 1):
154 | index -= self._accum[seqn - 1]
155 | return self._bags[seqn][index]
156 |
157 |
158 | class BagReader(Sequence[bytes]):
159 | """Reader for Bagz files."""
160 |
161 | def __init__(
162 | self,
163 | filename: str,
164 | *,
165 | separate_limits: bool = False,
166 | decompress: bool | None = None,
167 | ) -> None:
168 | """Creates a BagReader.
169 |
170 | Args:
171 | filename: The name of the Bagz file to read. Supports the @N shard syntax
172 | (where @0 corresponds to the single file case). If the shard syntax does
173 | not parse, then `filename` is treated as a single file.
174 | separate_limits: Whether the limits are stored in a separate file.
175 | decompress: Whether to decompress the records. If None, uses the file
176 | extension to determine whether to decompress.
177 | """
178 | if matches := re.findall(r'@(\d+)', filename):
179 | assert len(matches) == 1
180 | if int(matches[0]) != '0':
181 | reader_class = BagShardReader
182 | else:
183 | filename = filename.replace(matches[0], '')
184 | reader_class = BagFileReader
185 | else:
186 | reader_class = BagFileReader
187 |
188 | self._reader = reader_class(
189 | filename=filename,
190 | separate_limits=separate_limits,
191 | decompress=decompress,
192 | )
193 |
194 | def __len__(self) -> int:
195 | """Returns the number of records in the Bagz file."""
196 | return len(self._reader)
197 |
198 | def __getitem__(self, index: SupportsIndex) -> bytes:
199 | """Returns a record from the Bagz file."""
200 | return self._reader[index]
201 |
202 |
203 | class BagWriter:
204 | """Writer for Bagz files."""
205 |
206 | def __init__(
207 | self,
208 | filename: str,
209 | *,
210 | separate_limits: bool = False,
211 | compress: bool | None = None,
212 | compression_level: int = 0,
213 | ) -> None:
214 | """Creates a BagWriter.
215 |
216 | Args:
217 | filename: The name of the Bagz file to write.
218 | separate_limits: Whether to keep the limits in a separate file.
219 | compress: Whether to compress the records. If None, uses the file
220 | extension to determine whether to compress.
221 | compression_level: The compression level to use when compressing.
222 | """
223 | if compress or (compress is None and filename.endswith('.bagz')):
224 | self._process = zstd.ZstdCompressor(level=compression_level).compress
225 | else:
226 | self._process = lambda x: x
227 | self._separate_limits = separate_limits
228 | directory, name = os.path.split(filename)
229 | self._records = open(filename, 'wb')
230 | self._limits = open(os.path.join(directory, 'limits.' + name), 'wb+')
231 |
232 | def write(self, data: bytes) -> None:
233 | """Writes a record to the Bagz file."""
234 | if data:
235 | self._records.write(self._process(data))
236 | self._limits.write(struct.pack(' None:
239 | """Flushes the Bagz file."""
240 | self._records.flush()
241 | self._limits.flush()
242 |
243 | def __enter__(self) -> Self:
244 | return self
245 |
246 | def __exit__(self, exc_type, exc_value, traceback) -> None:
247 | """Ensures the Bagz file is closed when exiting a context."""
248 | self.close()
249 |
250 | def close(self) -> None:
251 | """Concatenates the limits file to the end of the data file."""
252 | if self._separate_limits:
253 | self._records.close()
254 | self._limits.close()
255 | else:
256 | self._limits.seek(0)
257 | shutil.copyfileobj(self._limits, self._records)
258 | self._records.close()
259 | os.unlink(self._limits.name)
260 | self._limits.close()
261 |
262 |
263 | class BagDataSource:
264 | """PyGrain-compatible data source for bagz files."""
265 |
266 | def __init__(self, path: epath.PathLike) -> None:
267 | """Creates a new BagDataSource object.
268 |
269 | Args:
270 | path: The path to the bag file.
271 | """
272 | self._path = os.fspath(path)
273 | self._reader = BagReader(self._path)
274 | self._num_records = len(self._reader)
275 |
276 | def __len__(self) -> int:
277 | return self._num_records
278 |
279 | def __getitem__(self, record_key: SupportsIndex) -> bytes:
280 | return self._reader[record_key]
281 |
282 | def __getstate__(self) -> dict[str, Any]:
283 | state = self.__dict__.copy()
284 | del state['_reader']
285 | return state
286 |
287 | def __setstate__(self, state) -> None:
288 | self.__dict__.update(state)
289 | self._reader = BagReader(self._path)
290 |
291 | def __repr__(self) -> str:
292 | return f'BagDataSource(path={self._path!r}'
293 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/src/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Evaluates a single episode."""
17 |
18 | import copy
19 | import os
20 | import pathlib
21 | from typing import Any
22 |
23 | from absl import logging
24 | import numpy as np
25 |
26 | from lm_act.src import bagz
27 | from lm_act.src import config as config_lib
28 | from lm_act.src import constants
29 | from lm_act.src import prompts
30 |
31 |
32 | _BASE_DIR_PATH = pathlib.Path(
33 | os.path.join(
34 | os.getcwd(),
35 | 'data/lm_act/',
36 | )
37 | )
38 |
39 |
40 | def _load_demonstrations_and_opening_path(
41 | rng: np.random.Generator,
42 | config: config_lib.Experiment,
43 | ) -> tuple[list[list[Any]], list[list[Any]], pathlib.Path]:
44 | """Loads the demonstrations (observations & actions) and the opening path.
45 |
46 | Args:
47 | rng: The random number generator.
48 | config: The experiment configuration.
49 |
50 | Returns:
51 | - The demonstrations episodes, consisting of observations and actions.
52 | - The opening path, i.e., the path to the opening (which is used to set the
53 | initial state of the environment for the evaluation episode).
54 |
55 | Raises:
56 | ValueError: If there are insufficient demonstrations in the directory.
57 | """
58 | base_dir_path = _BASE_DIR_PATH / config.environment.name
59 | demonstration_names = [
60 | file_name
61 | for file_name in os.listdir(base_dir_path)
62 | if file_name.startswith('demonstration')
63 | ]
64 | if len(demonstration_names) < config.num_demonstrations + 1:
65 | raise ValueError(
66 | f'Insufficient demonstrations in {base_dir_path}: Need at least'
67 | f' {config.num_demonstrations + 1} but only found'
68 | f' {len(demonstration_names)}.'
69 | )
70 |
71 | if config.replay_episode:
72 | assert config.num_demonstrations == 1
73 | num_openings = config.num_demonstrations
74 | else:
75 | # We need to add 1 to account for the opening that that will be evaluated.
76 | num_openings = config.num_demonstrations + 1
77 | demonstration_names = rng.choice(
78 | demonstration_names,
79 | size=num_openings,
80 | replace=False,
81 | shuffle=False,
82 | )
83 | opening_name = demonstration_names[-1]
84 | demonstration_names = demonstration_names[: config.num_demonstrations]
85 |
86 | demo_observations = list()
87 | demo_actions = list()
88 |
89 | match config.environment.observation_type:
90 | case 'rgb':
91 | rgb_shape = constants.get_rgb_shape(config.environment.name)
92 | observation_decode_fn = lambda x: np.frombuffer(
93 | x,
94 | dtype=np.uint8,
95 | ).reshape(rgb_shape)
96 | case 'png':
97 | # PNG data does not need to be decoded.
98 | observation_decode_fn = lambda x: x
99 | case _:
100 | observation_decode_fn = lambda x: x.decode('utf-8')
101 | action_decode_fn = lambda x: x.decode('utf-8')
102 |
103 | for demonstration_name in demonstration_names:
104 | demo_dir_path = base_dir_path / demonstration_name
105 | observations_path = (
106 | demo_dir_path
107 | / f'observations_{config.environment.observation_type}.bag'
108 | )
109 | actions_path = (
110 | demo_dir_path / f'actions_{config.environment.action_type}.bag'
111 | )
112 | observations = bagz.BagReader(observations_path.as_posix())
113 | actions = bagz.BagReader(actions_path.as_posix())
114 | assert len(observations) == len(actions)
115 | demo_observations.append(list(map(observation_decode_fn, observations)))
116 | demo_actions.append(list(map(action_decode_fn, actions)))
117 |
118 | return demo_observations, demo_actions, base_dir_path / opening_name
119 |
120 |
121 | def _create_demonstration_prompt(
122 | config: config_lib.Experiment,
123 | demo_observations: list[list[Any]],
124 | demo_actions: list[list[Any]],
125 | ) -> tuple[str, dict[str, Any]]:
126 | """Returns the demonstration prompt and content for the given config."""
127 | content_by_tag = dict()
128 | demo_prompts = list()
129 |
130 | for demo_idx, (observations, actions) in enumerate(
131 | zip(demo_observations, demo_actions)
132 | ):
133 | for step_idx, (observation, action) in enumerate(
134 | zip(observations, actions)
135 | ):
136 | match config.environment.observation_type:
137 | case 'fen' | 'coords':
138 | demo_prompt = f'Observation: {observation} '
139 | case 'dict':
140 | demo_prompt = f'Observation: {observation}\n'
141 | case 'pgn' | 'txt':
142 | demo_prompt = f'Observation:\n{observation}\n'
143 | case 'rgb' | 'png':
144 | tag = f''
145 | content_by_tag[tag] = observation
146 | demo_prompt = f'Observation: {tag} '
147 | case _:
148 | raise ValueError(
149 | 'Unsupported observation type:'
150 | f' {config.environment.observation_type}'
151 | )
152 |
153 | demo_prompt += f'Action: {action}'
154 | demo_prompts.append(demo_prompt)
155 | demo_prompts.append('\n')
156 | demo_prompts.append('\n')
157 |
158 | demonstration_prompt = prompts.build_demonstration_prompt(
159 | demonstrations=''.join(demo_prompts),
160 | )
161 | logging.info('Demonstration prompt: %s', demonstration_prompt)
162 | return demonstration_prompt, content_by_tag
163 |
164 |
165 | def _create_trajectory_prompt(
166 | config: config_lib.Experiment,
167 | observations: list[Any],
168 | actions: list[Any],
169 | legal_actions: list[str],
170 | ) -> tuple[str, dict[str, Any]]:
171 | """Returns the trajectory prompt and content for the given config."""
172 | content_by_tag = dict()
173 | trajectory_prompts = list()
174 |
175 | # The first action is a dummy action so we place it at the end of the list.
176 | actions = np.roll(copy.deepcopy(actions), -1)
177 |
178 | for step_idx, (observation, action) in enumerate(zip(observations, actions)):
179 | match config.environment.observation_type:
180 | case 'fen' | 'coords':
181 | trajectory_prompt = f'Observation: {observation} '
182 | case 'dict':
183 | trajectory_prompt = f'Observation: {observation}\n'
184 | case 'pgn' | 'txt':
185 | trajectory_prompt = f'Observation:\n{observation}\n'
186 | case 'rgb' | 'png':
187 | tag = f''
188 | content_by_tag[tag] = observation
189 | trajectory_prompt = f'Observation: {tag} '
190 | case _:
191 | raise ValueError(
192 | 'Unsupported observation type:'
193 | f' {config.environment.observation_type}'
194 | )
195 |
196 | if config.prompt.include_past_actions and step_idx < len(actions) - 1:
197 | trajectory_prompt += f'Action: {action}'
198 |
199 | if trajectory_prompt:
200 | trajectory_prompts.append(trajectory_prompt)
201 | trajectory_prompts.append('\n')
202 |
203 | trajectory_prompt = prompts.build_trajectory_prompt(
204 | config=config,
205 | trajectory=''.join(trajectory_prompts),
206 | legal_actions=legal_actions,
207 | )
208 | logging.info('Current trajectory prompt: %s', trajectory_prompt)
209 | return trajectory_prompt, content_by_tag
210 |
211 |
212 | def evaluate_episode_replay(
213 | episode_idx: int,
214 | config: config_lib.Experiment,
215 | ) -> int:
216 | """Returns the number of correctly replayed actions for a single episode."""
217 |
218 | # Every episode has to initialize the RNG with a different seed.
219 | rng = np.random.default_rng(seed=episode_idx)
220 |
221 | logging.info('Setting up the agent: %s.', config.agent.name)
222 | agent = constants.get_agent_builder(config.agent.name)(config=config.agent)
223 |
224 | logging.info('Loading the demonstrations and the evaluation opening name.')
225 | demo_observations, demo_actions, opening_path = (
226 | _load_demonstrations_and_opening_path(rng=rng, config=config)
227 | )
228 | assert len(demo_observations) == 1
229 | assert len(demo_actions) == 1
230 |
231 | logging.info('Replaying episode %d (opening %s).', episode_idx, opening_path)
232 |
233 | logging.info('Creating the demonstration chunks.')
234 | demonstration_prompt, demonstration_prompt_data = (
235 | _create_demonstration_prompt(
236 | config=config,
237 | demo_observations=demo_observations,
238 | demo_actions=demo_actions,
239 | )
240 | )
241 |
242 | num_correctly_replayed_actions = 0
243 |
244 | for step, (demo_observation, demo_action) in enumerate(
245 | zip(demo_observations[0], demo_actions[0])
246 | ):
247 | trajectory_prompt, trajectory_prompt_data = _create_trajectory_prompt(
248 | config=config,
249 | observations=demo_observations[0][: step + 1],
250 | actions=[None] + demo_actions[0][:step], # Dummy initial action.
251 | legal_actions=list(), # We cannot compute the legal actions.
252 | )
253 | sample = agent.step(
254 | observation={
255 | 'prompt': demonstration_prompt + trajectory_prompt,
256 | 'prompt_data': demonstration_prompt_data | trajectory_prompt_data,
257 | },
258 | environment=None,
259 | rng=rng,
260 | )
261 | replayed_action_is_correct = sample == demo_action
262 | num_correctly_replayed_actions += replayed_action_is_correct
263 |
264 | logging.info({
265 | 'demo_observation': demo_observation,
266 | 'demo_action': demo_action,
267 | 'sample': sample,
268 | 'replayed_action_is_correct': replayed_action_is_correct,
269 | })
270 |
271 | return num_correctly_replayed_actions
272 |
273 |
274 | def evaluate_episode(
275 | episode_idx: int,
276 | config: config_lib.Experiment,
277 | ) -> tuple[float, int, int, int, int]:
278 | """Evaluates a single episode."""
279 |
280 | # Every episode has to initialize the RNG with a different seed.
281 | rng = np.random.default_rng(seed=episode_idx)
282 |
283 | logging.info('Setting up the agent: %s.', config.agent.name)
284 | agent = constants.get_agent_builder(config.agent.name)(config=config.agent)
285 |
286 | logging.info('Loading the demonstrations and the evaluation opening name.')
287 | demo_observations, demo_actions, opening_path = (
288 | _load_demonstrations_and_opening_path(rng=rng, config=config)
289 | )
290 |
291 | logging.info(
292 | 'Evaluating episode %d with opening %s.', episode_idx, opening_path
293 | )
294 |
295 | logging.info('Creating the demonstration chunks.')
296 | demonstration_prompt, demonstration_prompt_data = (
297 | _create_demonstration_prompt(
298 | config=config,
299 | demo_observations=demo_observations,
300 | demo_actions=demo_actions,
301 | )
302 | )
303 |
304 | logging.info('Setting up the environment: %s.', config.environment.name)
305 | env = constants.get_environment_builder(config.environment.name)(
306 | config=config.environment,
307 | opening_paths=[opening_path],
308 | )
309 | time_step = env.reset()
310 |
311 | observations = [time_step.observation[config.environment.observation_type]]
312 | rewards = [time_step.reward]
313 | actions = [None] # Dummy action for the initial observation.
314 |
315 | num_illegal_actions = num_invalid_actions = num_empty_actions = 0
316 |
317 | for _ in range(config.num_evaluation_steps):
318 | if time_step.last():
319 | break
320 |
321 | trajectory_prompt, trajectory_prompt_data = _create_trajectory_prompt(
322 | config=config,
323 | observations=observations,
324 | actions=actions,
325 | legal_actions=env.legal_actions,
326 | )
327 | sample = agent.step(
328 | observation=time_step.observation
329 | | {
330 | 'prompt': demonstration_prompt + trajectory_prompt,
331 | 'prompt_data': demonstration_prompt_data | trajectory_prompt_data,
332 | },
333 | environment=env,
334 | rng=rng,
335 | )
336 |
337 | sample_is_empty = not sample
338 | num_empty_actions += sample_is_empty
339 |
340 | if sample_is_invalid := env.action_is_invalid(sample):
341 | num_invalid_actions += 1
342 | # If the sample is invalid, we also always consider it illegal.
343 | sample_is_illegal = True
344 | num_illegal_actions += 1
345 | elif sample_is_illegal := env.action_is_illegal(sample):
346 | num_illegal_actions += 1
347 |
348 | action = env.sample_legal_action(rng) if sample_is_illegal else sample
349 |
350 | logging.info({
351 | 'observation': time_step.observation[
352 | config.environment.observation_type
353 | ],
354 | 'reward': time_step.reward,
355 | 'action': action,
356 | 'sample': sample,
357 | 'sample_is_invalid': sample_is_invalid,
358 | 'sample_is_illegal': sample_is_illegal,
359 | 'sample_is_empty': sample_is_empty,
360 | 'step_type': int(time_step.step_type),
361 | })
362 |
363 | time_step = env.step(action)
364 | observations.append(
365 | time_step.observation[config.environment.observation_type]
366 | )
367 | rewards.append(time_step.reward)
368 | actions.append(action)
369 |
370 | logging.info({
371 | 'rgb': (
372 | time_step.observation['rgb']
373 | if 'rgb' in time_step.observation
374 | else None
375 | ),
376 | 'observation': time_step.observation[config.environment.observation_type],
377 | 'reward': time_step.reward,
378 | 'action': None,
379 | 'sample': None,
380 | 'sample_is_invalid': None,
381 | 'sample_is_illegal': None,
382 | 'sample_is_empty': None,
383 | 'step_type': int(time_step.step_type),
384 | })
385 |
386 | score = sum(rewards[1:]) # Skip the first reward since it is always None.
387 | num_steps = len(rewards) - 1
388 |
389 | return (
390 | score,
391 | num_steps,
392 | num_invalid_actions,
393 | num_illegal_actions,
394 | num_empty_actions,
395 | )
396 |
--------------------------------------------------------------------------------