├── multigrid
├── utils
│ ├── __init__.py
│ ├── misc.py
│ ├── enum.py
│ ├── random.py
│ ├── minigrid_interface.py
│ ├── rendering.py
│ └── obs.py
├── __init__.py
├── core
│ ├── __init__.py
│ ├── actions.py
│ ├── constants.py
│ ├── mission.py
│ ├── grid.py
│ ├── agent.py
│ ├── roomgrid.py
│ └── world_object.py
├── envs
│ ├── __init__.py
│ ├── playground.py
│ ├── blockedunlockpickup.py
│ ├── empty.py
│ ├── redbluedoors.py
│ └── locked_hallway.py
├── pettingzoo
│ └── __init__.py
├── rllib
│ └── __init__.py
├── wrappers.py
└── base.py
├── .gitignore
├── setup.py
├── .github
└── workflows
│ └── python-publish.yml
├── scripts
├── README.md
├── visualize.py
└── train.py
├── pyproject.toml
├── README.md
└── LICENSE
/multigrid/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/multigrid/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import MultiGridEnv
2 | from .core import *
3 |
4 | __version__ = '0.1.0'
5 |
--------------------------------------------------------------------------------
/multigrid/core/__init__.py:
--------------------------------------------------------------------------------
1 | from .actions import Action
2 | from .agent import Agent, AgentState
3 | from .constants import *
4 | from .grid import Grid
5 | from .mission import MissionSpace
6 | from .world_object import Ball, Box, Door, Floor, Goal, Key, Lava, Wall, WorldObj
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *__pycache__
3 | *egg-info
4 | trained_models
5 |
6 | # PyPI
7 | build/*
8 | dist/*
9 | .idea/
10 |
11 | # Docs
12 | .DS_Store
13 | _site
14 | .jekyll-cache
15 | __pycache__
16 | .vscode/
17 | docs/_build/
18 | /docs/environments/**/*.md
19 | /docs/environments/**/*.html
20 | !/docs/environments/**/index.md
21 |
22 | # Virtual environments
23 | .env
24 | .venv
25 | venv
26 |
--------------------------------------------------------------------------------
/multigrid/core/actions.py:
--------------------------------------------------------------------------------
1 | import enum
2 |
3 |
4 |
5 | class Action(enum.IntEnum):
6 | """
7 | Enumeration of possible actions.
8 | """
9 | left = 0 #: Turn left
10 | right = enum.auto() #: Turn right
11 | forward = enum.auto() #: Move forward
12 | pickup = enum.auto() #: Pick up an object
13 | drop = enum.auto() #: Drop an object
14 | toggle = enum.auto() #: Toggle / activate an object
15 | done = enum.auto() #: Done completing task
16 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import setuptools
3 |
4 |
5 |
6 | PACKAGE_DIR = pathlib.Path(__file__).absolute().parent
7 |
8 |
9 |
10 | def get_version():
11 | """
12 | Gets the multigrid version.
13 | """
14 | path = PACKAGE_DIR / 'multigrid' / '__init__.py'
15 | content = path.read_text()
16 |
17 | for line in content.splitlines():
18 | if line.startswith('__version__'):
19 | return line.strip().split()[-1].strip().strip("'")
20 |
21 | raise RuntimeError("bad version data in __init__.py")
22 |
23 | def get_description():
24 | """
25 | Gets the description from the readme.
26 | """
27 | with open("README.md") as fh:
28 | long_description = ""
29 | header_count = 0
30 | for line in fh:
31 | if line.startswith('##'):
32 | header_count += 1
33 | if header_count < 2:
34 | long_description += line
35 | else:
36 | break
37 | return long_description
38 |
39 | setuptools.setup(
40 | name='multigrid',
41 | version=get_version(),
42 | long_description=get_description(),
43 | )
44 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v3
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/multigrid/utils/misc.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from typing import Any
3 | from ..core.constants import Direction
4 |
5 |
6 |
7 | @functools.cache
8 | def front_pos(agent_x: int, agent_y: int, agent_dir: int):
9 | """
10 | Get the position in front of an agent.
11 | """
12 | dx, dy = Direction(agent_dir).to_vec()
13 | return (agent_x + dx, agent_y + dy)
14 |
15 |
16 |
17 | class PropertyAlias(property):
18 | """
19 | A class property that is an alias for an attribute property.
20 |
21 | Instead of::
22 |
23 | @property
24 | def x(self):
25 | self.attr.x
26 |
27 | @x.setter
28 | def x(self, value):
29 | self.attr.x = value
30 |
31 | we can simply just declare::
32 |
33 | x = PropertyAlias('attr', 'x')
34 | """
35 |
36 | def __init__(self, attr_name: str, attr_property_name: str, doc: str = None) -> None:
37 | """
38 | Parameters
39 | ----------
40 | attr_name : str
41 | Name of the base attribute
42 | attr_property : property
43 | Property from the base attribute class
44 | doc : str
45 | Docstring to append to the property's original docstring
46 | """
47 | prop = lambda obj: getattr(type(getattr(obj, attr_name)), attr_property_name)
48 | fget = lambda obj: prop(obj).fget(getattr(obj, attr_name))
49 | fset = lambda obj, value: prop(obj).fset(getattr(obj, attr_name), value)
50 | fdel = lambda obj: prop(obj).fdel(getattr(obj, attr_name))
51 | super().__init__(fget, fset, fdel, doc=doc)
52 | self.__doc__ = doc
53 |
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
1 | # Training MultiGrid agents with RLlib
2 |
3 | MultiGrid is compatible with RLlib's multi-agent API.
4 |
5 | This folder provides scripts to train and visualize agents over MultiGrid environments.
6 |
7 | ## Requirements
8 |
9 | Using MultiGrid environments with RLlib requires installation of [rllib](https://docs.ray.io/en/latest/rllib/index.html), and one of [PyTorch](https://pytorch.org/) or [TensorFlow](https://www.tensorflow.org/).
10 |
11 | ## Getting Started
12 |
13 | Train 2 agents on the `MultiGrid-Empty-8x8-v0` environment using the PPO algorithm:
14 |
15 | python train.py --algo PPO --env MultiGrid-Empty-8x8-v0 --num-agents 2 --save-dir ~/saved/empty8x8/
16 |
17 | Visualize behavior from trained agents policies:
18 |
19 | python visualize.py --algo PPO --env MultiGrid-Empty-8x8-v0 --num-agents 2 --load-dir ~/saved/empty8x8/
20 |
21 | For more options, run ``python train.py --help`` and ``python visualize.py --help``.
22 |
23 | ## Environments
24 |
25 | All of the environment configurations registered in [`multigrid.envs`](../multigrid/envs/__init__.py) can also be used with RLlib, and are registered via `import multigrid.rllib`.
26 |
27 | To use a specific MultiGrid environment configuration by name:
28 |
29 | >>> import multigrid.rllib
30 | >>> from ray.rllib.algorithms.ppo import PPOConfig
31 | >>> algorithm_config = PPOConfig().environment(env='MultiGrid-Empty-8x8-v0')
32 |
33 | To convert a custom `MultiGridEnv` to an RLlib `MultiAgentEnv`:
34 |
35 | >>> from multigrid.rllib import to_rllib_env
36 | >>> MyRLLibEnvClass = to_rllib_env(MyEnvClass)
37 | >>> algorithm_config = PPOConfig().environment(env=MyRLLibEnvClass)
38 |
--------------------------------------------------------------------------------
/multigrid/utils/enum.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import aenum as enum
4 | import functools
5 | import numpy as np
6 |
7 | from numpy.typing import ArrayLike, NDArray as ndarray
8 | from typing import Any
9 |
10 |
11 |
12 | ### Helper Functions
13 |
14 | @functools.cache
15 | def _enum_array(enum_cls: enum.EnumMeta):
16 | """
17 | Return an array of all values of the given enum.
18 |
19 | Parameters
20 | ----------
21 | enum_cls : enum.EnumMeta
22 | Enum class
23 | """
24 | return np.array([item.value for item in enum_cls])
25 |
26 | @functools.cache
27 | def _enum_index(enum_item: enum.Enum):
28 | """
29 | Return the index of the given enum item.
30 |
31 | Parameters
32 | ----------
33 | enum_item : enum.Enum
34 | Enum item
35 | """
36 | return list(enum_item.__class__).index(enum_item)
37 |
38 |
39 |
40 | ### Enumeration
41 |
42 | class IndexedEnum(enum.Enum):
43 | """
44 | Enum where each member has a corresponding integer index.
45 | """
46 |
47 | def __int__(self):
48 | return self.to_index()
49 |
50 | @classmethod
51 | def add_item(cls, name: str, value: Any):
52 | """
53 | Add a new item to the enumeration.
54 |
55 | Parameters
56 | ----------
57 | name : str
58 | Name of the new enum item
59 | value : Any
60 | Value of the new enum item
61 | """
62 | enum.extend_enum(cls, name, value)
63 | _enum_array.cache_clear()
64 | _enum_index.cache_clear()
65 |
66 | @classmethod
67 | def from_index(cls, index: int | ArrayLike[int]) -> enum.Enum | ndarray:
68 | """
69 | Return the enum item corresponding to the given index.
70 | Also supports vector inputs.
71 |
72 | Parameters
73 | ----------
74 | index : int or ArrayLike[int]
75 | Enum index (or array of indices)
76 |
77 | Returns
78 | -------
79 | enum.Enum or ndarray
80 | Enum item (or array of enum item values)
81 | """
82 | out = _enum_array(cls)[index]
83 | return cls(out) if out.ndim == 0 else out
84 |
85 | def to_index(self) -> int:
86 | """
87 | Return the integer index of this enum item.
88 | """
89 | return _enum_index(self)
90 |
--------------------------------------------------------------------------------
/multigrid/envs/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | ************
3 | Environments
4 | ************
5 |
6 | This package contains implementations of several MultiGrid environments.
7 |
8 | **************
9 | Configurations
10 | **************
11 |
12 | * `Blocked Unlock Pickup <./multigrid.envs.blockedunlockpickup>`_
13 | * ``MultiGrid-BlockedUnlockPickup-v0``
14 | * `Empty <./multigrid.envs.empty>`_
15 | * ``MultiGrid-Empty-5x5-v0``
16 | * ``MultiGrid-Empty-Random-5x5-v0``
17 | * ``MultiGrid-Empty-6x6-v0``
18 | * ``MultiGrid-Empty-Random-6x6-v0``
19 | * ``MultiGrid-Empty-8x8-v0``
20 | * ``MultiGrid-Empty-16x16-v0``
21 | * `Locked Hallway <./multigrid.envs.locked_hallway>`_
22 | * ``MultiGrid-LockedHallway-2Rooms-v0``
23 | * ``MultiGrid-LockedHallway-4Rooms-v0``
24 | * ``MultiGrid-LockedHallway-6Rooms-v0``
25 | * `Playground <./multigrid.envs.playground>`_
26 | * ``MultiGrid-Playground-v0``
27 | * `Red Blue Doors <./multigrid.envs.redbluedoors>`_
28 | * ``MultiGrid-RedBlueDoors-6x6-v0``
29 | * ``MultiGrid-RedBlueDoors-8x8-v0``
30 | """
31 |
32 | from .blockedunlockpickup import BlockedUnlockPickupEnv
33 | from .empty import EmptyEnv
34 | from .locked_hallway import LockedHallwayEnv
35 | from .playground import PlaygroundEnv
36 | from .redbluedoors import RedBlueDoorsEnv
37 |
38 | CONFIGURATIONS = {
39 | 'MultiGrid-BlockedUnlockPickup-v0': (BlockedUnlockPickupEnv, {}),
40 | 'MultiGrid-Empty-5x5-v0': (EmptyEnv, {'size': 5}),
41 | 'MultiGrid-Empty-Random-5x5-v0': (EmptyEnv, {'size': 5, 'agent_start_pos': None}),
42 | 'MultiGrid-Empty-6x6-v0': (EmptyEnv, {'size': 6}),
43 | 'MultiGrid-Empty-Random-6x6-v0': (EmptyEnv, {'size': 6, 'agent_start_pos': None}),
44 | 'MultiGrid-Empty-8x8-v0': (EmptyEnv, {}),
45 | 'MultiGrid-Empty-16x16-v0': (EmptyEnv, {'size': 16}),
46 | 'MultiGrid-LockedHallway-2Rooms-v0': (LockedHallwayEnv, {'num_rooms': 2}),
47 | 'MultiGrid-LockedHallway-4Rooms-v0': (LockedHallwayEnv, {'num_rooms': 4}),
48 | 'MultiGrid-LockedHallway-6Rooms-v0': (LockedHallwayEnv, {'num_rooms': 6}),
49 | 'MultiGrid-Playground-v0': (PlaygroundEnv, {}),
50 | 'MultiGrid-RedBlueDoors-6x6-v0': (RedBlueDoorsEnv, {'size': 6}),
51 | 'MultiGrid-RedBlueDoors-8x8-v0': (RedBlueDoorsEnv, {'size': 8}),
52 | }
53 |
54 | # Register environments with gymnasium
55 | from gymnasium.envs.registration import register
56 | for name, (env_cls, config) in CONFIGURATIONS.items():
57 | register(id=name, entry_point=env_cls, kwargs=config)
58 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # Package ######################################################################
2 |
3 | [build-system]
4 | requires = ["setuptools >= 61.0.0"]
5 | build-backend = "setuptools.build_meta"
6 |
7 | [project]
8 | name = "multigrid"
9 | description = "Fast multi-agent gridworld reinforcement learning environments."
10 | readme = "README.md"
11 | requires-python = ">= 3.9"
12 | authors = [{ name = "Ini Oguntola", email = "ini@ini.io" }]
13 | license = { text = "Apache License" }
14 | keywords = ["Memory, Environment, Agent, Multi-Agent, RL, Gymnasium, Cooperative, Competitive"]
15 | classifiers = [
16 | "Development Status :: 4 - Beta", # change to `5 - Production/Stable` when ready
17 | "License :: OSI Approved :: MIT License",
18 | "Programming Language :: Python :: 3",
19 | "Programming Language :: Python :: 3.9",
20 | "Programming Language :: Python :: 3.10",
21 | "Programming Language :: Python :: 3.11",
22 | 'Intended Audience :: Science/Research',
23 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
24 | ]
25 | dependencies = [
26 | "aenum>=1.3.0",
27 | "numba>=0.53.0",
28 | "numpy>=1.18.0",
29 | "gymnasium>=0.26",
30 | "pygame>=2.2.0",
31 | ]
32 | dynamic = ["version"]
33 |
34 | [tool.setuptools]
35 | include-package-data = true
36 |
37 | [tool.setuptools.packages.find]
38 | include = ["multigrid*"]
39 |
40 | # Linters and Test tools #######################################################
41 |
42 | [tool.black]
43 | safe = true
44 |
45 | [tool.isort]
46 | atomic = true
47 | profile = "black"
48 | append_only = true
49 | src_paths = ["multigrid"]
50 | add_imports = [ "from __future__ import annotations" ]
51 |
52 | [tool.pyright]
53 | include = [
54 | "multigrid/**",
55 | ]
56 |
57 | exclude = [
58 | "**/node_modules",
59 | "**/__pycache__",
60 | ]
61 |
62 | strict = []
63 |
64 | typeCheckingMode = "basic"
65 | pythonVersion = "3.9"
66 | typeshedPath = "typeshed"
67 | enableTypeIgnoreComments = true
68 |
69 | # This is required as the CI pre-commit does not download the module (i.e. numpy)
70 | # Therefore, we have to ignore missing imports
71 | reportMissingImports = "none"
72 |
73 | reportUnknownMemberType = "none"
74 | reportUnknownParameterType = "none"
75 | reportUnknownVariableType = "none"
76 | reportUnknownArgumentType = "none"
77 | reportPrivateUsage = "warning"
78 | reportUntypedFunctionDecorator = "none"
79 | reportMissingTypeStubs = false
80 | reportUnboundVariable = "warning"
81 | reportGeneralTypeIssues ="none"
82 | reportPrivateImportUsage = "none"
83 |
--------------------------------------------------------------------------------
/multigrid/utils/random.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Iterable, TypeVar
3 | from ..core.constants import Color
4 |
5 | T = TypeVar('T')
6 |
7 |
8 |
9 | class RandomMixin:
10 | """
11 | Mixin class for random number generation.
12 | """
13 |
14 | def __init__(self, random_generator: np.random.Generator):
15 | """
16 | Parameters
17 | ----------
18 | random_generator : np.random.Generator
19 | Random number generator
20 | """
21 | self.__np_random = random_generator
22 |
23 | def _rand_int(self, low: int, high: int) -> int:
24 | """
25 | Generate random integer in range [low, high).
26 |
27 | :meta public:
28 | """
29 | return self.__np_random.integers(low, high)
30 |
31 | def _rand_float(self, low: float, high: float) -> float:
32 | """
33 | Generate random float in range [low, high).
34 |
35 | :meta public:
36 | """
37 | return self.__np_random.uniform(low, high)
38 |
39 | def _rand_bool(self) -> bool:
40 | """
41 | Generate random boolean value.
42 |
43 | :meta public:
44 | """
45 | return self.__np_random.integers(0, 2) == 0
46 |
47 | def _rand_elem(self, iterable: Iterable[T]) -> T:
48 | """
49 | Pick a random element in a list.
50 |
51 | :meta public:
52 | """
53 | lst = list(iterable)
54 | idx = self._rand_int(0, len(lst))
55 | return lst[idx]
56 |
57 | def _rand_subset(self, iterable: Iterable[T], num_elems: int) -> list[T]:
58 | """
59 | Sample a random subset of distinct elements of a list.
60 |
61 | :meta public:
62 | """
63 | lst = list(iterable)
64 | assert num_elems <= len(lst)
65 |
66 | out: list[T] = []
67 |
68 | while len(out) < num_elems:
69 | elem = self._rand_elem(lst)
70 | lst.remove(elem)
71 | out.append(elem)
72 |
73 | return out
74 |
75 | def _rand_perm(self, iterable: Iterable[T]) -> list[T]:
76 | """
77 | Randomly permute a list.
78 |
79 | :meta public:
80 | """
81 | lst = list(iterable)
82 | self.__np_random.shuffle(lst)
83 | return lst
84 |
85 | def _rand_color(self) -> Color:
86 | """
87 | Generate a random color.
88 |
89 | :meta public:
90 | """
91 | return self._rand_elem(Color)
92 |
93 | def _rand_pos(
94 | self, x_low: int, x_high: int, y_low: int, y_high: int) -> tuple[int, int]:
95 | """
96 | Generate a random (x, y) position tuple.
97 |
98 | :meta public:
99 | """
100 | return (
101 | self.__np_random.integers(x_low, x_high),
102 | self.__np_random.integers(y_low, y_high),
103 | )
104 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MultiGrid
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | The **MultiGrid** library provides contains a collection of fast multi-agent discrete gridworld environments for reinforcement learning in [Gymnasium](https://github.com/Farama-Foundation/Gymnasium). This is a multi-agent extension of the [minigrid](https://github.com/Farama-Foundation/Minigrid) library, and the interface is designed to be as similar as possible.
10 |
11 | The environments are designed to be fast and easily customizable. Compared to minigrid, the underlying gridworld logic is **significantly optimized**, with environment simulation 10x to 20x faster by our benchmarks.
12 |
13 | Documentation for this library can be found at [ini.github.io/docs/multigrid](https://ini.github.io/docs/multigrid).
14 |
15 | ## Installation
16 |
17 | pip install multigrid
18 |
19 | Or alternatively, for an editable install:
20 |
21 | git clone https://github.com/ini/multigrid
22 | cd multigrid
23 | pip install -e .
24 |
25 | This package requires Python 3.9 or later.
26 |
27 | ## Environments
28 |
29 | The `multigrid.envs` package provides implementations of several multi-agent environments. [You can find the full list here](https://ini.github.io/docs/multigrid/multigrid/multigrid.envs).
30 |
31 | ## API
32 |
33 | MultiGrid follows the same pattern as RLlib's [MultiAgentEnv API](https://docs.ray.io/en/latest/rllib/rllib-env.html#multi-agent-and-hierarchical) and PettingZoo's [ParallelEnv API](https://pettingzoo.farama.org/api/parallel/).
34 |
35 | ```python
36 | import gymnasium as gym
37 | import multigrid.envs
38 |
39 | env = gym.make('MultiGrid-Empty-8x8-v0', agents=2, render_mode='human')
40 |
41 | observations, infos = env.reset()
42 | while not env.unwrapped.is_done():
43 | # this is where you would insert your policy / policies
44 | actions = {agent.index: agent.action_space.sample() for agent in env.unwrapped.agents}
45 | observations, rewards, terminations, truncations, infos = env.step(actions)
46 |
47 | env.close()
48 | ```
49 |
50 | More information about using MultiGrid directly with other APIs:
51 | * [PettingZoo](https://ini.github.io/docs/multigrid/multigrid/multigrid.pettingzoo)
52 | * [RLlib](https://ini.github.io/docs/multigrid/multigrid/multigrid.rllib)
53 |
54 | ## Training Agents
55 |
56 | See the [scripts folder](./scripts) for an example training with RLlib.
57 |
58 | ## Documentation
59 |
60 | Documentation for this package can be found at [ini.github.io/docs/multigrid](https://ini.github.io/docs/multigrid).
61 |
62 | ## Citation
63 |
64 | To cite this project please use:
65 |
66 | ```
67 | @article{oguntola2023theory,
68 | title={Theory of mind as intrinsic motivation for multi-agent reinforcement learning},
69 | author={Oguntola, Ini and Campbell, Joseph and Stepputtis, Simon and Sycara, Katia},
70 | journal={arXiv preprint arXiv:2307.01158},
71 | year={2023}
72 | }
73 | ```
74 |
--------------------------------------------------------------------------------
/multigrid/core/constants.py:
--------------------------------------------------------------------------------
1 | import enum
2 | import numpy as np
3 |
4 | from numpy.typing import NDArray as ndarray
5 | from ..utils.enum import IndexedEnum
6 |
7 |
8 |
9 | #: Tile size for rendering grid cell
10 | TILE_PIXELS = 32
11 |
12 | COLORS = {
13 | 'red': np.array([255, 0, 0]),
14 | 'green': np.array([0, 255, 0]),
15 | 'blue': np.array([0, 0, 255]),
16 | 'purple': np.array([112, 39, 195]),
17 | 'yellow': np.array([255, 255, 0]),
18 | 'grey': np.array([100, 100, 100]),
19 | }
20 |
21 | DIR_TO_VEC = [
22 | # Pointing right (positive X)
23 | np.array((1, 0)),
24 | # Down (positive Y)
25 | np.array((0, 1)),
26 | # Pointing left (negative X)
27 | np.array((-1, 0)),
28 | # Up (negative Y)
29 | np.array((0, -1)),
30 | ]
31 |
32 |
33 |
34 | class Type(str, IndexedEnum):
35 | """
36 | Enumeration of object types.
37 | """
38 | unseen = 'unseen'
39 | empty = 'empty'
40 | wall = 'wall'
41 | floor = 'floor'
42 | door = 'door'
43 | key = 'key'
44 | ball = 'ball'
45 | box = 'box'
46 | goal = 'goal'
47 | lava = 'lava'
48 | agent = 'agent'
49 |
50 |
51 | class Color(str, IndexedEnum):
52 | """
53 | Enumeration of object colors.
54 | """
55 | red = 'red'
56 | green = 'green'
57 | blue = 'blue'
58 | purple = 'purple'
59 | yellow = 'yellow'
60 | grey = 'grey'
61 |
62 | @classmethod
63 | def add_color(cls, name: str, rgb: ndarray[np.uint8]):
64 | """
65 | Add a new color to the ``Color`` enumeration.
66 |
67 | Parameters
68 | ----------
69 | name : str
70 | Name of the new color
71 | rgb : ndarray[np.uint8] of shape (3,)
72 | RGB value of the new color
73 | """
74 | cls.add_item(name, name)
75 | COLORS[name] = np.asarray(rgb, dtype=np.uint8)
76 |
77 | @staticmethod
78 | def cycle(n: int) -> tuple['Color', ...]:
79 | """
80 | Return a cycle of ``n`` colors.
81 | """
82 | return tuple(Color.from_index(i % len(Color)) for i in range(int(n)))
83 |
84 | def rgb(self) -> ndarray[np.uint8]:
85 | """
86 | Return the RGB value of this ``Color``.
87 | """
88 | return COLORS[self]
89 |
90 |
91 | class State(str, IndexedEnum):
92 | """
93 | Enumeration of object states.
94 | """
95 | open = 'open'
96 | closed = 'closed'
97 | locked = 'locked'
98 |
99 |
100 | class Direction(enum.IntEnum):
101 | """
102 | Enumeration of agent directions.
103 | """
104 | right = 0
105 | down = 1
106 | left = 2
107 | up = 3
108 |
109 | def to_vec(self) -> ndarray[np.int8]:
110 | """
111 | Return the vector corresponding to this ``Direction``.
112 | """
113 | return DIR_TO_VEC[self]
114 |
115 |
116 |
117 | ### Minigrid Compatibility
118 |
119 | OBJECT_TO_IDX = {t: t.to_index() for t in Type}
120 | IDX_TO_OBJECT = {t.to_index(): t for t in Type}
121 | COLOR_TO_IDX = {c: c.to_index() for c in Color}
122 | IDX_TO_COLOR = {c.to_index(): c for c in Color}
123 | STATE_TO_IDX = {s: s.to_index() for s in State}
124 | COLOR_NAMES = sorted(list(Color))
125 |
--------------------------------------------------------------------------------
/multigrid/pettingzoo/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This package provides tools for using MultiGrid environments with
3 | the PettingZoo ParallelEnv API.
4 |
5 | *****
6 | Usage
7 | *****
8 |
9 | Wrap an environment instance with :class:`.PettingZooWrapper`:
10 |
11 | >>> import gymnasium as gym
12 | >>> import multigrid.envs
13 | >>> env = gym.make('MultiGrid-Empty-8x8-v0', agents=2, render_mode='human')
14 |
15 | >>> from multigrid.pettingzoo import PettingZooWrapper
16 | >>> env = PettingZooWrapper(env)
17 |
18 | Wrap an environment class with :func:`.to_pettingzoo_env()`:
19 |
20 | >>> from multigrid.envs import EmptyEnv
21 | >>> from multigrid.pettingzoo import to_pettingzoo_env
22 | >>> PZEnv = to_pettingzoo_env(EmptyEnv, metadata={'name': 'empty_v0'})
23 | >>> env = PZEnv(agents=2, render_mode='human')
24 | """
25 |
26 | from __future__ import annotations
27 |
28 | import gymnasium as gym
29 |
30 | from gymnasium import spaces
31 | from pettingzoo import ParallelEnv
32 | from typing import Any
33 |
34 | from ..base import AgentID, MultiGridEnv
35 |
36 |
37 |
38 | class PettingZooWrapper(ParallelEnv):
39 | """
40 | Wrapper for a ``MultiGridEnv`` environment that implements the
41 | PettingZoo ``ParallelEnv`` interface.
42 | """
43 |
44 | def __init__(self, env: MultiGridEnv):
45 | self.env = env
46 | self.reset = self.env.reset
47 | self.step = self.env.step
48 | self.render = self.env.render
49 | self.close = self.env.close
50 | self.metadata = {}
51 |
52 | @property
53 | def agents(self) -> list[AgentID]:
54 | if self.env.is_done():
55 | return []
56 | return [agent.index for agent in self.env.agents if not agent.terminated]
57 |
58 | @property
59 | def possible_agents(self) -> list[AgentID]:
60 | return [agent.index for agent in self.env.agents]
61 |
62 | @property
63 | def observation_spaces(self) -> dict[AgentID, spaces.Space]:
64 | return dict(self.env.observation_space)
65 |
66 | @property
67 | def action_spaces(self) -> dict[AgentID, spaces.Space]:
68 | return dict(self.env.action_space)
69 |
70 | @property
71 | def render_mode(self) -> str | None:
72 | return self.env.render_mode
73 |
74 | def observation_space(self, agent_id: AgentID) -> spaces.Space:
75 | return self.env.observation_space[agent_id]
76 |
77 | def action_space(self, agent_id: AgentID) -> spaces.Space:
78 | return self.env.action_space[agent_id]
79 |
80 |
81 |
82 | def to_pettingzoo_env(
83 | env_cls: type[MultiGridEnv],
84 | *wrappers: gym.Wrapper,
85 | metadata: dict[str, Any] = {}) -> type[ParallelEnv]:
86 | """
87 | Convert a ``MultiGridEnv`` environment class to a PettingZoo ``ParallelEnv`` class.
88 |
89 | Note that this is a wrapper around the environment **class**,
90 | not environment instances.
91 |
92 | Parameters
93 | ----------
94 | env_cls : type[MultiGridEnv]
95 | ``MultiGridEnv`` environment class
96 | wrappers : gym.Wrapper
97 | Gym wrappers to apply to the environment
98 | metadata : dict[str, Any]
99 | Environment metadata
100 |
101 | Returns
102 | -------
103 | pettingzoo_env_cls : type[ParallelEnv]
104 | PettingZoo ``ParallelEnv`` environment class
105 | """
106 | class PettingZooEnv(PettingZooWrapper):
107 | def __init__(self, *args, **kwargs):
108 | env = env_cls(*args, **kwargs)
109 | for wrapper in wrappers:
110 | env = wrapper(env)
111 | super().__init__(env)
112 |
113 | PettingZooEnv.__name__ = f"PettingZoo_{env_cls.__name__}"
114 | PettingZooEnv.metadata = metadata
115 | return PettingZooEnv
116 |
--------------------------------------------------------------------------------
/multigrid/rllib/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This package provides tools for using MultiGrid environments with
3 | the RLlib MultiAgentEnv API.
4 |
5 | *****
6 | Usage
7 | *****
8 |
9 | Use a specific environment configuration from :mod:`multigrid.envs` by name:
10 |
11 | >>> import multigrid.rllib # registers environment configurations with RLlib
12 | >>> from ray.rllib.algorithms.ppo import PPOConfig
13 | >>> algorithm_config = PPOConfig().environment(env='MultiGrid-Empty-8x8-v0')
14 |
15 | Wrap an environment instance with :class:`.RLlibWrapper`:
16 |
17 | >>> import gymnasium as gym
18 | >>> import multigrid.envs
19 | >>> env = gym.make('MultiGrid-Empty-8x8-v0', agents=2, render_mode='human')
20 |
21 | >>> from multigrid.rllib import RLlibWrapper
22 | >>> env = RLlibWrapper(env)
23 |
24 | Wrap an environment class with :func:`.to_rllib_env()`:
25 |
26 | >>> from multigrid.envs import EmptyEnv
27 | >>> from multigrid.rllib import to_rllib_env
28 | >>> MyEnv = to_rllib_env(EmptyEnv, default_config={'size': 8})
29 | >>> config = {'agents': 2, 'render_mode': 'human'}
30 | >>> env = MyEnv(config)
31 | """
32 |
33 | import gymnasium as gym
34 |
35 | from ray.rllib.env import MultiAgentEnv
36 | from ray.tune.registry import register_env
37 |
38 | from ..base import MultiGridEnv
39 | from ..envs import CONFIGURATIONS
40 | from ..wrappers import OneHotObsWrapper
41 |
42 |
43 |
44 | class RLlibWrapper(MultiAgentEnv):
45 | """
46 | Wrapper for a ``MultiGridEnv`` environment that implements the
47 | RLlib ``MultiAgentEnv`` interface.
48 | """
49 |
50 | def __init__(self, env: MultiGridEnv):
51 | super().__init__()
52 | self.env = env
53 | self.agents = list(range(len(env.unwrapped.agents)))
54 | self.possible_agents = self.agents[:]
55 |
56 | def reset(self, *args, **kwargs):
57 | return self.env.reset(*args, **kwargs)
58 |
59 | def step(self, *args, **kwargs):
60 | obs, rewards, terminations, truncations, infos = self.env.step(*args, **kwargs)
61 | terminations['__all__'] = all(terminations.values())
62 | truncations['__all__'] = all(truncations.values())
63 | return obs, rewards, terminations, truncations, infos
64 |
65 | def get_observation_space(self, agent_index: int):
66 | return self.env.unwrapped.agents[agent_index].observation_space
67 |
68 | def get_action_space(self, agent_index: int):
69 | return self.env.unwrapped.agents[agent_index].action_space
70 |
71 |
72 | def to_rllib_env(
73 | env_cls: type[MultiGridEnv],
74 | *wrappers: gym.Wrapper,
75 | default_config: dict = {}) -> type[MultiAgentEnv]:
76 | """
77 | Convert a ``MultiGridEnv`` environment class to an RLLib ``MultiAgentEnv`` class.
78 |
79 | Note that this is a wrapper around the environment **class**,
80 | not environment instances.
81 |
82 | Parameters
83 | ----------
84 | env_cls : type[MultiGridEnv]
85 | ``MultiGridEnv`` environment class
86 | wrappers : gym.Wrapper
87 | Gym wrappers to apply to the environment
88 | default_config : dict
89 | Default configuration for the environment
90 |
91 | Returns
92 | -------
93 | rllib_env_cls : type[MultiAgentEnv]
94 | RLlib ``MultiAgentEnv`` environment class
95 | """
96 | class RLlibEnv(RLlibWrapper):
97 | def __init__(self, config: dict = {}):
98 | config = {**default_config, **config}
99 | env = env_cls(**config)
100 | for wrapper in wrappers:
101 | env = wrapper(env)
102 | super().__init__(env)
103 |
104 | RLlibEnv.__name__ = f"RLlib_{env_cls.__name__}"
105 | return RLlibEnv
106 |
107 |
108 |
109 | # Register environments with RLlib
110 | for name, (env_cls, config) in CONFIGURATIONS.items():
111 | register_env(name, to_rllib_env(env_cls, OneHotObsWrapper, default_config=config))
112 |
--------------------------------------------------------------------------------
/multigrid/core/mission.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 | from gymnasium import spaces
5 | from typing import Any, Callable, Iterable, Sequence
6 |
7 |
8 |
9 | class Mission(np.ndarray):
10 | """
11 | Class representing an agent mission.
12 | """
13 |
14 | def __new__(cls, string: str, index: Iterable[int] | None = None):
15 | """
16 | Parameters
17 | ----------
18 | string : str
19 | Mission string
20 | index : Iterable[int]
21 | Index of mission string in :class:`MissionSpace`
22 | """
23 | mission = np.array(0 if index is None else index)
24 | mission = mission.view(cls)
25 | mission.string = string
26 | return mission.view(cls)
27 |
28 | def __array_finalize__(self, mission):
29 | if mission is None: return
30 | self.string = getattr(mission, 'string', None)
31 |
32 | def __str__(self) -> str:
33 | return self.string
34 |
35 | def __repr__(self) -> str:
36 | return f'{self.__class__.__name__}("{self.string}")'
37 |
38 | def __eq__(self, value: object) -> bool:
39 | return self.string == str(value)
40 |
41 | def __hash__(self) -> int:
42 | return hash(self.string)
43 |
44 |
45 | class MissionSpace(spaces.MultiDiscrete):
46 | """
47 | Class representing a space over agent missions.
48 |
49 | Examples
50 | --------
51 | >>> observation_space = MissionSpace(
52 | ... mission_func=lambda color: f"Get the {color} ball.",
53 | ... ordered_placeholders=[["green", "blue"]])
54 | >>> observation_space.seed(123)
55 | >>> observation_space.sample()
56 | Mission("Get the blue ball.")
57 |
58 | >>> observation_space = MissionSpace.from_string("Get the ball.")
59 | >>> observation_space.sample()
60 | Mission("Get the ball.")
61 | """
62 |
63 | def __init__(
64 | self,
65 | mission_func: Callable[..., str],
66 | ordered_placeholders: Sequence[Sequence[str]] = [],
67 | seed : int | np.random.Generator | None = None):
68 | """
69 | Parameters
70 | ----------
71 | mission_func : Callable(*args) -> str
72 | Deterministic function that generates a mission string
73 | ordered_placeholders : Sequence[Sequence[str]]
74 | Sequence of argument groups, ordered by placing order in ``mission_func()``
75 | seed : int or np.random.Generator or None
76 | Seed for random sampling from the space
77 | """
78 | self.mission_func = mission_func
79 | self.arg_groups = ordered_placeholders
80 | nvec = tuple(len(group) for group in self.arg_groups)
81 | super().__init__(nvec=nvec if nvec else (1,))
82 |
83 | def __repr__(self) -> str:
84 | """
85 | Get a string representation of this space.
86 | """
87 | if self.arg_groups:
88 | return f'MissionSpace({self.mission_func.__name__}, {self.arg_groups})'
89 | return f"MissionSpace('{self.mission_func()}')"
90 |
91 | def get(self, idx: Iterable[int]) -> Mission:
92 | """
93 | Get the mission string corresponding to the given index.
94 |
95 | Parameters
96 | ----------
97 | idx : Iterable[int]
98 | Index of desired argument in each argument group
99 | """
100 | if self.arg_groups:
101 | args = (self.arg_groups[axis][index] for axis, index in enumerate(idx))
102 | return Mission(string=self.mission_func(*args), index=idx)
103 | return Mission(string=self.mission_func())
104 |
105 | def sample(self) -> Mission:
106 | """
107 | Sample a random mission string.
108 | """
109 | idx = super().sample()
110 | return self.get(idx)
111 |
112 | def contains(self, x: Any) -> bool:
113 | """
114 | Check if an item is a valid member of this mission space.
115 |
116 | Parameters
117 | ----------
118 | x : Any
119 | Item to check
120 | """
121 | for idx in np.ndindex(tuple(self.nvec)):
122 | if self.get(idx) == x:
123 | return True
124 | return False
125 |
126 | @staticmethod
127 | def from_string(string: str) -> MissionSpace:
128 | """
129 | Create a mission space containing a single mission string.
130 |
131 | Parameters
132 | ----------
133 | string : str
134 | Mission string
135 | """
136 | return MissionSpace(mission_func=lambda: string)
137 |
--------------------------------------------------------------------------------
/multigrid/envs/playground.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from multigrid.core.mission import MissionSpace
4 | from multigrid.core.roomgrid import RoomGrid
5 |
6 |
7 |
8 | class PlaygroundEnv(RoomGrid):
9 | """
10 | .. image:: https://i.imgur.com/QBz99Vh.gif
11 | :width: 380
12 |
13 | ***********
14 | Description
15 | ***********
16 |
17 | Environment with multiple rooms and random objects.
18 | This environment has no specific goals or rewards.
19 |
20 | *************
21 | Mission Space
22 | *************
23 |
24 | None
25 |
26 | *****************
27 | Observation Space
28 | *****************
29 |
30 | The multi-agent observation space is a Dict mapping from agent index to
31 | corresponding agent observation space.
32 |
33 | Each agent observation is a dictionary with the following entries:
34 |
35 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`)
36 | Encoding of the agent's partially observable view of the environment,
37 | where the object at each grid cell is encoded as a vector:
38 | (:class:`.Type`, :class:`.Color`, :class:`.State`)
39 | * direction : int
40 | Agent's direction (0: right, 1: down, 2: left, 3: up)
41 | * mission : Mission
42 | Task string corresponding to the current environment configuration
43 |
44 | ************
45 | Action Space
46 | ************
47 |
48 | The multi-agent action space is a Dict mapping from agent index to
49 | corresponding agent action space.
50 |
51 | Agent actions are discrete integer values, given by:
52 |
53 | +-----+--------------+-----------------------------+
54 | | Num | Name | Action |
55 | +=====+==============+=============================+
56 | | 0 | left | Turn left |
57 | +-----+--------------+-----------------------------+
58 | | 1 | right | Turn right |
59 | +-----+--------------+-----------------------------+
60 | | 2 | forward | Move forward |
61 | +-----+--------------+-----------------------------+
62 | | 3 | pickup | Pick up an object |
63 | +-----+--------------+-----------------------------+
64 | | 4 | drop | Drop an object |
65 | +-----+--------------+-----------------------------+
66 | | 5 | toggle | Toggle / activate an object |
67 | +-----+--------------+-----------------------------+
68 | | 6 | done | Done completing task |
69 | +-----+--------------+-----------------------------+
70 |
71 | *******
72 | Rewards
73 | *******
74 |
75 | None
76 |
77 | ***********
78 | Termination
79 | ***********
80 |
81 | The episode ends when the following condition is met:
82 |
83 | * Timeout (see ``max_steps``)
84 |
85 | *************************
86 | Registered Configurations
87 | *************************
88 |
89 | * ``MultiGrid-Playground-v0``
90 | """
91 |
92 | def __init__(
93 | self,
94 | room_size: int = 7,
95 | num_rows: int = 3,
96 | num_cols: int = 3,
97 | max_steps: int = 100,
98 | **kwargs):
99 | """
100 | Parameters
101 | ----------
102 | room_size : int, default=7
103 | Width and height for each of the rooms
104 | num_rows : int, default=3
105 | Number of rows of rooms
106 | num_cols : int, default=3
107 | Number of columns of rooms
108 | max_steps : int, default=100
109 | Maximum number of steps per episode
110 | **kwargs
111 | See :attr:`multigrid.base.MultiGridEnv.__init__`
112 | """
113 | super().__init__(
114 | mission_space=MissionSpace.from_string(""),
115 | num_rows=num_rows,
116 | num_cols=num_cols,
117 | room_size=room_size,
118 | max_steps=max_steps,
119 | **kwargs,
120 | )
121 |
122 | def _gen_grid(self, width, height):
123 | """
124 | :meta private:
125 | """
126 | super()._gen_grid(width, height)
127 | self.connect_all()
128 |
129 | # Place random objects in the world
130 | for i in range(0, 12):
131 | col = self._rand_int(0, self.num_cols)
132 | row = self._rand_int(0, self.num_rows)
133 | self.add_object(col, row)
134 |
135 | # Place agents
136 | for agent in self.agents:
137 | self.place_agent(agent)
138 |
--------------------------------------------------------------------------------
/scripts/visualize.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import numpy as np
4 |
5 | from ray.rllib.utils.typing import AgentID
6 | from ray.rllib.utils.torch_utils import convert_to_torch_tensor
7 | from typing import Callable
8 |
9 | from train import get_algorithm_config, find_checkpoint_dir, get_policy_mapping_fn
10 | from ray.rllib.core.rl_module import RLModule
11 |
12 |
13 |
14 | def visualize(
15 | modules: dict[str, RLModule],
16 | policy_mapping_fn: Callable[[AgentID], str],
17 | num_episodes: int = 10) -> list[np.ndarray]:
18 | """
19 | Visualize trajectories from trained agents.
20 |
21 | Parameters
22 | ----------
23 | algorithm : Algorithm
24 | RLlib algorithm instance with trained policies
25 | policy_mapping_fn : Callable(AgentID) -> str
26 | Function mapping agent IDs to policy IDs
27 | num_episodes : int, default=10
28 | Number of episodes to visualize
29 | """
30 | frames = []
31 | env = algorithm.env_creator(algorithm.config.env_config)
32 |
33 | for episode in range(num_episodes):
34 | print()
35 | print('-' * 32, '\n', 'Episode', episode, '\n', '-' * 32)
36 |
37 | episode_rewards = {agent_id: 0.0 for agent_id in env.possible_agents}
38 | terminations, truncations = {'__all__': False}, {'__all__': False}
39 | observations, infos = env.reset()
40 |
41 | # Get initial states for each agent
42 | states = {
43 | agent_id: modules[policy_mapping_fn(agent_id)].get_initial_state()
44 | for agent_id in env.agents
45 | }
46 |
47 | while not terminations['__all__'] and not truncations['__all__']:
48 | # Store current frame
49 | frames.append(env.env.unwrapped.get_frame())
50 |
51 | # Compute actions for each agent
52 | actions = {}
53 | observations = convert_to_torch_tensor(observations)
54 | for agent_id in env.agents:
55 | agent_module = modules[policy_mapping_fn(agent_id)]
56 | out = agent_module.forward_inference({'obs': observations[agent_id]})
57 | logits = out['action_dist_inputs']
58 | action_dist_class = agent_module.get_inference_action_dist_cls()
59 | action_dist = action_dist_class.from_logits(logits)
60 | actions[agent_id] = action_dist.sample().item()
61 |
62 | # Take actions in environment and accumulate rewards
63 | observations, rewards, terminations, truncations, infos = env.step(actions)
64 | for agent_id in rewards:
65 | episode_rewards[agent_id] += rewards[agent_id]
66 |
67 | frames.append(env.env.unwrapped.get_frame())
68 | print('Rewards:', episode_rewards)
69 |
70 | env.close()
71 | return frames
72 |
73 |
74 |
75 | if __name__ == '__main__':
76 | parser = argparse.ArgumentParser()
77 | parser.add_argument(
78 | '--algo', type=str, default='PPO',
79 | help="The name of the RLlib-registered algorithm to use.")
80 | parser.add_argument(
81 | '--framework', type=str, choices=['torch', 'tf', 'tf2'], default='torch',
82 | help="Deep learning framework to use.")
83 | parser.add_argument(
84 | '--lstm', action='store_true', help="Use LSTM model.")
85 | parser.add_argument(
86 | '--env', type=str, default='MultiGrid-Empty-8x8-v0',
87 | help="MultiGrid environment to use.")
88 | parser.add_argument(
89 | '--env-config', type=json.loads, default={},
90 | help="Environment config dict, given as a JSON string (e.g. '{\"size\": 8}')")
91 | parser.add_argument(
92 | '--num-agents', type=int, default=2, help="Number of agents in environment.")
93 | parser.add_argument(
94 | '--num-episodes', type=int, default=10, help="Number of episodes to visualize.")
95 | parser.add_argument(
96 | '--load-dir', type=str,
97 | help="Checkpoint directory for loading pre-trained policies.")
98 | parser.add_argument(
99 | '--gif', type=str, help="Store output as GIF at given path.")
100 |
101 | args = parser.parse_args()
102 | args.env_config.update(render_mode='human')
103 | config = get_algorithm_config(
104 | **vars(args),
105 | num_workers=0,
106 | num_gpus=0,
107 | )
108 | algorithm = config.build()
109 | checkpoint = find_checkpoint_dir(args.load_dir)
110 | policy_mapping_fn = lambda agent_id, *args, **kwargs: f'policy_{agent_id}'
111 |
112 | if checkpoint:
113 | print(f"Loading checkpoint from {checkpoint}")
114 | path = checkpoint / 'learner_group' / 'learner' / 'rl_module/'
115 | modules = RLModule.from_checkpoint(path)
116 | policy_mapping_fn = get_policy_mapping_fn(checkpoint, args.num_agents)
117 |
118 | frames = visualize(modules, policy_mapping_fn, num_episodes=args.num_episodes)
119 | if args.gif:
120 | from array2gif import write_gif
121 | filename = args.gif if args.gif.endswith('.gif') else f'{args.gif}.gif'
122 | print(f"Saving GIF to {filename}")
123 | write_gif(np.array(frames), filename, fps=10)
124 |
--------------------------------------------------------------------------------
/multigrid/envs/blockedunlockpickup.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from multigrid.core.constants import Color, Direction, Type
4 | from multigrid.core.mission import MissionSpace
5 | from multigrid.core.roomgrid import RoomGrid
6 | from multigrid.core.world_object import Ball
7 |
8 |
9 |
10 | class BlockedUnlockPickupEnv(RoomGrid):
11 | """
12 | .. image:: https://i.imgur.com/uSFi059.gif
13 | :width: 275
14 |
15 | ***********
16 | Description
17 | ***********
18 |
19 | The objective is to pick up a box which is placed in another room, behind a
20 | locked door. The door is also blocked by a ball which must be moved before
21 | the door can be unlocked. Hence, agents must learn to move the ball,
22 | pick up the key, open the door and pick up the object in the other
23 | room.
24 |
25 | The standard setting is cooperative, where all agents receive the reward
26 | when the task is completed.
27 |
28 | *************
29 | Mission Space
30 | *************
31 |
32 | "pick up the ``{color}`` box"
33 |
34 | ``{color}`` is the color of the box. Can be any :class:`.Color`.
35 |
36 | *****************
37 | Observation Space
38 | *****************
39 |
40 | The multi-agent observation space is a Dict mapping from agent index to
41 | corresponding agent observation space.
42 |
43 | Each agent observation is a dictionary with the following entries:
44 |
45 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`)
46 | Encoding of the agent's partially observable view of the environment,
47 | where the object at each grid cell is encoded as a vector:
48 | (:class:`.Type`, :class:`.Color`, :class:`.State`)
49 | * direction : int
50 | Agent's direction (0: right, 1: down, 2: left, 3: up)
51 | * mission : Mission
52 | Task string corresponding to the current environment configuration
53 |
54 | ************
55 | Action Space
56 | ************
57 |
58 | The multi-agent action space is a Dict mapping from agent index to
59 | corresponding agent action space.
60 |
61 | Agent actions are discrete integer values, given by:
62 |
63 | +-----+--------------+-----------------------------+
64 | | Num | Name | Action |
65 | +=====+==============+=============================+
66 | | 0 | left | Turn left |
67 | +-----+--------------+-----------------------------+
68 | | 1 | right | Turn right |
69 | +-----+--------------+-----------------------------+
70 | | 2 | forward | Move forward |
71 | +-----+--------------+-----------------------------+
72 | | 3 | pickup | Pick up an object |
73 | +-----+--------------+-----------------------------+
74 | | 4 | drop | Drop an object |
75 | +-----+--------------+-----------------------------+
76 | | 5 | toggle | Toggle / activate an object |
77 | +-----+--------------+-----------------------------+
78 | | 6 | done | Done completing task |
79 | +-----+--------------+-----------------------------+
80 |
81 | *******
82 | Rewards
83 | *******
84 |
85 | A reward of ``1 - 0.9 * (step_count / max_steps)`` is given for success,
86 | and ``0`` for failure.
87 |
88 | ***********
89 | Termination
90 | ***********
91 |
92 | The episode ends if any one of the following conditions is met:
93 |
94 | * Any agent picks up the correct box
95 | * Timeout (see ``max_steps``)
96 |
97 | *************************
98 | Registered Configurations
99 | *************************
100 |
101 | * ``MultiGrid-BlockedUnlockPickup-v0``
102 | """
103 |
104 | def __init__(
105 | self,
106 | room_size: int = 6,
107 | max_steps: int | None = None,
108 | joint_reward: bool = True,
109 | **kwargs):
110 | """
111 | Parameters
112 | ----------
113 | room_size : int, default=6
114 | Width and height for each of the two rooms
115 | max_steps : int, optional
116 | Maximum number of steps per episode
117 | joint_reward : bool, default=True
118 | Whether all agents receive the reward when the task is completed
119 | **kwargs
120 | See :attr:`multigrid.base.MultiGridEnv.__init__`
121 | """
122 | assert room_size >= 4
123 | mission_space = MissionSpace(
124 | mission_func=self._gen_mission,
125 | ordered_placeholders=[list(Color), [Type.box, Type.key]],
126 | )
127 | super().__init__(
128 | mission_space=mission_space,
129 | num_rows=1,
130 | num_cols=2,
131 | room_size=room_size,
132 | max_steps=max_steps or (16 * room_size**2),
133 | joint_reward=joint_reward,
134 | success_termination_mode='any',
135 | **kwargs,
136 | )
137 |
138 | @staticmethod
139 | def _gen_mission(color: str, obj_type: str):
140 | return f"pick up the {color} {obj_type}"
141 |
142 | def _gen_grid(self, width, height):
143 | """
144 | :meta private:
145 | """
146 | super()._gen_grid(width, height)
147 |
148 | # Add a box to the room on the right
149 | self.obj, _ = self.add_object(1, 0, kind=Type.box)
150 |
151 | # Make sure the two rooms are directly connected by a locked door
152 | door, pos = self.add_door(0, 0, Direction.right, locked=True)
153 |
154 | # Block the door with a ball
155 | self.grid.set(pos[0] - 1, pos[1], Ball(color=self._rand_color()))
156 |
157 | # Add a key to unlock the door
158 | self.add_object(0, 0, Type.key, door.color)
159 |
160 | # Place agents in the left room
161 | for agent in self.agents:
162 | self.place_agent(agent, 0, 0)
163 |
164 | self.mission = f"pick up the {self.obj.color} {self.obj.type}"
165 |
166 | def step(self, actions):
167 | """
168 | :meta private:
169 | """
170 | obs, reward, terminated, truncated, info = super().step(actions)
171 | for agent in self.agents:
172 | if agent.state.carrying == self.obj:
173 | self.on_success(agent, reward, terminated)
174 |
175 | return obs, reward, terminated, truncated, info
176 |
--------------------------------------------------------------------------------
/multigrid/utils/minigrid_interface.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gymnasium import spaces
4 | from gymnasium.core import ActType, ObsType
5 | from typing import Any, Sequence, SupportsFloat
6 |
7 | from ..core.world_object import WorldObj
8 | from ..base import MultiGridEnv
9 |
10 |
11 |
12 | class MiniGridInterface(MultiGridEnv):
13 | """
14 | MultiGridEnv interface for compatibility with single-agent MiniGrid environments.
15 |
16 | Most environment implementations deriving from `minigrid.MiniGridEnv` can be
17 | converted to a single-agent `MultiGridEnv` by simply inheriting from
18 | `MiniGridInterface` instead (along with using the multigrid grid and grid objects).
19 |
20 | Examples
21 | --------
22 | Start with a single-agent minigrid environment:
23 |
24 | >>> from minigrid.core.world_object import Ball, Key, Door
25 | >>> from minigrid.core.grid import Grid
26 | >>> from minigrid import MiniGridEnv
27 |
28 | >>> class MyEnv(MiniGridEnv):
29 | >>> ... # existing class definition
30 |
31 | Now use multigrid imports, keeping the environment class definition the same:
32 |
33 | >>> from multigrid.core.world_object import Ball, Key, Door
34 | >>> from multigrid.core.grid import Grid
35 | >>> from multigrid.utils.minigrid_interface import MiniGridInterface as MiniGridEnv
36 |
37 | >>> class MyEnv(MiniGridEnv):
38 | >>> ... # same class definition
39 | """
40 |
41 | def reset(self, *args, **kwargs) -> tuple[ObsType, dict[str, Any]]:
42 | """
43 | Reset the environment.
44 | """
45 | result = super().reset(*args, **kwargs)
46 | return tuple(item[0] for item in result)
47 |
48 | def step(self, action: ActType) -> tuple[
49 | ObsType,
50 | SupportsFloat,
51 | bool,
52 | bool,
53 | dict[str, Any]]:
54 | """
55 | Run one timestep of the environment’s dynamics
56 | using the provided agent action.
57 | """
58 | result = super().step({0: action})
59 | return tuple(item[0] for item in result)
60 |
61 | @property
62 | def action_space(self) -> spaces.Space:
63 | """
64 | Get action space.
65 | """
66 | assert len(self.agents) == 1, (
67 | "This property is not supported for multi-agent envs. "
68 | "Use `env.agents[i].action_space` instead."
69 | )
70 | return self.agents[0].action_space
71 |
72 | @action_space.setter
73 | def action_space(self, space: spaces.Space):
74 | """
75 | Set action space.
76 | """
77 | assert len(self.agents) == 1, (
78 | "This property is not supported for multi-agent envs. "
79 | "Use `env.agents[i].action_space` instead."
80 | )
81 | self.agents[0].action_space = space
82 |
83 | @property
84 | def observation_space(self) -> spaces.Space:
85 | """
86 | Get observation space.
87 | """
88 | assert len(self.agents) == 1, (
89 | "This property is not supported for multi-agent envs. "
90 | "Use `env.agents[i].observation_space` instead."
91 | )
92 | return self.agents[0].observation_space
93 |
94 | @observation_space.setter
95 | def observation_space(self, space: spaces.Space):
96 | """
97 | Set observation space.
98 | """
99 | assert len(self.agents) == 1, (
100 | "This property is not supported for multi-agent envs. "
101 | "Use `env.agents[i].observation_space` instead."
102 | )
103 | self.agents[0].observation_space = space
104 |
105 | @property
106 | def agent_pos(self) -> np.ndarray[int]:
107 | """
108 | Get agent position.
109 | """
110 | assert len(self.agents) == 1, (
111 | "This property is not supported for multi-agent envs. "
112 | "Use `env.agents[i].pos` instead."
113 | )
114 | return self.agents[0].pos
115 |
116 | @agent_pos.setter
117 | def agent_pos(self, value: Sequence[int]):
118 | """
119 | Set agent position.
120 | """
121 | assert len(self.agents) == 1, (
122 | "This property is not supported for multi-agent envs. "
123 | "Use `env.agents[i].pos` instead."
124 | )
125 | if value is not None:
126 | self.agents[0].pos = value
127 |
128 | @property
129 | def agent_dir(self) -> int:
130 | """
131 | Get agent direction.
132 | """
133 | assert len(self.agents) == 1, (
134 | "This property is not supported for multi-agent envs. "
135 | "Use `env.agents[i].dir` instead."
136 | )
137 | return self.agents[0].dir
138 |
139 | @agent_dir.setter
140 | def agent_dir(self, value: Sequence[int]):
141 | """
142 | Set agent direction.
143 | """
144 | assert len(self.agents) == 1, (
145 | "This property is not supported for multi-agent envs. "
146 | "Use `env.agents[i].dir` instead."
147 | )
148 | self.agents[0].dir = value
149 |
150 | @property
151 | def carrying(self) -> WorldObj:
152 | """
153 | Get object carried by agent.
154 | """
155 | assert len(self.agents) == 1, (
156 | "This property is not supported for multi-agent envs. "
157 | "Use `env.agents[i].carrying` instead."
158 | )
159 | return self.agents[0].carrying
160 |
161 | @property
162 | def dir_vec(self):
163 | """
164 | Get the direction vector for the agent, pointing in the direction
165 | of forward movement.
166 | """
167 | assert len(self.agents) == 1, (
168 | "This property is not supported for multi-agent envs. "
169 | "Use `env.agents[i].dir.to_vec()` instead."
170 | )
171 | return self.agents[0].dir.to_vec()
172 |
173 | @property
174 | def front_pos(self):
175 | """
176 | Get the position of the cell that is right in front of the agent.
177 | """
178 | assert len(self.agents) == 1, (
179 | "This property is not supported for multi-agent envs. "
180 | "Use `env.agents[i].front_pos` instead."
181 | )
182 | return self.agents[0].front_pos
183 |
184 | def place_agent(self, *args, **kwargs) -> tuple[int, int]:
185 | """
186 | Set agent starting point at an empty position in the grid.
187 | """
188 | return super().place_agent(self.agents[0], *args, **kwargs)
189 |
--------------------------------------------------------------------------------
/multigrid/envs/empty.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from multigrid import MultiGridEnv
4 | from multigrid.core import Grid
5 | from multigrid.core.constants import Direction
6 | from multigrid.core.world_object import Goal
7 |
8 |
9 |
10 | class EmptyEnv(MultiGridEnv):
11 | """
12 | .. image:: https://i.imgur.com/wY0tT7R.gif
13 | :width: 200
14 |
15 | ***********
16 | Description
17 | ***********
18 |
19 | This environment is an empty room, and the goal for each agent is to reach the
20 | green goal square, which provides a sparse reward. A small penalty is subtracted
21 | for the number of steps to reach the goal.
22 |
23 | The standard setting is competitive, where agents race to the goal, and
24 | only the winner receives a reward.
25 |
26 | This environment is useful with small rooms, to validate that your RL algorithm
27 | works correctly, and with large rooms to experiment with sparse rewards and
28 | exploration. The random variants of the environment have the agents starting
29 | at a random position for each episode, while the regular variants have the
30 | agent always starting in the corner opposite to the goal.
31 |
32 | *************
33 | Mission Space
34 | *************
35 |
36 | "get to the green goal square"
37 |
38 | *****************
39 | Observation Space
40 | *****************
41 |
42 | The multi-agent observation space is a Dict mapping from agent index to
43 | corresponding agent observation space.
44 |
45 | Each agent observation is a dictionary with the following entries:
46 |
47 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`)
48 | Encoding of the agent's partially observable view of the environment,
49 | where the object at each grid cell is encoded as a vector:
50 | (:class:`.Type`, :class:`.Color`, :class:`.State`)
51 | * direction : int
52 | Agent's direction (0: right, 1: down, 2: left, 3: up)
53 | * mission : Mission
54 | Task string corresponding to the current environment configuration
55 |
56 | ************
57 | Action Space
58 | ************
59 |
60 | The multi-agent action space is a Dict mapping from agent index to
61 | corresponding agent action space.
62 |
63 | Agent actions are discrete integer values, given by:
64 |
65 | +-----+--------------+-----------------------------+
66 | | Num | Name | Action |
67 | +=====+==============+=============================+
68 | | 0 | left | Turn left |
69 | +-----+--------------+-----------------------------+
70 | | 1 | right | Turn right |
71 | +-----+--------------+-----------------------------+
72 | | 2 | forward | Move forward |
73 | +-----+--------------+-----------------------------+
74 | | 3 | pickup | Pick up an object |
75 | +-----+--------------+-----------------------------+
76 | | 4 | drop | Drop an object |
77 | +-----+--------------+-----------------------------+
78 | | 5 | toggle | Toggle / activate an object |
79 | +-----+--------------+-----------------------------+
80 | | 6 | done | Done completing task |
81 | +-----+--------------+-----------------------------+
82 |
83 | *******
84 | Rewards
85 | *******
86 |
87 | A reward of ``1 - 0.9 * (step_count / max_steps)`` is given for success,
88 | and ``0`` for failure.
89 |
90 | ***********
91 | Termination
92 | ***********
93 |
94 | The episode ends if any one of the following conditions is met:
95 |
96 | * Any agent reaches the goal
97 | * Timeout (see ``max_steps``)
98 |
99 | *************************
100 | Registered Configurations
101 | *************************
102 |
103 | * ``MultiGrid-Empty-5x5-v0``
104 | * ``MultiGrid-Empty-Random-5x5-v0``
105 | * ``MultiGrid-Empty-6x6-v0``
106 | * ``MultiGrid-Empty-Random-6x6-v0``
107 | * ``MultiGrid-Empty-8x8-v0``
108 | * ``MultiGrid-Empty-16x16-v0``
109 | """
110 |
111 | def __init__(
112 | self,
113 | size: int = 8,
114 | agent_start_pos: tuple[int, int] | None = (1, 1),
115 | agent_start_dir: Direction | None = Direction.right,
116 | max_steps: int | None = None,
117 | joint_reward: bool = False,
118 | success_termination_mode: str = 'any',
119 | **kwargs):
120 | """
121 | Parameters
122 | ----------
123 | size : int, default=8
124 | Width and height of the grid
125 | agent_start_pos : tuple[int, int], default=(1, 1)
126 | Starting position of the agents (random if None)
127 | agent_start_dir : Direction, default=Direction.right
128 | Starting direction of the agents (random if None)
129 | max_steps : int, optional
130 | Maximum number of steps per episode
131 | joint_reward : bool, default=True
132 | Whether all agents receive the reward when the task is completed
133 | success_termination_mode : 'any' or 'all', default='any'
134 | Whether to terminate the environment when any agent reaches the goal
135 | or after all agents reach the goal
136 | **kwargs
137 | See :attr:`multigrid.base.MultiGridEnv.__init__`
138 | """
139 | self.agent_start_pos = agent_start_pos
140 | self.agent_start_dir = agent_start_dir
141 |
142 | super().__init__(
143 | mission_space="get to the green goal square",
144 | grid_size=size,
145 | max_steps=max_steps or (4 * size**2),
146 | joint_reward=joint_reward,
147 | success_termination_mode=success_termination_mode,
148 | **kwargs,
149 | )
150 |
151 | def _gen_grid(self, width, height):
152 | """
153 | :meta private:
154 | """
155 | # Create an empty grid
156 | self.grid = Grid(width, height)
157 |
158 | # Generate the surrounding walls
159 | self.grid.wall_rect(0, 0, width, height)
160 |
161 | # Place a goal square in the bottom-right corner
162 | self.put_obj(Goal(), width - 2, height - 2)
163 |
164 | # Place the agent
165 | for agent in self.agents:
166 | if self.agent_start_pos is not None and self.agent_start_dir is not None:
167 | agent.state.pos = self.agent_start_pos
168 | agent.state.dir = self.agent_start_dir
169 | else:
170 | self.place_agent(agent)
171 |
--------------------------------------------------------------------------------
/multigrid/envs/redbluedoors.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from multigrid import MultiGridEnv
4 | from multigrid.core import Action, Grid, MissionSpace
5 | from multigrid.core.constants import Color
6 | from multigrid.core.world_object import Door
7 |
8 |
9 |
10 | class RedBlueDoorsEnv(MultiGridEnv):
11 | """
12 | .. image:: https://i.imgur.com/usbavAh.gif
13 | :width: 400
14 |
15 | ***********
16 | Description
17 | ***********
18 |
19 | This environment is a room with one red and one blue door facing
20 | opposite directions. Agents must open the red door and then open the blue door,
21 | in that order.
22 |
23 | The standard setting is cooperative, where all agents receive the reward
24 | upon completion of the task.
25 |
26 | *************
27 | Mission Space
28 | *************
29 |
30 | "open the red door then the blue door"
31 |
32 | *****************
33 | Observation Space
34 | *****************
35 |
36 | The multi-agent observation space is a Dict mapping from agent index to
37 | corresponding agent observation space.
38 |
39 | Each agent observation is a dictionary with the following entries:
40 |
41 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`)
42 | Encoding of the agent's partially observable view of the environment,
43 | where the object at each grid cell is encoded as a vector:
44 | (:class:`.Type`, :class:`.Color`, :class:`.State`)
45 | * direction : int
46 | Agent's direction (0: right, 1: down, 2: left, 3: up)
47 | * mission : Mission
48 | Task string corresponding to the current environment configuration
49 |
50 | ************
51 | Action Space
52 | ************
53 |
54 | The multi-agent action space is a Dict mapping from agent index to
55 | corresponding agent action space.
56 |
57 | Agent actions are discrete integer values, given by:
58 |
59 | +-----+--------------+-----------------------------+
60 | | Num | Name | Action |
61 | +=====+==============+=============================+
62 | | 0 | left | Turn left |
63 | +-----+--------------+-----------------------------+
64 | | 1 | right | Turn right |
65 | +-----+--------------+-----------------------------+
66 | | 2 | forward | Move forward |
67 | +-----+--------------+-----------------------------+
68 | | 3 | pickup | Pick up an object |
69 | +-----+--------------+-----------------------------+
70 | | 4 | drop | Drop an object |
71 | +-----+--------------+-----------------------------+
72 | | 5 | toggle | Toggle / activate an object |
73 | +-----+--------------+-----------------------------+
74 | | 6 | done | Done completing task |
75 | +-----+--------------+-----------------------------+
76 |
77 | *******
78 | Rewards
79 | *******
80 |
81 | A reward of ``1 - 0.9 * (step_count / max_steps)`` is given for success,
82 | and ``0`` for failure.
83 |
84 | ***********
85 | Termination
86 | ***********
87 |
88 | The episode ends if any one of the following conditions is met:
89 |
90 | * Any agent opens the blue door while the red door is open (success)
91 | * Any agent opens the blue door while the red door is not open (failure)
92 | * Timeout (see ``max_steps``)
93 |
94 | *************************
95 | Registered Configurations
96 | *************************
97 |
98 | * ``MultiGrid-RedBlueDoors-6x6-v0``
99 | * ``MultiGrid-RedBlueDoors-8x8-v0``
100 | """
101 |
102 | def __init__(
103 | self,
104 | size: int = 8,
105 | max_steps: int | None = None,
106 | joint_reward: bool = True,
107 | success_termination_mode: str = 'any',
108 | failure_termination_mode: str = 'any',
109 | **kwargs):
110 | """
111 | Parameters
112 | ----------
113 | size : int, default=8
114 | Width and height of the grid
115 | max_steps : int, optional
116 | Maximum number of steps per episode
117 | joint_reward : bool, default=True
118 | Whether all agents receive the reward when the task is completed
119 | success_termination_mode : 'any' or 'all', default='any'
120 | Whether to terminate the environment when any agent fails the task
121 | or after all agents fail the task
122 | failure_termination_mode : 'any' or 'all', default='any'
123 | Whether to terminate the environment when any agent fails the task
124 | or after all agents fail the task
125 | **kwargs
126 | See :attr:`multigrid.base.MultiGridEnv.__init__`
127 | """
128 | self.size = size
129 | mission_space = MissionSpace.from_string("open the red door then the blue door")
130 | super().__init__(
131 | mission_space=mission_space,
132 | width=(2 * size),
133 | height=size,
134 | max_steps=max_steps or (20 * size**2),
135 | joint_reward=joint_reward,
136 | success_termination_mode=success_termination_mode,
137 | failure_termination_mode=failure_termination_mode,
138 | **kwargs,
139 | )
140 |
141 | def _gen_grid(self, width, height):
142 | """
143 | :meta private:
144 | """
145 | # Create an empty grid
146 | self.grid = Grid(width, height)
147 |
148 | # Generate the grid walls
149 | room_top = (width // 4, 0)
150 | room_size = (width // 2, height)
151 | self.grid.wall_rect(0, 0, width, height)
152 | self.grid.wall_rect(*room_top, *room_size)
153 |
154 | # Place agents in the top-left corner
155 | for agent in self.agents:
156 | self.place_agent(agent, top=room_top, size=room_size)
157 |
158 | # Add a red door at a random position in the left wall
159 | x = room_top[0]
160 | y = self._rand_int(1, height - 1)
161 | self.red_door = Door(Color.red)
162 | self.grid.set(x, y, self.red_door)
163 |
164 | # Add a blue door at a random position in the right wall
165 | x = room_top[0] + room_size[0] - 1
166 | y = self._rand_int(1, height - 1)
167 | self.blue_door = Door(Color.blue)
168 | self.grid.set(x, y, self.blue_door)
169 |
170 | def step(self, actions):
171 | """
172 | :meta private:
173 | """
174 | obs, reward, terminated, truncated, info = super().step(actions)
175 |
176 | for agent_id, action in actions.items():
177 | if action == Action.toggle:
178 | agent = self.agents[agent_id]
179 | fwd_obj = self.grid.get(*agent.front_pos)
180 | if fwd_obj == self.blue_door and self.blue_door.is_open:
181 | if self.red_door.is_open:
182 | self.on_success(agent, reward, terminated)
183 | else:
184 | self.on_failure(agent, reward, terminated)
185 | self.blue_door.is_open = False # close the door again
186 |
187 | return obs, reward, terminated, truncated, info
188 |
--------------------------------------------------------------------------------
/multigrid/wrappers.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import gymnasium as gym
4 | import numba as nb
5 | import numpy as np
6 |
7 | from gymnasium import spaces
8 | from gymnasium.core import ObservationWrapper
9 | from numpy.typing import NDArray as ndarray
10 |
11 | from .base import MultiGridEnv, AgentID, ObsType
12 | from .core.constants import Color, Direction, State, Type
13 | from .core.world_object import WorldObj
14 |
15 |
16 |
17 | class FullyObsWrapper(ObservationWrapper):
18 | """
19 | Fully observable gridworld using a compact grid encoding instead of agent view.
20 |
21 | Examples
22 | --------
23 | >>> import gymnasium as gym
24 | >>> import multigrid.envs
25 | >>> env = gym.make('MultiGrid-Empty-16x16-v0')
26 | >>> obs, _ = env.reset()
27 | >>> obs[0]['image'].shape
28 | (7, 7, 3)
29 |
30 | >>> from multigrid.wrappers import FullyObsWrapper
31 | >>> env = FullyObsWrapper(env)
32 | >>> obs, _ = env.reset()
33 | >>> obs[0]['image'].shape
34 | (16, 16, 3)
35 | """
36 |
37 | def __init__(self, env: MultiGridEnv):
38 | """
39 | """
40 | super().__init__(env)
41 |
42 | # Update agent observation spaces
43 | for agent in self.env.agents:
44 | agent.observation_space['image'] = spaces.Box(
45 | low=0, high=255, shape=(env.height, env.width, WorldObj.dim), dtype=int)
46 |
47 | def observation(self, obs: dict[AgentID, ObsType]) -> dict[AgentID, ObsType]:
48 | """
49 | :meta private:
50 | """
51 | img = self.env.grid.encode()
52 | for agent in self.env.agents:
53 | img[agent.state.pos] = agent.encode()
54 |
55 | for agent_id in obs:
56 | obs[agent_id]['image'] = img
57 |
58 | return obs
59 |
60 |
61 | class ImgObsWrapper(ObservationWrapper):
62 | """
63 | Use the image as the only observation output for each agent.
64 |
65 | Examples
66 | --------
67 | >>> import gymnasium as gym
68 | >>> import multigrid.envs
69 | >>> env = gym.make('MultiGrid-Empty-8x8-v0')
70 | >>> obs, _ = env.reset()
71 | >>> obs[0].keys()
72 | dict_keys(['image', 'direction', 'mission'])
73 |
74 | >>> from multigrid.wrappers import ImgObsWrapper
75 | >>> env = ImgObsWrapper(env)
76 | >>> obs, _ = env.reset()
77 | >>> obs.shape
78 | (7, 7, 3)
79 | """
80 |
81 | def __init__(self, env: MultiGridEnv):
82 | """
83 | """
84 | super().__init__(env)
85 |
86 | # Update agent observation spaces
87 | for agent in self.env.agents:
88 | agent.observation_space = agent.observation_space['image']
89 | agent.observation_space.dtype = np.uint8
90 |
91 | def observation(self, obs: dict[AgentID, ObsType]) -> dict[AgentID, ObsType]:
92 | """
93 | :meta private:
94 | """
95 | for agent_id in obs:
96 | obs[agent_id] = obs[agent_id]['image'].astype(np.uint8)
97 |
98 | return obs
99 |
100 |
101 | class OneHotObsWrapper(ObservationWrapper):
102 | """
103 | Wrapper to get a one-hot encoding of a partially observable
104 | agent view as observation.
105 |
106 | Examples
107 | --------
108 | >>> import gymnasium as gym
109 | >>> import multigrid.envs
110 | >>> env = gym.make('MultiGrid-Empty-5x5-v0')
111 | >>> obs, _ = env.reset()
112 | >>> obs[0]['image'][0, :, :]
113 | array([[2, 5, 0],
114 | [2, 5, 0],
115 | [2, 5, 0],
116 | [2, 5, 0],
117 | [2, 5, 0],
118 | [2, 5, 0],
119 | [2, 5, 0]])
120 |
121 | >>> from multigrid.wrappers import OneHotObsWrapper
122 | >>> env = OneHotObsWrapper(env)
123 | >>> obs, _ = env.reset()
124 | >>> obs[0]['image'][0, :, :]
125 | array([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
126 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
127 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
128 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
129 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
130 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
131 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]],
132 | dtype=uint8)
133 | """
134 |
135 | def __init__(self, env: MultiGridEnv):
136 | """
137 | """
138 | super().__init__(env)
139 | self.dim_sizes = np.array([
140 | len(Type), len(Color), max(len(State), len(Direction))])
141 |
142 | # Update agent observation spaces
143 | dim = sum(self.dim_sizes)
144 | for agent in self.env.agents:
145 | view_height, view_width, _ = agent.observation_space['image'].shape
146 | agent.observation_space['image'] = spaces.Box(
147 | low=0, high=1, shape=(view_height, view_width, dim), dtype=np.uint8)
148 |
149 | def observation(self, obs: dict[AgentID, ObsType]) -> dict[AgentID, ObsType]:
150 | """
151 | :meta private:
152 | """
153 | for agent_id in obs:
154 | obs[agent_id]['image'] = self.one_hot(obs[agent_id]['image'], self.dim_sizes)
155 |
156 | return obs
157 |
158 | @staticmethod
159 | @nb.njit(cache=True)
160 | def one_hot(x: ndarray[np.int], dim_sizes: ndarray[np.int]) -> ndarray[np.uint8]:
161 | """
162 | Return a one-hot encoding of a 3D integer array,
163 | where each 2D slice is encoded separately.
164 |
165 | Parameters
166 | ----------
167 | x : ndarray[int] of shape (view_height, view_width, dim)
168 | 3D array of integers to be one-hot encoded
169 | dim_sizes : ndarray[int] of shape (dim,)
170 | Number of possible values for each dimension
171 |
172 | Returns
173 | -------
174 | out : ndarray[uint8] of shape (view_height, view_width, sum(dim_sizes))
175 | One-hot encoding
176 |
177 | :meta private:
178 | """
179 | out = np.zeros((x.shape[0], x.shape[1], sum(dim_sizes)), dtype=np.uint8)
180 |
181 | dim_offset = 0
182 | for d in range(len(dim_sizes)):
183 | for i in range(x.shape[0]):
184 | for j in range(x.shape[1]):
185 | k = dim_offset + x[i, j, d]
186 | out[i, j, k] = 1
187 |
188 | dim_offset += dim_sizes[d]
189 |
190 | return out
191 |
192 |
193 | class SingleAgentWrapper(gym.Wrapper):
194 | """
195 | Wrapper to convert a multi-agent environment into a
196 | single-agent environment.
197 |
198 | Examples
199 | --------
200 | >>> import gymnasium as gym
201 | >>> import multigrid.envs
202 | >>> env = gym.make('MultiGrid-Empty-5x5-v0')
203 | >>> obs, _ = env.reset()
204 | >>> obs[0].keys()
205 | dict_keys(['image', 'direction', 'mission'])
206 |
207 | >>> from multigrid.wrappers import SingleAgentWrapper
208 | >>> env = SingleAgentWrapper(env)
209 | >>> obs, _ = env.reset()
210 | >>> obs.keys()
211 | dict_keys(['image', 'direction', 'mission'])
212 | """
213 |
214 | def __init__(self, env: MultiGridEnv):
215 | """
216 | """
217 | super().__init__(env)
218 | self.observation_space = env.agents[0].observation_space
219 | self.action_space = env.agents[0].action_space
220 |
221 | def reset(self, *args, **kwargs):
222 | """
223 | :meta private:
224 | """
225 | result = super().reset(*args, **kwargs)
226 | return tuple(item[0] for item in result)
227 |
228 | def step(self, action):
229 | """
230 | :meta private:
231 | """
232 | result = super().step({0: action})
233 | return tuple(item[0] for item in result)
234 |
--------------------------------------------------------------------------------
/multigrid/utils/rendering.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 | import numpy as np
5 | from numpy.typing import NDArray as ndarray
6 | from typing import Callable
7 |
8 |
9 |
10 | # Constants
11 |
12 | FilterFunction = Callable[[float, float], bool]
13 | White = np.array([255, 255, 255])
14 |
15 |
16 |
17 | # Functions
18 |
19 | def downsample(img: ndarray[np.uint8], factor: int) -> ndarray[np.uint8]:
20 | """
21 | Downsample an image along both dimensions by some factor.
22 |
23 | Parameters
24 | ----------
25 | img : ndarray[uint8] of shape (height, width, 3)
26 | The image to downsample
27 | factor : int
28 | The factor by which to downsample the image
29 |
30 | Returns
31 | -------
32 | img : ndarray[uint8] of shape (height/factor, width/factor, 3)
33 | The downsampled image
34 | """
35 | assert img.shape[0] % factor == 0
36 | assert img.shape[1] % factor == 0
37 |
38 | img = img.reshape(
39 | [img.shape[0] // factor, factor, img.shape[1] // factor, factor, 3]
40 | )
41 | img = img.mean(axis=3)
42 | img = img.mean(axis=1)
43 |
44 | return img
45 |
46 | def fill_coords(
47 | img: ndarray[np.uint8],
48 | fn: FilterFunction,
49 | color: ndarray[np.uint8]) -> ndarray[np.uint8]:
50 | """
51 | Fill pixels of an image with coordinates matching a filter function.
52 |
53 | Parameters
54 | ----------
55 | img : ndarray[uint8] of shape (height, width, 3)
56 | The image to fill
57 | fn : Callable(float, float) -> bool
58 | The filter function to use for coordinates
59 | color : ndarray[uint8] of shape (3,)
60 | RGB color to fill matching coordinates
61 |
62 | Returns
63 | -------
64 | img : ndarray[np.uint8] of shape (height, width, 3)
65 | The updated image
66 | """
67 | for y in range(img.shape[0]):
68 | for x in range(img.shape[1]):
69 | yf = (y + 0.5) / img.shape[0]
70 | xf = (x + 0.5) / img.shape[1]
71 | if fn(xf, yf):
72 | img[y, x] = color
73 |
74 | return img
75 |
76 | def rotate_fn(fin: FilterFunction, cx: float, cy: float, theta: float) -> FilterFunction:
77 | """
78 | Rotate a coordinate filter function around a center point by some angle.
79 |
80 | Parameters
81 | ----------
82 | fin : Callable(float, float) -> bool
83 | The filter function to rotate
84 | cx : float
85 | The x-coordinate of the center of rotation
86 | cy : float
87 | The y-coordinate of the center of rotation
88 | theta : float
89 | The angle by which to rotate the filter function (in radians)
90 |
91 | Returns
92 | -------
93 | fout : Callable(float, float) -> bool
94 | The rotated filter function
95 | """
96 | def fout(x, y):
97 | x = x - cx
98 | y = y - cy
99 |
100 | x2 = cx + x * math.cos(-theta) - y * math.sin(-theta)
101 | y2 = cy + y * math.cos(-theta) + x * math.sin(-theta)
102 |
103 | return fin(x2, y2)
104 |
105 | return fout
106 |
107 | def point_in_line(
108 | x0: float, y0: float, x1: float, y1: float, r: float) -> FilterFunction:
109 | """
110 | Return a filter function that returns True for points within distance r
111 | from the line between (x0, y0) and (x1, y1).
112 |
113 | Parameters
114 | ----------
115 | x0 : float
116 | The x-coordinate of the line start point
117 | y0 : float
118 | The y-coordinate of the line start point
119 | x1 : float
120 | The x-coordinate of the line end point
121 | y1 : float
122 | The y-coordinate of the line end point
123 | r : float
124 | Maximum distance from the line
125 |
126 | Returns
127 | -------
128 | fn : Callable(float, float) -> bool
129 | Filter function
130 | """
131 | p0 = np.array([x0, y0], dtype=np.float32)
132 | p1 = np.array([x1, y1], dtype=np.float32)
133 | dir = p1 - p0
134 | dist = np.linalg.norm(dir)
135 | dir = dir / dist
136 |
137 | xmin = min(x0, x1) - r
138 | xmax = max(x0, x1) + r
139 | ymin = min(y0, y1) - r
140 | ymax = max(y0, y1) + r
141 |
142 | def fn(x, y):
143 | # Fast, early escape test
144 | if x < xmin or x > xmax or y < ymin or y > ymax:
145 | return False
146 |
147 | q = np.array([x, y])
148 | pq = q - p0
149 |
150 | # Closest point on line
151 | a = np.dot(pq, dir)
152 | a = np.clip(a, 0, dist)
153 | p = p0 + a * dir
154 |
155 | dist_to_line = np.linalg.norm(q - p)
156 | return dist_to_line <= r
157 |
158 | return fn
159 |
160 | def point_in_circle(cx: float, cy: float, r: float) -> FilterFunction:
161 | """
162 | Return a filter function that returns True for points within radius r
163 | from a given point.
164 |
165 | Parameters
166 | ----------
167 | cx : float
168 | The x-coordinate of the circle center
169 | cy : float
170 | The y-coordinate of the circle center
171 | r : float
172 | The radius of the circle
173 |
174 | Returns
175 | -------
176 | fn : Callable(float, float) -> bool
177 | Filter function
178 | """
179 | def fn(x, y):
180 | return (x - cx) * (x - cx) + (y - cy) * (y - cy) <= r * r
181 |
182 | return fn
183 |
184 | def point_in_rect(xmin: float, xmax: float, ymin: float, ymax: float) -> FilterFunction:
185 | """
186 | Return a filter function that returns True for points within a rectangle.
187 |
188 | Parameters
189 | ----------
190 | xmin : float
191 | The minimum x-coordinate of the rectangle
192 | xmax : float
193 | The maximum x-coordinate of the rectangle
194 | ymin : float
195 | The minimum y-coordinate of the rectangle
196 | ymax : float
197 | The maximum y-coordinate of the rectangle
198 |
199 | Returns
200 | -------
201 | fn : Callable(float, float) -> bool
202 | Filter function
203 | """
204 | def fn(x, y):
205 | return x >= xmin and x <= xmax and y >= ymin and y <= ymax
206 |
207 | return fn
208 |
209 | def point_in_triangle(
210 | a: tuple[float, float],
211 | b: tuple[float, float],
212 | c: tuple[float, float]) -> FilterFunction:
213 | """
214 | Return a filter function that returns True for points within a triangle.
215 |
216 | Parameters
217 | ----------
218 | a : tuple[float, float]
219 | The first vertex of the triangle
220 | b : tuple[float, float]
221 | The second vertex of the triangle
222 | c : tuple[float, float]
223 | The third vertex of the triangle
224 |
225 | Returns
226 | -------
227 | fn : Callable(float, float) -> bool
228 | Filter function
229 | """
230 | a = np.array(a, dtype=np.float32)
231 | b = np.array(b, dtype=np.float32)
232 | c = np.array(c, dtype=np.float32)
233 |
234 | def fn(x, y):
235 | v0 = c - a
236 | v1 = b - a
237 | v2 = np.array((x, y)) - a
238 |
239 | # Compute dot products
240 | dot00 = np.dot(v0, v0)
241 | dot01 = np.dot(v0, v1)
242 | dot02 = np.dot(v0, v2)
243 | dot11 = np.dot(v1, v1)
244 | dot12 = np.dot(v1, v2)
245 |
246 | # Compute barycentric coordinates
247 | inv_denom = 1 / (dot00 * dot11 - dot01 * dot01)
248 | u = (dot11 * dot02 - dot01 * dot12) * inv_denom
249 | v = (dot00 * dot12 - dot01 * dot02) * inv_denom
250 |
251 | # Check if point is in triangle
252 | return (u >= 0) and (v >= 0) and (u + v) < 1
253 |
254 | return fn
255 |
256 | def highlight_img(
257 | img: ndarray[np.uint8],
258 | color: ndarray[np.uint8] = White,
259 | alpha=0.30) -> ndarray[np.uint8]:
260 | """
261 | Add highlighting to an image.
262 |
263 | Parameters
264 | ----------
265 | img : ndarray[uint8] of shape (height, width, 3)
266 | The image to highlight
267 | color : ndarray[uint8] of shape (3,)
268 | RGB color to use for highlighting
269 | alpha : float
270 | The alpha value to use for blending
271 |
272 | Returns
273 | -------
274 | img : ndarray[uint8] of shape (height, width, 3)
275 | The highlighted image
276 | """
277 | blend_img = img + alpha * (np.array(color, dtype=np.uint8) - img)
278 | blend_img = blend_img.clip(0, 255).astype(np.uint8)
279 | img[:, :, :] = blend_img
280 |
--------------------------------------------------------------------------------
/multigrid/envs/locked_hallway.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from math import ceil
4 | from multigrid import MultiGridEnv
5 | from multigrid.core.actions import Action
6 | from multigrid.core.constants import Color, Direction
7 | from multigrid.core.mission import MissionSpace
8 | from multigrid.core.roomgrid import Room, RoomGrid
9 | from multigrid.core.world_object import Door, Key
10 |
11 |
12 |
13 | class LockedHallwayEnv(RoomGrid):
14 | """
15 | .. image:: https://i.imgur.com/VylPtnn.gif
16 | :width: 325
17 |
18 | ***********
19 | Description
20 | ***********
21 |
22 | This environment consists of a hallway with multiple locked rooms on either side.
23 | To unlock each door, agents must first find the corresponding key,
24 | which may be in another locked room. Agents are rewarded for each door they unlock.
25 |
26 | The standard setting is cooperative, where all agents receive a reward
27 | for each door that is opened.
28 |
29 | *************
30 | Mission Space
31 | *************
32 |
33 | "unlock all the doors"
34 |
35 | *****************
36 | Observation Space
37 | *****************
38 |
39 | The multi-agent observation space is a Dict mapping from agent index to
40 | corresponding agent observation space.
41 |
42 | Each agent observation is a dictionary with the following entries:
43 |
44 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`)
45 | Encoding of the agent's partially observable view of the environment,
46 | where each grid cell is encoded as a 3 dimensional tuple:
47 | (:class:`.Type`, :class:`.Color`, :class:`.State`)
48 | * direction : int
49 | Agent's direction (0: right, 1: down, 2: left, 3: up)
50 | * mission : Mission
51 | Task string corresponding to the current environment configuration
52 |
53 | ************
54 | Action Space
55 | ************
56 |
57 | The multi-agent action space is a Dict mapping from agent index to
58 | corresponding agent action space.
59 |
60 | Agent actions are discrete integer values, given by:
61 |
62 | +-----+--------------+-----------------------------+
63 | | Num | Name | Action |
64 | +=====+==============+=============================+
65 | | 0 | left | Turn left |
66 | +-----+--------------+-----------------------------+
67 | | 1 | right | Turn right |
68 | +-----+--------------+-----------------------------+
69 | | 2 | forward | Move forward |
70 | +-----+--------------+-----------------------------+
71 | | 3 | pickup | Pick up an object |
72 | +-----+--------------+-----------------------------+
73 | | 4 | drop | Drop an object |
74 | +-----+--------------+-----------------------------+
75 | | 5 | toggle | Toggle / activate an object |
76 | +-----+--------------+-----------------------------+
77 | | 6 | done | Done completing task |
78 | +-----+--------------+-----------------------------+
79 |
80 | *******
81 | Rewards
82 | *******
83 |
84 | A reward of ``1 - 0.9 * (step_count / max_steps)`` is given
85 | when a door is unlocked.
86 |
87 | ***********
88 | Termination
89 | ***********
90 |
91 | The episode ends if any one of the following conditions is met:
92 |
93 | * All doors are unlocked
94 | * Timeout (see ``max_steps``)
95 |
96 | *************************
97 | Registered Configurations
98 | *************************
99 |
100 | * ``MultiGrid-LockedHallway-2Rooms-v0``
101 | * ``MultiGrid-LockedHallway-4Rooms-v0``
102 | * ``MultiGrid-LockedHallway-6Rooms-v0``
103 | """
104 |
105 | def __init__(
106 | self,
107 | num_rooms: int = 6,
108 | room_size: int = 5,
109 | max_hallway_keys: int = 1,
110 | max_keys_per_room: int = 2,
111 | max_steps: int | None = None,
112 | joint_reward: bool = True,
113 | **kwargs):
114 | """
115 | Parameters
116 | ----------
117 | num_rooms : int, default=6
118 | Number of rooms in the environment
119 | room_size : int, default=5
120 | Width and height for each of the rooms
121 | max_hallway_keys : int, default=1
122 | Maximum number of keys in the hallway
123 | max_keys_per_room : int, default=2
124 | Maximum number of keys in each room
125 | max_steps : int, optional
126 | Maximum number of steps per episode
127 | joint_reward : bool, default=True
128 | Whether all agents receive the same reward
129 | **kwargs
130 | See :attr:`multigrid.base.MultiGridEnv.__init__`
131 | """
132 | assert room_size >= 4
133 | assert num_rooms % 2 == 0
134 |
135 | self.num_rooms = num_rooms
136 | self.max_hallway_keys = max_hallway_keys
137 | self.max_keys_per_room = max_keys_per_room
138 |
139 | if max_steps is None:
140 | max_steps = 8 * num_rooms * room_size**2
141 |
142 | super().__init__(
143 | mission_space=MissionSpace.from_string("unlock all the doors"),
144 | room_size=room_size,
145 | num_rows=(num_rooms // 2),
146 | num_cols=3,
147 | max_steps=max_steps,
148 | joint_reward=joint_reward,
149 | **kwargs,
150 | )
151 |
152 | def _gen_grid(self, width, height):
153 | """
154 | :meta private:
155 | """
156 | super()._gen_grid(width, height)
157 |
158 | LEFT, HALLWAY, RIGHT = range(3) # columns
159 | color_sequence = list(Color) * ceil(self.num_rooms / len(Color))
160 | color_sequence = self._rand_perm(color_sequence)[:self.num_rooms]
161 |
162 | # Create hallway
163 | for row in range(self.num_rows - 1):
164 | self.remove_wall(HALLWAY, row, Direction.down)
165 |
166 | # Add doors
167 | self.rooms: dict[Color, Room] = {}
168 | door_colors = self._rand_perm(color_sequence)
169 | for row in range(self.num_rows):
170 | for col, dir in ((LEFT, Direction.right), (RIGHT, Direction.left)):
171 | color = door_colors.pop()
172 | self.rooms[color] = self.get_room(col, row)
173 | self.add_door(
174 | col, row, dir=dir, color=color, locked=True, rand_pos=False)
175 |
176 | # Place keys in hallway
177 | num_hallway_keys = self._rand_int(1, self.max_hallway_keys + 1)
178 | hallway_top = self.get_room(HALLWAY, 0).top
179 | hallway_size = (self.get_room(HALLWAY, 0).size[0], self.height)
180 | for key_color in color_sequence[:num_hallway_keys]:
181 | self.place_obj(Key(color=key_color), top=hallway_top, size=hallway_size)
182 |
183 | # Place keys in rooms
184 | key_index = num_hallway_keys
185 | while key_index < len(color_sequence):
186 | room = self.rooms[color_sequence[key_index - 1]]
187 | num_room_keys = self._rand_int(1, self.max_keys_per_room + 1)
188 | for key_color in color_sequence[key_index : key_index + num_room_keys]:
189 | self.place_obj(Key(color=key_color), top=room.top, size=room.size)
190 | key_index += 1
191 |
192 | # Place agents in hallway
193 | for agent in self.agents:
194 | MultiGridEnv.place_agent(self, agent, top=hallway_top, size=hallway_size)
195 |
196 | def reset(self, **kwargs):
197 | """
198 | :meta private:
199 | """
200 | self.unlocked_doors = []
201 | return super().reset(**kwargs)
202 |
203 | def step(self, actions):
204 | """
205 | :meta private:
206 | """
207 | observations, rewards, terminations, truncations, infos = super().step(actions)
208 |
209 | # Reward for unlocking doors
210 | for agent_id, action in actions.items():
211 | if action == Action.toggle:
212 | fwd_obj = self.grid.get(*self.agents[agent_id].front_pos)
213 | if isinstance(fwd_obj, Door) and not fwd_obj.is_locked:
214 | if fwd_obj not in self.unlocked_doors:
215 | self.unlocked_doors.append(fwd_obj)
216 | if self.joint_reward:
217 | for k in rewards:
218 | rewards[k] += self._reward()
219 | else:
220 | rewards[agent_id] += self._reward()
221 |
222 | # Check if all doors are unlocked
223 | if len(self.unlocked_doors) == len(self.rooms):
224 | for agent in self.agents:
225 | terminations[agent.index] = True
226 |
227 | return observations, rewards, terminations, truncations, infos
228 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import argparse
4 | import json
5 | import multigrid.rllib
6 | import os
7 | import random
8 | import ray
9 | import ray.train
10 | import ray.tune
11 | import torch
12 | import torch.nn as nn
13 |
14 | from multigrid.core.constants import Direction
15 | from pathlib import Path
16 | from ray.rllib.algorithms import PPOConfig
17 | from ray.rllib.core.columns import Columns
18 | from ray.rllib.core.rl_module import MultiRLModuleSpec, RLModuleSpec
19 | from ray.rllib.core.rl_module.apis import ValueFunctionAPI
20 | from ray.rllib.core.rl_module.torch import TorchRLModule
21 | from ray.rllib.utils.from_config import NotProvided
22 | from typing import Callable
23 |
24 |
25 |
26 | ### Helper Methods
27 |
28 | def get_policy_mapping_fn(
29 | checkpoint_dir: Path | str | None,
30 | num_agents: int,
31 | ) -> Callable:
32 | try:
33 | policies = sorted([
34 | path for path in (checkpoint_dir / 'policies').iterdir() if path.is_dir()])
35 |
36 | def policy_mapping_fn(agent_id, *args, **kwargs):
37 | return policies[agent_id % len(policies)].name
38 |
39 | print('Loading policies from:', checkpoint_dir)
40 | for agent_id in range(num_agents):
41 | print('Agent ID:', agent_id, 'Policy ID:', policy_mapping_fn(agent_id))
42 |
43 | return policy_mapping_fn
44 |
45 | except:
46 | return lambda agent_id, *args, **kwargs: f'policy_{agent_id}'
47 |
48 | def find_checkpoint_dir(search_dir: Path | str | None) -> Path | None:
49 | try:
50 | checkpoints = Path(search_dir).expanduser().glob('**/rllib_checkpoint.json')
51 | if checkpoints:
52 | return sorted(checkpoints, key=os.path.getmtime)[-1].parent
53 | except:
54 | return None
55 |
56 | def preprocess_batch(batch: dict) -> torch.Tensor:
57 | image = batch['obs']['image']
58 | direction = batch['obs']['direction']
59 | direction = 2 * torch.pi * (direction / len(Direction))
60 | direction = torch.stack([torch.cos(direction), torch.sin(direction)], dim=-1)
61 | direction = direction[..., None, None, :].expand(*image.shape[:-1], 2)
62 | x = torch.cat([image, direction], dim=-1).float()
63 | return x
64 |
65 |
66 |
67 | ### Models
68 |
69 | class MultiGridEncoder(nn.Module):
70 |
71 | def __init__(self, in_channels: int = 23):
72 | super().__init__()
73 | self.model = nn.Sequential(
74 | nn.Conv2d(in_channels, 16, (3, 3)), nn.ReLU(),
75 | nn.Conv2d(16, 32, (3, 3)), nn.ReLU(),
76 | nn.Conv2d(32, 64, (3, 3)), nn.ReLU(),
77 | nn.Flatten(),
78 | )
79 |
80 | def forward(self, x):
81 | x = x[None] if x.ndim == 3 else x # add batch dimension
82 | x = x.permute(0, 3, 1, 2) # channels-first (NHWC -> NCHW)
83 | return self.model(x)
84 |
85 |
86 | class AgentModule(TorchRLModule, ValueFunctionAPI):
87 |
88 | def setup(self):
89 | self.base = nn.Identity()
90 | self.actor = nn.Sequential(
91 | MultiGridEncoder(in_channels=23),
92 | nn.Linear(64, 64), nn.ReLU(),
93 | nn.Linear(64, 7),
94 | )
95 | self.critic = nn.Sequential(
96 | MultiGridEncoder(in_channels=23),
97 | nn.Linear(64, 64), nn.ReLU(),
98 | nn.Linear(64, 1),
99 | )
100 |
101 | def _forward(self, batch, **kwargs):
102 | x = self.base(preprocess_batch(batch))
103 | return {Columns.ACTION_DIST_INPUTS: self.actor(x)}
104 |
105 | def _forward_train(self, batch, **kwargs):
106 | x = self.base(preprocess_batch(batch))
107 | return {
108 | Columns.ACTION_DIST_INPUTS: self.actor(x),
109 | Columns.EMBEDDINGS: self.critic(batch.get('value_inputs', x)),
110 | }
111 |
112 | def compute_values(self, batch, embeddings = None):
113 | if embeddings is None:
114 | x = self.base(preprocess_batch(batch))
115 | embeddings = self.critic(batch.get('value_inputs', x))
116 |
117 | return embeddings.squeeze(-1)
118 |
119 | def get_initial_state(self):
120 | return {}
121 |
122 |
123 |
124 | ### Training
125 |
126 | def get_algorithm_config(
127 | env: str = 'MultiGrid-Empty-8x8-v0',
128 | env_config: dict = {},
129 | num_agents: int = 2,
130 | # lstm: bool = False, TODO: implement LSTM model
131 | num_workers: int = 0,
132 | num_gpus: int = 0,
133 | lr: float = NotProvided,
134 | batch_size: int = NotProvided,
135 | # centralized_critic: bool = False, TODO: implement centralized critic
136 | **kwargs,
137 | ) -> PPOConfig:
138 | """
139 | Return the RL algorithm configuration dictionary.
140 | """
141 | config = PPOConfig()
142 | config = config.api_stack(
143 | enable_env_runner_and_connector_v2=True,
144 | enable_rl_module_and_learner=True,
145 | )
146 | config = config.debugging(seed=random.randint(0, 1000000))
147 | config = config.env_runners(
148 | num_env_runners=num_workers,
149 | num_envs_per_env_runner=1,
150 | num_gpus_per_env_runner=num_gpus if torch.cuda.is_available() else 0,
151 | )
152 | config = config.environment(env=env, env_config={**env_config, 'agents': num_agents})
153 | config = config.framework('torch')
154 | config = config.multi_agent(
155 | policies={f'policy_{i}' for i in range(num_agents)},
156 | policy_mapping_fn=get_policy_mapping_fn(None, num_agents),
157 | policies_to_train=[f'policy_{i}' for i in range(num_agents)],
158 | )
159 | config = config.training(lr=lr, train_batch_size=batch_size)
160 | config = config.rl_module(
161 | rl_module_spec=MultiRLModuleSpec(
162 | rl_module_specs={
163 | f'policy_{i}': RLModuleSpec(module_class=AgentModule)
164 | for i in range(num_agents)
165 | }
166 | )
167 | )
168 |
169 | return config
170 |
171 | def train(
172 | config: PPOConfig,
173 | stop_conditions: dict,
174 | save_dir: str,
175 | load_dir: str = None,
176 | ):
177 | """
178 | Train an RLlib algorithm.
179 | """
180 | checkpoint = find_checkpoint_dir(load_dir)
181 | if checkpoint:
182 | tuner = ray.tune.Tuner.restore(checkpoint)
183 | else:
184 | tuner = ray.tune.Tuner(
185 | config.algo_class,
186 | param_space=config,
187 | run_config=ray.train.RunConfig(
188 | storage_path=save_dir,
189 | stop=stop_conditions,
190 | verbose=1,
191 | checkpoint_config=ray.train.CheckpointConfig(
192 | checkpoint_frequency=20,
193 | checkpoint_at_end=True,
194 | ),
195 | ),
196 | )
197 |
198 | results = tuner.fit()
199 | return results
200 |
201 |
202 |
203 | if __name__ == "__main__":
204 | parser = argparse.ArgumentParser()
205 |
206 | parser.add_argument(
207 | '--algo', type=str, default='PPO',
208 | help="The name of the RLlib-registered algorithm to use.")
209 | parser.add_argument(
210 | '--env', type=str, default='MultiGrid-Empty-8x8-v0',
211 | help="MultiGrid environment to use.")
212 | parser.add_argument(
213 | '--env-config', type=json.loads, default={},
214 | help="Environment config dict, given as a JSON string (e.g. '{\"size\": 8}')")
215 | parser.add_argument(
216 | '--num-agents', type=int, default=2,
217 | help="Number of agents in environment.")
218 | # parser.add_argument(
219 | # '--lstm', action='store_true',
220 | # help="Use LSTM model.")
221 | # parser.add_argument(
222 | # '--centralized-critic', action='store_true',
223 | # help="Use centralized critic for training.")
224 | parser.add_argument(
225 | '--num-workers', type=int, default=8,
226 | help="Number of rollout workers.")
227 | parser.add_argument(
228 | '--num-gpus', type=int, default=1,
229 | help="Number of GPUs to train on.")
230 | parser.add_argument(
231 | '--num-timesteps', type=int, default=1e7,
232 | help="Total number of timesteps to train.")
233 | parser.add_argument(
234 | '--lr', type=float,
235 | help="Learning rate for training.")
236 | parser.add_argument(
237 | '--load-dir', type=str,
238 | help="Checkpoint directory for loading pre-trained policies.")
239 | parser.add_argument(
240 | '--save-dir', type=str, default='~/ray_results/',
241 | help="Directory for saving checkpoints, results, and trained policies.")
242 |
243 | args = parser.parse_args()
244 | config = get_algorithm_config(**vars(args))
245 |
246 | print()
247 | print(f"Running with following CLI options: {args}")
248 | print('\n', '-' * 64, '\n', "Training with following configuration:", '\n', '-' * 64)
249 | print()
250 |
251 | stop_conditions = {
252 | 'learners/__all_modules__/num_env_steps_trained_lifetime': args.num_timesteps}
253 | train(config, stop_conditions, args.save_dir, args.load_dir)
254 |
--------------------------------------------------------------------------------
/multigrid/core/grid.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 |
5 | from collections import defaultdict
6 | from functools import cached_property
7 | from numpy.typing import NDArray as ndarray
8 | from typing import Any, Callable, Iterable
9 |
10 | from .agent import Agent
11 | from .constants import Type, TILE_PIXELS
12 | from .world_object import Wall, WorldObj
13 |
14 | from ..utils.rendering import (
15 | downsample,
16 | fill_coords,
17 | highlight_img,
18 | point_in_rect,
19 | )
20 |
21 |
22 |
23 | class Grid:
24 | """
25 | Class representing a grid of :class:`.WorldObj` objects.
26 |
27 | Attributes
28 | ----------
29 | width : int
30 | Width of the grid
31 | height : int
32 | Height of the grid
33 | world_objects : dict[tuple[int, int], WorldObj]
34 | Dictionary of world objects in the grid, indexed by (x, y) location
35 | state : ndarray[int] of shape (width, height, WorldObj.dim)
36 | Grid state, where each (x, y) entry is a world object encoding
37 | """
38 |
39 | # Static cache of pre-renderer tiles
40 | _tile_cache: dict[tuple[Any, ...], Any] = {}
41 |
42 | def __init__(self, width: int, height: int):
43 | """
44 | Parameters
45 | ----------
46 | width : int
47 | Width of the grid
48 | height : int
49 | Height of the grid
50 | """
51 | assert width >= 3
52 | assert height >= 3
53 | self.world_objects: dict[tuple[int, int], WorldObj] = {} # indexed by location
54 | self.state: ndarray[np.int] = np.zeros((width, height, WorldObj.dim), dtype=int)
55 | self.state[...] = WorldObj.empty()
56 |
57 | @cached_property
58 | def width(self) -> int:
59 | """
60 | Width of the grid.
61 | """
62 | return self.state.shape[0]
63 |
64 | @cached_property
65 | def height(self) -> int:
66 | """
67 | Height of the grid.
68 | """
69 | return self.state.shape[1]
70 |
71 | @property
72 | def grid(self) -> list[WorldObj | None]:
73 | """
74 | Return a list of all world objects in the grid.
75 | """
76 | return [self.get(i, j) for i in range(self.width) for j in range(self.height)]
77 |
78 | def set(self, x: int, y: int, obj: WorldObj | None):
79 | """
80 | Set a world object at the given coordinates.
81 |
82 | Parameters
83 | ----------
84 | x : int
85 | Grid x-coordinate
86 | y : int
87 | Grid y-coordinate
88 | obj : WorldObj or None
89 | Object to place
90 | """
91 | # Update world object dictionary
92 | self.world_objects[x, y] = obj
93 |
94 | # Update grid state
95 | if isinstance(obj, WorldObj):
96 | self.state[x, y] = obj
97 | elif obj is None:
98 | self.state[x, y] = WorldObj.empty()
99 | else:
100 | raise TypeError(f"cannot set grid value to {type(obj)}")
101 |
102 | def get(self, x: int, y: int) -> WorldObj | None:
103 | """
104 | Get the world object at the given coordinates.
105 |
106 | Parameters
107 | ----------
108 | x : int
109 | Grid x-coordinate
110 | y : int
111 | Grid y-coordinate
112 | """
113 | # Create WorldObj instance if none exists
114 | if (x, y) not in self.world_objects:
115 | self.world_objects[x, y] = WorldObj.from_array(self.state[x, y])
116 |
117 | return self.world_objects[x, y]
118 |
119 | def update(self, x: int, y: int):
120 | """
121 | Update the grid state from the world object at the given coordinates.
122 |
123 | Parameters
124 | ----------
125 | x : int
126 | Grid x-coordinate
127 | y : int
128 | Grid y-coordinate
129 | """
130 | if (x, y) in self.world_objects:
131 | self.state[x, y] = self.world_objects[x, y]
132 |
133 | def horz_wall(
134 | self,
135 | x: int, y: int,
136 | length: int | None = None,
137 | obj_type: Callable[[], WorldObj] = Wall):
138 | """
139 | Create a horizontal wall.
140 |
141 | Parameters
142 | ----------
143 | x : int
144 | Leftmost x-coordinate of wall
145 | y : int
146 | Y-coordinate of wall
147 | length : int or None
148 | Length of wall. If None, wall extends to the right edge of the grid.
149 | obj_type : Callable() -> WorldObj
150 | Function that returns a WorldObj instance to use for the wall
151 | """
152 | length = self.width - x if length is None else length
153 | self.state[x:x+length, y] = obj_type()
154 |
155 | def vert_wall(
156 | self,
157 | x: int, y: int,
158 | length: int | None = None,
159 | obj_type: Callable[[], WorldObj] = Wall):
160 | """
161 | Create a vertical wall.
162 |
163 | Parameters
164 | ----------
165 | x : int
166 | X-coordinate of wall
167 | y : int
168 | Topmost y-coordinate of wall
169 | length : int or None
170 | Length of wall. If None, wall extends to the bottom edge of the grid.
171 | obj_type : Callable() -> WorldObj
172 | Function that returns a WorldObj instance to use for the wall
173 | """
174 | length = self.height - y if length is None else length
175 | self.state[x, y:y+length] = obj_type()
176 |
177 | def wall_rect(self, x: int, y: int, w: int, h: int):
178 | """
179 | Create a walled rectangle.
180 |
181 | Parameters
182 | ----------
183 | x : int
184 | X-coordinate of top-left corner
185 | y : int
186 | Y-coordinate of top-left corner
187 | w : int
188 | Width of rectangle
189 | h : int
190 | Height of rectangle
191 | """
192 | self.horz_wall(x, y, w)
193 | self.horz_wall(x, y + h - 1, w)
194 | self.vert_wall(x, y, h)
195 | self.vert_wall(x + w - 1, y, h)
196 |
197 | @classmethod
198 | def render_tile(
199 | cls,
200 | obj: WorldObj | None = None,
201 | agent: Agent | None = None,
202 | highlight: bool = False,
203 | tile_size: int = TILE_PIXELS,
204 | subdivs: int = 3) -> ndarray[np.uint8]:
205 | """
206 | Render a tile and cache the result.
207 |
208 | Parameters
209 | ----------
210 | obj : WorldObj or None
211 | Object to render
212 | agent : Agent or None
213 | Agent to render
214 | highlight : bool
215 | Whether to highlight the tile
216 | tile_size : int
217 | Tile size (in pixels)
218 | subdivs : int
219 | Downsampling factor for supersampling / anti-aliasing
220 | """
221 | # Hash map lookup key for the cache
222 | key: tuple[Any, ...] = (highlight, tile_size)
223 | if agent:
224 | key += (agent.state.color, agent.state.dir)
225 | else:
226 | key += (None, None)
227 | key = obj.encode() + key if obj else key
228 |
229 | if key in cls._tile_cache:
230 | return cls._tile_cache[key]
231 |
232 | img = np.zeros(
233 | shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)
234 |
235 | # Draw the grid lines (top and left edges)
236 | fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
237 | fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
238 |
239 | # Draw the object
240 | if obj is not None:
241 | obj.render(img)
242 |
243 | # Draw the agent
244 | if agent is not None and not agent.state.terminated:
245 | agent.render(img)
246 |
247 | # Highlight the cell if needed
248 | if highlight:
249 | highlight_img(img)
250 |
251 | # Downsample the image to perform supersampling/anti-aliasing
252 | img = downsample(img, subdivs)
253 |
254 | # Cache the rendered tile
255 | cls._tile_cache[key] = img
256 |
257 | return img
258 |
259 | def render(
260 | self,
261 | tile_size: int,
262 | agents: Iterable[Agent] = (),
263 | highlight_mask: ndarray[np.bool] | None = None) -> ndarray[np.uint8]:
264 | """
265 | Render this grid at a given scale.
266 |
267 | Parameters
268 | ----------
269 | tile_size: int
270 | Tile size (in pixels)
271 | agents: Iterable[Agent]
272 | Agents to render
273 | highlight_mask: ndarray
274 | Boolean mask indicating which grid locations to highlight
275 | """
276 | if highlight_mask is None:
277 | highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
278 |
279 | # Get agent locations
280 | # For overlapping agents, non-terminated agents get priority
281 | location_to_agent = defaultdict(type(None))
282 | for agent in sorted(agents, key=lambda a: not a.terminated):
283 | location_to_agent[tuple(agent.pos)] = agent
284 |
285 | # Initialize pixel array
286 | width_px = self.width * tile_size
287 | height_px = self.height * tile_size
288 | img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
289 |
290 | # Render the grid
291 | for j in range(0, self.height):
292 | for i in range(0, self.width):
293 | assert highlight_mask is not None
294 | cell = self.get(i, j)
295 | tile_img = Grid.render_tile(
296 | cell,
297 | agent=location_to_agent[i, j],
298 | highlight=highlight_mask[i, j],
299 | tile_size=tile_size,
300 | )
301 |
302 | ymin = j * tile_size
303 | ymax = (j + 1) * tile_size
304 | xmin = i * tile_size
305 | xmax = (i + 1) * tile_size
306 | img[ymin:ymax, xmin:xmax, :] = tile_img
307 |
308 | return img
309 |
310 | def encode(self, vis_mask: ndarray[np.bool] | None = None) -> ndarray[np.int]:
311 | """
312 | Produce a compact numpy encoding of the grid.
313 |
314 | Parameters
315 | ----------
316 | vis_mask : ndarray[bool] of shape (width, height)
317 | Visibility mask
318 | """
319 | if vis_mask is None:
320 | vis_mask = np.ones((self.width, self.height), dtype=bool)
321 |
322 | encoding = self.state.copy()
323 | encoding[~vis_mask][..., WorldObj.TYPE] = Type.unseen.to_index()
324 | return encoding
325 |
326 | @staticmethod
327 | def decode(array: ndarray[np.int]) -> tuple['Grid', ndarray[np.bool]]:
328 | """
329 | Decode an array grid encoding back into a `Grid` instance.
330 |
331 | Parameters
332 | ----------
333 | array : ndarray[int] of shape (width, height, dim)
334 | Grid encoding
335 |
336 | Returns
337 | -------
338 | grid : Grid
339 | Decoded `Grid` instance
340 | vis_mask : ndarray[bool] of shape (width, height)
341 | Visibility mask
342 | """
343 | width, height, dim = array.shape
344 | assert dim == WorldObj.dim
345 |
346 | vis_mask = (array[..., WorldObj.TYPE] != Type.unseen.to_index())
347 | grid = Grid(width, height)
348 | grid.state[vis_mask] = array[vis_mask]
349 | return grid, vis_mask
350 |
--------------------------------------------------------------------------------
/multigrid/core/agent.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 |
5 | from gymnasium import spaces
6 | from numpy.typing import ArrayLike, NDArray as ndarray
7 |
8 | from .actions import Action
9 | from .constants import Color, Direction, Type
10 | from .mission import Mission, MissionSpace
11 | from .world_object import WorldObj
12 |
13 | from ..utils.misc import front_pos, PropertyAlias
14 | from ..utils.rendering import (
15 | fill_coords,
16 | point_in_triangle,
17 | rotate_fn,
18 | )
19 |
20 |
21 |
22 | class Agent:
23 | """
24 | Class representing an agent in the environment.
25 |
26 | :Observation Space:
27 |
28 | Observations are dictionaries with the following entries:
29 |
30 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`)
31 | Encoding of the agent's view of the environment
32 | * direction : int
33 | Agent's direction (0: right, 1: down, 2: left, 3: up)
34 | * mission : Mission
35 | Task string corresponding to the current environment configuration
36 |
37 | :Action Space:
38 |
39 | Actions are discrete integers, as enumerated in :class:`.Action`.
40 |
41 | Attributes
42 | ----------
43 | index : int
44 | Index of the agent in the environment
45 | state : AgentState
46 | State of the agent
47 | mission : Mission
48 | Current mission string for the agent
49 | action_space : gym.spaces.Discrete
50 | Action space for the agent
51 | observation_space : gym.spaces.Dict
52 | Observation space for the agent
53 | """
54 |
55 | def __init__(
56 | self,
57 | index: int,
58 | mission_space: MissionSpace = MissionSpace.from_string('maximize reward'),
59 | view_size: int = 7,
60 | see_through_walls: bool = False):
61 | """
62 | Parameters
63 | ----------
64 | index : int
65 | Index of the agent in the environment
66 | mission_space : MissionSpace
67 | The mission space for the agent
68 | view_size : int
69 | The size of the agent's view (must be odd)
70 | see_through_walls : bool
71 | Whether the agent can see through walls
72 | """
73 | self.index: int = index
74 | self.state: AgentState = AgentState()
75 | self.mission: Mission = None
76 |
77 | # Number of cells (width and height) in the agent view
78 | assert view_size % 2 == 1
79 | assert view_size >= 3
80 | self.view_size = view_size
81 | self.see_through_walls = see_through_walls
82 |
83 | # Observations are dictionaries containing an
84 | # encoding of the grid and a textual 'mission' string
85 | self.observation_space = spaces.Dict({
86 | 'image': spaces.Box(
87 | low=0,
88 | high=255,
89 | shape=(view_size, view_size, WorldObj.dim),
90 | dtype=int,
91 | ),
92 | 'direction': spaces.Discrete(len(Direction)),
93 | 'mission': mission_space,
94 | })
95 |
96 | # Actions are discrete integer values
97 | self.action_space = spaces.Discrete(len(Action))
98 |
99 | # AgentState Properties
100 | color = PropertyAlias(
101 | 'state', 'color', doc='Alias for :attr:`AgentState.color`.')
102 | dir = PropertyAlias(
103 | 'state', 'dir', doc='Alias for :attr:`AgentState.dir`.')
104 | pos = PropertyAlias(
105 | 'state', 'pos', doc='Alias for :attr:`AgentState.pos`.')
106 | terminated = PropertyAlias(
107 | 'state', 'terminated', doc='Alias for :attr:`AgentState.terminated`.')
108 | carrying = PropertyAlias(
109 | 'state', 'carrying', doc='Alias for :attr:`AgentState.carrying`.')
110 |
111 | @property
112 | def front_pos(self) -> tuple[int, int]:
113 | """
114 | Get the position of the cell that is directly in front of the agent.
115 | """
116 | agent_dir = self.state._view[AgentState.DIR]
117 | agent_pos = self.state._view[AgentState.POS]
118 | return front_pos(*agent_pos, agent_dir)
119 |
120 | def reset(self, mission: Mission = Mission('maximize reward')):
121 | """
122 | Reset the agent to an initial state.
123 |
124 | Parameters
125 | ----------
126 | mission : Mission
127 | Mission string to use for the new episode
128 | """
129 | self.mission = mission
130 | self.state.pos = (-1, -1)
131 | self.state.dir = -1
132 | self.state.terminated = False
133 | self.state.carrying = None
134 |
135 | def encode(self) -> tuple[int, int, int]:
136 | """
137 | Encode a description of this agent as a 3-tuple of integers.
138 |
139 | Returns
140 | -------
141 | type_idx : int
142 | The index of the agent type
143 | color_idx : int
144 | The index of the agent color
145 | agent_dir : int
146 | The direction of the agent (0: right, 1: down, 2: left, 3: up)
147 | """
148 | return (Type.agent.to_index(), self.state.color.to_index(), self.state.dir)
149 |
150 | def render(self, img: ndarray[np.uint8]):
151 | """
152 | Draw the agent.
153 |
154 | Parameters
155 | ----------
156 | img : ndarray[int] of shape (width, height, 3)
157 | RGB image array to render agent on
158 | """
159 | tri_fn = point_in_triangle(
160 | (0.12, 0.19),
161 | (0.87, 0.50),
162 | (0.12, 0.81),
163 | )
164 |
165 | # Rotate the agent based on its direction
166 | tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * np.pi * self.state.dir)
167 | fill_coords(img, tri_fn, self.state.color.rgb())
168 |
169 |
170 | class AgentState(np.ndarray):
171 | """
172 | State for an :class:`.Agent` object.
173 |
174 | ``AgentState`` objects also support vectorized operations,
175 | in which case the ``AgentState`` object represents the states of multiple agents.
176 |
177 | Attributes
178 | ----------
179 | color : Color or ndarray[str]
180 | Agent color
181 | dir : Direction or ndarray[int]
182 | Agent direction (0: right, 1: down, 2: left, 3: up)
183 | pos : tuple[int, int] or ndarray[int]
184 | Agent (x, y) position
185 | terminated : bool or ndarray[bool]
186 | Whether the agent has terminated
187 | carrying : WorldObj or None or ndarray[object]
188 | Object the agent is carrying
189 |
190 | Examples
191 | --------
192 | Create a vectorized agent state for 3 agents:
193 |
194 | >>> agent_state = AgentState(3)
195 | >>> agent_state
196 | AgentState(3)
197 |
198 | Access and set state attributes for one agent at a time:
199 |
200 | >>> a = agent_state[0]
201 | >>> a
202 | AgentState()
203 | >>> a.color
204 | 'red'
205 | >>> a.color = 'yellow'
206 |
207 | The underlying vectorized state is automatically updated as well:
208 |
209 | >>> agent_state.color
210 | array(['yellow', 'green', 'blue'])
211 |
212 | Access and set state attributes all at once:
213 |
214 | >>> agent_state.dir
215 | array([-1, -1, -1])
216 | >>> agent_state.dir = np.random.randint(4, size=(len(agent_state)))
217 | >>> agent_state.dir
218 | array([2, 3, 0])
219 | >>> a.dir
220 | 2
221 | """
222 | # State vector indices
223 | TYPE = 0
224 | COLOR = 1
225 | DIR = 2
226 | ENCODING = slice(0, 3)
227 | POS = slice(3, 5)
228 | TERMINATED = 5
229 | CARRYING = slice(6, 6 + WorldObj.dim)
230 |
231 | # State vector dimension
232 | dim = 6 + WorldObj.dim
233 |
234 | def __new__(cls, *dims: int):
235 | """
236 | Parameters
237 | ----------
238 | dims : int, optional
239 | Shape of vectorized agent state
240 | """
241 | obj = np.zeros(dims + (cls.dim,), dtype=int).view(cls)
242 |
243 | # Set default values
244 | obj[..., AgentState.TYPE] = Type.agent
245 | obj[..., AgentState.COLOR].flat = Color.cycle(np.prod(dims))
246 | obj[..., AgentState.DIR] = -1
247 | obj[..., AgentState.POS] = (-1, -1)
248 |
249 | # Other attributes
250 | obj._carried_obj = np.empty(dims, dtype=object) # object references
251 | obj._terminated = np.zeros(dims, dtype=bool) # cache for faster access
252 | obj._view = obj.view(np.ndarray) # view of the underlying array (faster indexing)
253 |
254 | return obj
255 |
256 | def __repr__(self):
257 | shape = str(self.shape[:-1]).replace(",)", ")")
258 | return f'{self.__class__.__name__}{shape}'
259 |
260 | def __getitem__(self, idx):
261 | out = super().__getitem__(idx)
262 | if out.shape and out.shape[-1] == self.dim:
263 | out._view = self._view[idx, ...]
264 | out._carried_obj = self._carried_obj[idx, ...] # set carried object reference
265 | out._terminated = self._terminated[idx, ...] # set terminated cache
266 |
267 | return out
268 |
269 | @property
270 | def color(self) -> Color | ndarray[np.str]:
271 | """
272 | Return the agent color.
273 | """
274 | return Color.from_index(self._view[..., AgentState.COLOR])
275 |
276 | @color.setter
277 | def color(self, value: str | ArrayLike[str]):
278 | """
279 | Set the agent color.
280 | """
281 | self[..., AgentState.COLOR] = np.vectorize(lambda c: Color(c).to_index())(value)
282 |
283 | @property
284 | def dir(self) -> Direction | ndarray[np.int]:
285 | """
286 | Return the agent direction.
287 | """
288 | out = self._view[..., AgentState.DIR]
289 | return Direction(out.item()) if out.ndim == 0 else out
290 |
291 | @dir.setter
292 | def dir(self, value: int | ArrayLike[int]):
293 | """
294 | Set the agent direction.
295 | """
296 | self[..., AgentState.DIR] = value
297 |
298 | @property
299 | def pos(self) -> tuple[int, int] | ndarray[np.int]:
300 | """
301 | Return the agent's (x, y) position.
302 | """
303 | out = self._view[..., AgentState.POS]
304 | return tuple(out) if out.ndim == 1 else out
305 |
306 | @pos.setter
307 | def pos(self, value: ArrayLike[int] | ArrayLike[ArrayLike[int]]):
308 | """
309 | Set the agent's (x, y) position.
310 | """
311 | self[..., AgentState.POS] = value
312 |
313 | @property
314 | def terminated(self) -> bool | ndarray[np.bool]:
315 | """
316 | Return whether the agent has terminated.
317 | """
318 | out = self._terminated
319 | return out.item() if out.ndim == 0 else out
320 |
321 | @terminated.setter
322 | def terminated(self, value: bool | ArrayLike[bool]):
323 | """
324 | Set whether the agent has terminated.
325 | """
326 | self[..., AgentState.TERMINATED] = value
327 | self._terminated[...] = value
328 |
329 | @property
330 | def carrying(self) -> WorldObj | None | ndarray[np.object]:
331 | """
332 | Return the object the agent is carrying.
333 | """
334 | out = self._carried_obj
335 | return out.item() if out.ndim == 0 else out
336 |
337 | @carrying.setter
338 | def carrying(self, obj: WorldObj | None | ArrayLike[object]):
339 | """
340 | Set the object the agent is carrying.
341 | """
342 | self[..., AgentState.CARRYING] = WorldObj.empty() if obj is None else obj
343 | if isinstance(obj, (WorldObj, type(None))):
344 | self._carried_obj[...].fill(obj)
345 | else:
346 | self._carried_obj[...] = obj
347 |
--------------------------------------------------------------------------------
/multigrid/utils/obs.py:
--------------------------------------------------------------------------------
1 | import numba as nb
2 | import numpy as np
3 |
4 | from ..core.agent import AgentState
5 | from ..core.constants import Color, Direction, State, Type
6 | from ..core.world_object import Wall, WorldObj
7 |
8 | from numpy.typing import NDArray as ndarray
9 |
10 |
11 |
12 | ### Constants
13 |
14 | WALL_ENCODING = Wall().encode()
15 | UNSEEN_ENCODING = WorldObj(Type.unseen, Color.from_index(0)).encode()
16 | ENCODE_DIM = WorldObj.dim
17 |
18 | GRID_ENCODING_IDX = slice(None)
19 |
20 | AGENT_DIR_IDX = AgentState.DIR
21 | AGENT_POS_IDX = AgentState.POS
22 | AGENT_TERMINATED_IDX = AgentState.TERMINATED
23 | AGENT_CARRYING_IDX = AgentState.CARRYING
24 | AGENT_ENCODING_IDX = AgentState.ENCODING
25 |
26 | TYPE = WorldObj.TYPE
27 | STATE = WorldObj.STATE
28 |
29 | WALL = int(Type.wall)
30 | DOOR = int(Type.door)
31 |
32 | OPEN = int(State.open)
33 | CLOSED = int(State.closed)
34 | LOCKED = int(State.locked)
35 |
36 | RIGHT = int(Direction.right)
37 | LEFT = int(Direction.left)
38 | UP = int(Direction.up)
39 | DOWN = int(Direction.down)
40 |
41 |
42 |
43 |
44 | ### Observation Functions
45 |
46 | @nb.njit(cache=True)
47 | def see_behind(world_obj: ndarray[np.int_]) -> bool:
48 | """
49 | Can an agent see behind this object?
50 |
51 | Parameters
52 | ----------
53 | world_obj : ndarray[int] of shape (encode_dim,)
54 | World object encoding
55 | """
56 | if world_obj is None:
57 | return True
58 | if world_obj[TYPE] == WALL:
59 | return False
60 | elif world_obj[TYPE] == DOOR and world_obj[STATE] != OPEN:
61 | return False
62 |
63 | return True
64 |
65 | @nb.njit(cache=True)
66 | def gen_obs_grid_encoding(
67 | grid_state: ndarray[np.int_],
68 | agent_state: ndarray[np.int_],
69 | agent_view_size: int,
70 | see_through_walls: bool) -> ndarray[np.int_]:
71 | """
72 | Generate encoding for the sub-grid observed by an agent (including visibility mask).
73 |
74 | Parameters
75 | ----------
76 | grid_state : ndarray[int] of shape (width, height, grid_state_dim)
77 | Array representation for each grid object
78 | agent_state : ndarray[int] of shape (num_agents, agent_state_dim)
79 | Array representation for each agent
80 | agent_view_size : int
81 | Width and height of observation sub-grids
82 | see_through_walls : bool
83 | Whether the agent can see through walls
84 |
85 | Returns
86 | -------
87 | img : ndarray[int] of shape (num_agents, view_size, view_size, encode_dim)
88 | Encoding of observed sub-grid for each agent
89 | """
90 | obs_grid = gen_obs_grid(grid_state, agent_state, agent_view_size)
91 |
92 | # Generate and apply visibility masks
93 | vis_mask = get_vis_mask(obs_grid)
94 | num_agents = len(agent_state)
95 | for agent in range(num_agents):
96 | if not see_through_walls:
97 | for i in range(agent_view_size):
98 | for j in range(agent_view_size):
99 | if not vis_mask[agent, i, j]:
100 | obs_grid[agent, i, j] = UNSEEN_ENCODING
101 |
102 | return obs_grid
103 |
104 | @nb.njit(cache=True)
105 | def gen_obs_grid_vis_mask(
106 | grid_state: ndarray[np.int_],
107 | agent_state: ndarray[np.int_],
108 | agent_view_size: int) -> ndarray[np.int_]:
109 | """
110 | Generate visibility mask for the sub-grid observed by an agent.
111 |
112 | Parameters
113 | ----------
114 | grid_state : ndarray[int] of shape (width, height, grid_state_dim)
115 | Array representation for each grid object
116 | agent_state : ndarray[int] of shape (num_agents, agent_state_dim)
117 | Array representation for each agent
118 | agent_view_size : int
119 | Width and height of observation sub-grids
120 |
121 | Returns
122 | -------
123 | mask : ndarray[int] of shape (num_agents, view_size, view_size)
124 | Encoding of observed sub-grid for each agent
125 | """
126 | obs_grid = gen_obs_grid(grid_state, agent_state, agent_view_size)
127 | return get_vis_mask(obs_grid)
128 |
129 |
130 | @nb.njit(cache=True)
131 | def gen_obs_grid(
132 | grid_state: ndarray[np.int_],
133 | agent_state: ndarray[np.int_],
134 | agent_view_size: int) -> ndarray[np.int_]:
135 | """
136 | Generate the sub-grid observed by each agent (WITHOUT visibility mask).
137 |
138 | Parameters
139 | ----------
140 | grid_state : ndarray[int] of shape (width, height, grid_state_dim)
141 | Array representation for each grid object
142 | agent_state : ndarray[int] of shape (num_agents, agent_state_dim)
143 | Array representation for each agent
144 | agent_view_size : int
145 | Width and height of observation sub-grids
146 |
147 | Returns
148 | -------
149 | obs_grid : ndarray[int] of shape (num_agents, width, height, encode_dim)
150 | Observed sub-grid for each agent
151 | """
152 | num_agents = len(agent_state)
153 | obs_width, obs_height = agent_view_size, agent_view_size
154 |
155 | # Process agent states
156 | agent_grid_encoding = agent_state[..., AGENT_ENCODING_IDX]
157 | agent_dir = agent_state[..., AGENT_DIR_IDX]
158 | agent_pos = agent_state[..., AGENT_POS_IDX]
159 | agent_terminated = agent_state[..., AGENT_TERMINATED_IDX]
160 | agent_carrying_encoding = agent_state[..., AGENT_CARRYING_IDX]
161 |
162 | # Get grid encoding
163 | if num_agents > 1:
164 | grid_encoding = np.empty((*grid_state.shape[:-1], ENCODE_DIM), dtype=np.int_)
165 | grid_encoding[...] = grid_state[..., GRID_ENCODING_IDX]
166 |
167 | # Insert agent grid encodings
168 | for agent in range(num_agents):
169 | if not agent_terminated[agent]:
170 | i, j = agent_pos[agent]
171 | grid_encoding[i, j, GRID_ENCODING_IDX] = agent_grid_encoding[agent]
172 | else:
173 | grid_encoding = grid_state[..., GRID_ENCODING_IDX]
174 |
175 | # Get top left corner of observation grids
176 | top_left = get_view_exts(agent_dir, agent_pos, agent_view_size)
177 | topX, topY = top_left[:, 0], top_left[:, 1]
178 |
179 | # Populate observation grids
180 | num_left_rotations = (agent_dir + 1) % 4
181 | obs_grid = np.empty((num_agents, obs_width, obs_height, ENCODE_DIM), dtype=np.int_)
182 | for agent in range(num_agents):
183 | for i in range(0, obs_width):
184 | for j in range(0, obs_height):
185 | # Absolute coordinates in world grid
186 | x, y = topX[agent] + i, topY[agent] + j
187 |
188 | # Rotated relative coordinates for observation grid
189 | if num_left_rotations[agent] == 0:
190 | i_rot, j_rot = i, j
191 | elif num_left_rotations[agent] == 1:
192 | i_rot, j_rot = j, obs_width - i - 1
193 | elif num_left_rotations[agent] == 2:
194 | i_rot, j_rot = obs_width - i - 1, obs_height - j - 1
195 | elif num_left_rotations[agent] == 3:
196 | i_rot, j_rot = obs_height - j - 1, i
197 |
198 | # Set observation grid
199 | if 0 <= x < grid_encoding.shape[0] and 0 <= y < grid_encoding.shape[1]:
200 | obs_grid[agent, i_rot, j_rot] = grid_encoding[x, y]
201 | else:
202 | obs_grid[agent, i_rot, j_rot] = WALL_ENCODING
203 |
204 | # Make it so each agent sees what it's carrying
205 | # We do this by placing the carried object at the agent position
206 | # in each agent's partially observable view
207 | obs_grid[:, obs_width // 2, obs_height - 1] = agent_carrying_encoding
208 |
209 | return obs_grid
210 |
211 | @nb.njit(cache=True)
212 | def get_see_behind_mask(grid_array: ndarray[np.int_]) -> ndarray[np.int_]:
213 | """
214 | Return boolean mask indicating which grid locations can be seen through.
215 |
216 | Parameters
217 | ----------
218 | grid_array : ndarray[int] of shape (num_agents, width, height, dim)
219 | Grid object array for each agent
220 |
221 | Returns
222 | -------
223 | see_behind_mask : ndarray[bool] of shape (width, height)
224 | Boolean visibility mask
225 | """
226 | num_agents, width, height = grid_array.shape[:3]
227 | see_behind_mask = np.zeros((num_agents, width, height), dtype=np.bool_)
228 | for agent in range(num_agents):
229 | for i in range(width):
230 | for j in range(height):
231 | see_behind_mask[agent, i, j] = see_behind(grid_array[agent, i, j])
232 |
233 | return see_behind_mask
234 |
235 | @nb.njit(cache=True)
236 | def get_vis_mask(obs_grid: ndarray[np.int_]) -> ndarray[np.bool_]:
237 | """
238 | Generate a boolean mask indicating which grid locations are visible to each agent.
239 |
240 | Parameters
241 | ----------
242 | obs_grid : ndarray[int] of shape (num_agents, width, height, dim)
243 | Grid object array for each agent observation
244 |
245 | Returns
246 | -------
247 | vis_mask : ndarray[bool] of shape (num_agents, width, height)
248 | Boolean visibility mask for each agent
249 | """
250 | num_agents, width, height = obs_grid.shape[:3]
251 | see_behind_mask = get_see_behind_mask(obs_grid)
252 | vis_mask = np.zeros((num_agents, width, height), dtype=np.bool_)
253 | vis_mask[:, width // 2, height - 1] = True # agent relative position
254 |
255 | for agent in range(num_agents):
256 | for j in range(height - 1, -1, -1):
257 | # Forward pass
258 | for i in range(0, width - 1):
259 | if vis_mask[agent, i, j] and see_behind_mask[agent, i, j]:
260 | vis_mask[agent, i + 1, j] = True
261 | if j > 0:
262 | vis_mask[agent, i + 1, j - 1] = True
263 | vis_mask[agent, i, j - 1] = True
264 |
265 | # Backward pass
266 | for i in range(width - 1, 0, -1):
267 | if vis_mask[agent, i, j] and see_behind_mask[agent, i, j]:
268 | vis_mask[agent, i - 1, j] = True
269 | if j > 0:
270 | vis_mask[agent, i - 1, j - 1] = True
271 | vis_mask[agent, i, j - 1] = True
272 |
273 | return vis_mask
274 |
275 | @nb.njit(cache=True)
276 | def get_view_exts(
277 | agent_dir: ndarray[np.int_],
278 | agent_pos: ndarray[np.int_],
279 | agent_view_size: int) -> ndarray[np.int_]:
280 | """
281 | Get the extents of the square set of grid cells visible to each agent.
282 |
283 | Parameters
284 | ----------
285 | agent_dir : ndarray[int] of shape (num_agents,)
286 | Direction of each agent
287 | agent_pos : ndarray[int] of shape (num_agents, 2)
288 | The (x, y) position of each agent
289 | agent_view_size : int
290 | Width and height of agent view
291 |
292 | Returns
293 | -------
294 | top_left : ndarray[int] of shape (num_agents, 2)
295 | The (x, y) coordinates of the top-left corner of each agent's observable view
296 | """
297 | agent_x, agent_y = agent_pos[:, 0], agent_pos[:, 1]
298 | top_left = np.zeros((agent_dir.shape[0], 2), dtype=np.int_)
299 |
300 | # Facing right
301 | top_left[agent_dir == RIGHT, 0] = agent_x[agent_dir == RIGHT]
302 | top_left[agent_dir == RIGHT, 1] = agent_y[agent_dir == RIGHT] - agent_view_size // 2
303 |
304 | # Facing down
305 | top_left[agent_dir == DOWN, 0] = agent_x[agent_dir == DOWN] - agent_view_size // 2
306 | top_left[agent_dir == DOWN, 1] = agent_y[agent_dir == DOWN]
307 |
308 | # Facing left
309 | top_left[agent_dir == LEFT, 0] = agent_x[agent_dir == LEFT] - agent_view_size + 1
310 | top_left[agent_dir == LEFT, 1] = agent_y[agent_dir == LEFT] - agent_view_size // 2
311 |
312 | # Facing up
313 | top_left[agent_dir == UP, 0] = agent_x[agent_dir == UP] - agent_view_size // 2
314 | top_left[agent_dir == UP, 1] = agent_y[agent_dir == UP] - agent_view_size + 1
315 |
316 | return top_left
317 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2019 Maxime Chevalier-Boisvert
190 | Copyright 2022 Farama Foundation
191 | Copyright 2023 Ini Oguntola
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
--------------------------------------------------------------------------------
/multigrid/core/roomgrid.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 |
5 | from collections import deque
6 | from typing import Callable, Iterable, TypeVar
7 |
8 | from .agent import Agent
9 | from .constants import Color, Direction, Type
10 | from .grid import Grid
11 | from .world_object import Door, WorldObj
12 | from ..base import MultiGridEnv
13 |
14 |
15 |
16 | T = TypeVar('T')
17 |
18 |
19 |
20 | def bfs(start_node: T, neighbor_fn: Callable[[T], Iterable[T]]) -> set[T]:
21 | """
22 | Run a breadth-first search from a starting node.
23 |
24 | Parameters
25 | ----------
26 | start_node : T
27 | Start node
28 | neighbor_fn : Callable(T) -> Iterable[T]
29 | Function that returns the neighbors of a node
30 |
31 | Returns
32 | -------
33 | visited : set[T]
34 | Set of nodes reachable from the start node
35 | """
36 | visited, queue = set(), deque([start_node])
37 | while queue:
38 | node = queue.popleft()
39 | if node not in visited:
40 | visited.add(node)
41 | queue.extend(neighbor_fn(node))
42 |
43 | return visited
44 |
45 | def reject_next_to(env: MultiGridEnv, pos: tuple[int, int]):
46 | """
47 | Function to filter out object positions that are right next to
48 | the agent's starting point
49 | """
50 | return any(np.linalg.norm(pos - env.agent_states.pos, axis=-1) <= 1)
51 |
52 |
53 | class Room:
54 | """
55 | Room as an area inside a grid.
56 | """
57 |
58 | def __init__(self, top: tuple[int, int], size: tuple[int, int]):
59 | """
60 | Parameters
61 | ----------
62 | top : tuple[int, int]
63 | Top-left position of the room
64 | size : tuple[int, int]
65 | Room size as (width, height)
66 | """
67 | self.top, self.size = top, size
68 | Point = tuple[int, int] # typing alias
69 |
70 | # Mapping of door objects and door positions
71 | self.doors: dict[Direction, Door | None] = {d: None for d in Direction}
72 | self.door_pos: dict[Direction, Point | None] = {d: None for d in Direction}
73 |
74 | # Mapping of rooms adjacent to this one
75 | self.neighbors: dict[Direction, Room | None] = {d: None for d in Direction}
76 |
77 | # List of objects contained in this room
78 | self.objs = []
79 |
80 | @property
81 | def locked(self) -> bool:
82 | """
83 | Return whether this room is behind a locked door.
84 | """
85 | return any(door and door.is_locked for door in self.doors.values())
86 |
87 | def set_door_pos(
88 | self,
89 | dir: Direction,
90 | random: np.random.Generator | None = None) -> tuple[int, int]:
91 | """
92 | Set door position in the given direction.
93 |
94 | Parameters
95 | ----------
96 | dir : Direction
97 | Direction of wall to place door
98 | random : np.random.Generator, optional
99 | Random number generator (if provided, door position will be random)
100 | """
101 | left, top = self.top
102 | right, bottom = self.top[0] + self.size[0] - 1, self.top[1] + self.size[1] - 1,
103 |
104 | if dir == Direction.right:
105 | if random:
106 | self.door_pos[dir] = (right, random.integers(top + 1, bottom))
107 | else:
108 | self.door_pos[dir] = (right, (top + bottom) // 2)
109 |
110 | elif dir == Direction.down:
111 | if random:
112 | self.door_pos[dir] = (random.integers(left + 1, right), bottom)
113 | else:
114 | self.door_pos[dir] = ((left + right) // 2, bottom)
115 |
116 | elif dir == Direction.left:
117 | if random:
118 | self.door_pos[dir] = (left, random.integers(top + 1, bottom))
119 | else:
120 | self.door_pos[dir] = (left, (top + bottom) // 2)
121 |
122 | elif dir == Direction.up:
123 | if random:
124 | self.door_pos[dir] = (random.integers(left + 1, right), top)
125 | else:
126 | self.door_pos[dir] = ((left + right) // 2, top)
127 |
128 | return self.door_pos[dir]
129 |
130 | def pos_inside(self, x: int, y: int) -> bool:
131 | """
132 | Check if a position is within the bounds of this room.
133 | """
134 | left_x, top_y = self.top
135 | width, height = self.size
136 | return left_x <= x < left_x + width and top_y <= y < top_y + height
137 |
138 |
139 | class RoomGrid(MultiGridEnv):
140 | """
141 | Environment with multiple rooms and random objects.
142 | This is meant to serve as a base class for other environments.
143 | """
144 |
145 | def __init__(
146 | self,
147 | room_size: int = 7,
148 | num_rows: int = 3,
149 | num_cols: int = 3,
150 | **kwargs):
151 | """
152 | Parameters
153 | ----------
154 | room_size : int, default=7
155 | Width and height for each of the rooms
156 | num_rows : int, default=3
157 | Number of rows of rooms
158 | num_cols : int, default=3
159 | Number of columns of rooms
160 | **kwargs
161 | See :attr:`multigrid.base.MultiGridEnv.__init__`
162 | """
163 | assert room_size >= 3
164 | assert num_rows > 0
165 | assert num_cols > 0
166 | self.room_size = room_size
167 | self.num_rows = num_rows
168 | self.num_cols = num_cols
169 | height = (room_size - 1) * num_rows + 1
170 | width = (room_size - 1) * num_cols + 1
171 | super().__init__(width=width, height=height, **kwargs)
172 |
173 | def get_room(self, col: int, row: int) -> Room:
174 | """
175 | Get the room at the given column and row.
176 |
177 | Parameters
178 | ----------
179 | col : int
180 | Column of the room
181 | row : int
182 | Row of the room
183 | """
184 | assert 0 <= col < self.num_cols
185 | assert 0 <= row < self.num_rows
186 | return self.room_grid[row][col]
187 |
188 | def room_from_pos(self, x: int, y: int) -> Room:
189 | """
190 | Get the room a given position maps to.
191 |
192 | Parameters
193 | ----------
194 | x : int
195 | Grid x-coordinate
196 | y : int
197 | Grid y-coordinate
198 | """
199 | col = x // (self.room_size - 1)
200 | row = y // (self.room_size - 1)
201 | return self.get_room(col, row)
202 |
203 | def _gen_grid(self, width, height):
204 | # Create the grid
205 | self.grid = Grid(width, height)
206 | self.room_grid = [[None] * self.num_cols for _ in range(self.num_rows)]
207 |
208 | # Create rooms
209 | for row in range(self.num_rows):
210 | for col in range(self.num_cols):
211 | room = Room(
212 | (col * (self.room_size - 1), row * (self.room_size - 1)),
213 | (self.room_size, self.room_size),
214 | )
215 | self.room_grid[row][col] = room
216 | self.grid.wall_rect(*room.top, *room.size) # generate walls
217 |
218 | # Create connections between rooms
219 | for row in range(self.num_rows):
220 | for col in range(self.num_cols):
221 | room = self.room_grid[row][col]
222 | if col < self.num_cols - 1:
223 | room.neighbors[Direction.right] = self.room_grid[row][col + 1]
224 | if row < self.num_rows - 1:
225 | room.neighbors[Direction.down] = self.room_grid[row + 1][col]
226 | if col > 0:
227 | room.neighbors[Direction.left] = self.room_grid[row][col - 1]
228 | if row > 0:
229 | room.neighbors[Direction.up] = self.room_grid[row - 1][col]
230 |
231 | # Agents start in the middle, facing right
232 | self.agent_states.dir = Direction.right
233 | self.agent_states.pos = (
234 | (self.num_cols // 2) * (self.room_size - 1) + (self.room_size // 2),
235 | (self.num_rows // 2) * (self.room_size - 1) + (self.room_size // 2),
236 | )
237 |
238 | def place_in_room(
239 | self, col: int, row: int, obj: WorldObj) -> tuple[WorldObj, tuple[int, int]]:
240 | """
241 | Add an existing object to the given room.
242 |
243 | Parameters
244 | ----------
245 | col : int
246 | Room column
247 | row : int
248 | Room row
249 | obj : WorldObj
250 | Object to add
251 | """
252 | room = self.get_room(col, row)
253 | pos = self.place_obj(
254 | obj, room.top, room.size, reject_fn=reject_next_to, max_tries=1000)
255 | room.objs.append(obj)
256 | return obj, pos
257 |
258 | def add_object(
259 | self,
260 | col: int,
261 | row: int,
262 | kind: Type | None = None,
263 | color: Color | None = None) -> tuple[WorldObj, tuple[int, int]]:
264 | """
265 | Create a new object in the given room.
266 |
267 | Parameters
268 | ----------
269 | col : int
270 | Room column
271 | row : int
272 | Room row
273 | kind : str, optional
274 | Type of object to add (random if not specified)
275 | color : str, optional
276 | Color of the object to add (random if not specified)
277 | """
278 | kind = kind or self._rand_elem([Type.key, Type.ball, Type.box])
279 | color = color or self._rand_color()
280 | obj = WorldObj(type=kind, color=color)
281 | return self.place_in_room(col, row, obj)
282 |
283 | def add_door(
284 | self,
285 | col: int,
286 | row: int,
287 | dir: Direction | None = None,
288 | color: Color | None = None,
289 | locked: bool | None = None,
290 | rand_pos: bool = True) -> tuple[Door, tuple[int, int]]:
291 | """
292 | Add a door to a room, connecting it to a neighbor.
293 |
294 | Parameters
295 | ----------
296 | col : int
297 | Room column
298 | row : int
299 | Room row
300 | dir : Direction, optional
301 | Which wall to put the door on (random if not specified)
302 | color : Color, optional
303 | Color of the door (random if not specified)
304 | locked : bool, optional
305 | Whether the door is locked (random if not specified)
306 | rand_pos : bool, default=True
307 | Whether to place the door at a random position on the room wall
308 | """
309 | room = self.get_room(col, row)
310 |
311 | # Need to make sure that there is a neighbor along this wall
312 | # and that there is not already a door
313 | if dir is None:
314 | while room.neighbors[dir] is None or room.doors[dir] is not None:
315 | dir = self._rand_elem(Direction)
316 | else:
317 | assert room.neighbors[dir] is not None, "no neighbor in this direction"
318 | assert room.doors[dir] is None, "door already exists"
319 |
320 | # Create the door
321 | color = color if color is not None else self._rand_color()
322 | locked = locked if locked is not None else self._rand_bool()
323 | door = Door(color, is_locked=locked)
324 | pos = room.set_door_pos(dir, random=self.np_random if rand_pos else None)
325 | self.put_obj(door, *pos)
326 |
327 | # Connect the door to the neighboring room
328 | room.doors[dir] = door
329 | room.neighbors[dir].doors[(dir + 2) % 4] = door
330 |
331 | return door, pos
332 |
333 | def remove_wall(self, col: int, row: int, dir: Direction):
334 | """
335 | Remove a wall between two rooms.
336 |
337 | Parameters
338 | ----------
339 | col : int
340 | Room column
341 | row : int
342 | Room row
343 | dir : Direction
344 | Direction of the wall to remove
345 | """
346 | room = self.get_room(col, row)
347 | assert room.doors[dir] is None, "door exists on this wall"
348 | assert room.neighbors[dir], "invalid wall"
349 |
350 | tx, ty = room.top
351 | w, h = room.size
352 |
353 | # Remove the wall
354 | if dir == Direction.right:
355 | for i in range(1, h - 1):
356 | self.grid.set(tx + w - 1, ty + i, None)
357 | elif dir == Direction.down:
358 | for i in range(1, w - 1):
359 | self.grid.set(tx + i, ty + h - 1, None)
360 | elif dir == Direction.left:
361 | for i in range(1, h - 1):
362 | self.grid.set(tx, ty + i, None)
363 | elif dir == Direction.up:
364 | for i in range(1, w - 1):
365 | self.grid.set(tx + i, ty, None)
366 | else:
367 | assert False, "invalid wall index"
368 |
369 | # Mark the rooms as connected
370 | room.doors[dir] = True
371 | room.neighbors[dir].doors[(dir + 2) % 4] = True
372 |
373 | def place_agent(
374 | self,
375 | agent: Agent,
376 | col: int | None = None,
377 | row: int | None = None,
378 | rand_dir: bool = True) -> tuple[int, int]:
379 | """
380 | Place an agent in a room.
381 |
382 | Parameters
383 | ----------
384 | agent : Agent
385 | Agent to place
386 | col : int, optional
387 | Room column to place the agent in (random if not specified)
388 | row : int, optional
389 | Room row to place the agent in (random if not specified)
390 | rand_dir : bool, default=True
391 | Whether to select a random agent direction
392 | """
393 | col = col if col is not None else self._rand_int(0, self.num_cols)
394 | row = row if row is not None else self._rand_int(0, self.num_rows)
395 | room = self.get_room(col, row)
396 |
397 | # Find a position that is not right in front of an object
398 | while True:
399 | super().place_agent(agent, room.top, room.size, rand_dir, max_tries=1000)
400 | front_cell = self.grid.get(*agent.front_pos)
401 | if front_cell is None or front_cell.type == Type.wall:
402 | break
403 |
404 | return agent.state.pos
405 |
406 | def connect_all(
407 | self,
408 | door_colors: list[Color] = list(Color),
409 | max_itrs: int = 5000) -> list[Door]:
410 | """
411 | Make sure that all rooms are reachable by the agent from its
412 | starting position.
413 |
414 | Parameters
415 | ----------
416 | door_colors : list[Color], default=list(Color)
417 | Color options for creating doors
418 | max_itrs : int, default=5000
419 | Maximum number of iterations to try to connect all rooms
420 | """
421 | added_doors = []
422 | neighbor_fn = lambda room: [
423 | room.neighbors[dir] for dir in Direction if room.doors[dir] is not None]
424 | start_room = self.get_room(0, 0)
425 |
426 | for i in range(max_itrs):
427 | # If all rooms are reachable, stop
428 | reachable_rooms = bfs(start_room, neighbor_fn)
429 | if len(reachable_rooms) == self.num_rows * self.num_cols:
430 | return added_doors
431 |
432 | # Pick a random room and door position
433 | col = self._rand_int(0, self.num_cols)
434 | row = self._rand_int(0, self.num_rows)
435 | dir = self._rand_elem(Direction)
436 | room = self.get_room(col, row)
437 |
438 | # If there is already a door there, skip
439 | if not room.neighbors[dir] or room.doors[dir]:
440 | continue
441 |
442 | neighbor_room = room.neighbors[dir]
443 | assert neighbor_room is not None
444 | if room.locked or neighbor_room.locked:
445 | continue
446 |
447 | # Add a new door
448 | color = self._rand_elem(door_colors)
449 | door, _ = self.add_door(col, row, dir=dir, color=color, locked=False)
450 | added_doors.append(door)
451 |
452 | raise RecursionError('connect_all() failed')
453 |
454 | def add_distractors(
455 | self,
456 | col: int | None = None,
457 | row: int | None = None,
458 | num_distractors: int = 10,
459 | all_unique: bool = True) -> list[WorldObj]:
460 | """
461 | Add random objects that can potentially distract / confuse the agent.
462 |
463 | Parameters
464 | ----------
465 | col : int, optional
466 | Room column to place the objects in (random if not specified)
467 | row : int, optional
468 | Room row to place the objects in (random if not specified)
469 | num_distractors : int, default=10
470 | Number of distractor objects to add
471 | all_unique : bool, default=True
472 | Whether all distractor objects should be unique with respect to (type, color)
473 | """
474 | # Collect keys for existing room objects
475 | room_objs = (obj for row in self.room_grid for room in row for obj in room.objs)
476 | room_obj_keys = {(obj.type, obj.color) for obj in room_objs}
477 |
478 | # Add distractors
479 | distractors = []
480 | while len(distractors) < num_distractors:
481 | color = self._rand_color()
482 | type = self._rand_elem([Type.key, Type.ball, Type.box])
483 |
484 | if all_unique and (type, color) in room_obj_keys:
485 | continue
486 |
487 | # Add the object to a random room if no room specified
488 | col = col if col is not None else self._rand_int(0, self.num_cols)
489 | row = row if row is not None else self._rand_int(0, self.num_rows)
490 | distractor, _ = self.add_object(col, row, kind=type, color=color)
491 |
492 | room_obj_keys.append((type, color))
493 | distractors.append(distractor)
494 |
495 | return distractors
496 |
--------------------------------------------------------------------------------
/multigrid/core/world_object.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import functools
4 | import numpy as np
5 |
6 | from numpy.typing import ArrayLike, NDArray as ndarray
7 | from typing import Any, TYPE_CHECKING
8 |
9 | from .constants import Color, State, Type
10 | from ..utils.rendering import (
11 | fill_coords,
12 | point_in_circle,
13 | point_in_line,
14 | point_in_rect,
15 | )
16 |
17 | if TYPE_CHECKING:
18 | from .agent import Agent
19 | from ..base import MultiGridEnv
20 |
21 |
22 |
23 | class WorldObjMeta(type):
24 | """
25 | Metaclass for world objects.
26 |
27 | Each subclass is associated with a unique :class:`Type` enumeration value.
28 |
29 | By default, the type name is the class name (in lowercase), but this can be
30 | overridden by setting the `type_name` attribute in the class definition.
31 | Type names are dynamically added to the :class:`Type` enumeration
32 | if not already present.
33 |
34 | Examples
35 | --------
36 | >>> class A(WorldObj): pass
37 | >>> A().type
38 |
39 |
40 | >>> class B(WorldObj): type_name = 'goal'
41 | >>> B().type
42 |
43 |
44 | :meta private:
45 | """
46 |
47 | # Registry of object classes
48 | _TYPE_IDX_TO_CLASS = {}
49 |
50 | def __new__(meta, name, bases, class_dict):
51 | cls = super().__new__(meta, name, bases, class_dict)
52 |
53 | if name != 'WorldObj':
54 | type_name = class_dict.get('type_name', name.lower())
55 |
56 | # Add the object class name to the `Type` enumeration if not already present
57 | if type_name not in set(Type):
58 | Type.add_item(type_name, type_name)
59 |
60 | # Store the object class with its corresponding type index
61 | meta._TYPE_IDX_TO_CLASS[Type(type_name).to_index()] = cls
62 |
63 | return cls
64 |
65 |
66 | class WorldObj(np.ndarray, metaclass=WorldObjMeta):
67 | """
68 | Base class for grid world objects.
69 |
70 | Attributes
71 | ----------
72 | type : Type
73 | The object type
74 | color : Color
75 | The object color
76 | state : State
77 | The object state
78 | contains : WorldObj or None
79 | The object contained by this object, if any
80 | init_pos : tuple[int, int] or None
81 | The initial position of the object
82 | cur_pos : tuple[int, int] or None
83 | The current position of the object
84 | """
85 | # WorldObj vector indices
86 | TYPE = 0
87 | COLOR = 1
88 | STATE = 2
89 |
90 | # WorldObj vector dimension
91 | dim = len([TYPE, COLOR, STATE])
92 |
93 | def __new__(cls, type: str | None = None, color: str = Color.from_index(0)):
94 | """
95 | Parameters
96 | ----------
97 | type : str or None
98 | Object type
99 | color : str
100 | Object color
101 | """
102 | # If not provided, infer the object type from the class
103 | type_name = type or getattr(cls, 'type_name', cls.__name__.lower())
104 | type_idx = Type(type_name).to_index()
105 |
106 | # Use the WorldObj subclass corresponding to the object type
107 | cls = WorldObjMeta._TYPE_IDX_TO_CLASS.get(type_idx, cls)
108 |
109 | # Create the object
110 | obj = np.zeros(cls.dim, dtype=int).view(cls)
111 | obj[WorldObj.TYPE] = type_idx
112 | obj[WorldObj.COLOR] = Color(color).to_index()
113 | obj.contains: WorldObj | None = None # object contained by this object
114 | obj.init_pos: tuple[int, int] | None = None # initial position of the object
115 | obj.cur_pos: tuple[int, int] | None = None # current position of the object
116 |
117 | return obj
118 |
119 | def __bool__(self) -> bool:
120 | return self.type != Type.empty
121 |
122 | def __repr__(self) -> str:
123 | return f"{self.__class__.__name__}(color={self.color})"
124 |
125 | def __str__(self) -> str:
126 | return self.__repr__()
127 |
128 | def __eq__(self, other: Any):
129 | return self is other
130 |
131 | @staticmethod
132 | @functools.cache
133 | def empty() -> 'WorldObj':
134 | """
135 | Return an empty WorldObj instance.
136 | """
137 | return WorldObj(type=Type.empty)
138 |
139 | @staticmethod
140 | def from_array(arr: ArrayLike[int]) -> 'WorldObj' | None:
141 | """
142 | Convert an array to a WorldObj instance.
143 |
144 | Parameters
145 | ----------
146 | arr : ArrayLike[int]
147 | Array encoding the object type, color, and state
148 | """
149 | type_idx = arr[WorldObj.TYPE]
150 |
151 | if type_idx == Type.empty.to_index():
152 | return None
153 |
154 | if type_idx in WorldObj._TYPE_IDX_TO_CLASS:
155 | cls = WorldObj._TYPE_IDX_TO_CLASS[type_idx]
156 | obj = cls.__new__(cls)
157 | obj[...] = arr
158 | return obj
159 |
160 | raise ValueError(f'Unknown object type: {arr[WorldObj.TYPE]}')
161 |
162 | @functools.cached_property
163 | def type(self) -> Type:
164 | """
165 | Return the object type.
166 | """
167 | return Type.from_index(self[WorldObj.TYPE])
168 |
169 | @property
170 | def color(self) -> Color:
171 | """
172 | Return the object color.
173 | """
174 | return Color.from_index(self[WorldObj.COLOR])
175 |
176 | @color.setter
177 | def color(self, value: str):
178 | """
179 | Set the object color.
180 | """
181 | self[WorldObj.COLOR] = Color(value).to_index()
182 |
183 | @property
184 | def state(self) -> str:
185 | """
186 | Return the name of the object state.
187 | """
188 | return State.from_index(self[WorldObj.STATE])
189 |
190 | @state.setter
191 | def state(self, value: str):
192 | """
193 | Set the name of the object state.
194 | """
195 | self[WorldObj.STATE] = State(value).to_index()
196 |
197 | def can_overlap(self) -> bool:
198 | """
199 | Can an agent overlap with this?
200 | """
201 | return self.type == Type.empty
202 |
203 | def can_pickup(self) -> bool:
204 | """
205 | Can an agent pick this up?
206 | """
207 | return False
208 |
209 | def can_contain(self) -> bool:
210 | """
211 | Can this contain another object?
212 | """
213 | return False
214 |
215 | def toggle(self, env: MultiGridEnv, agent: Agent, pos: tuple[int, int]) -> bool:
216 | """
217 | Toggle the state of this object or trigger an action this object performs.
218 |
219 | Parameters
220 | ----------
221 | env : MultiGridEnv
222 | The environment this object is contained in
223 | agent : Agent
224 | The agent performing the toggle action
225 | pos : tuple[int, int]
226 | The (x, y) position of this object in the environment grid
227 |
228 | Returns
229 | -------
230 | success : bool
231 | Whether the toggle action was successful
232 | """
233 | return False
234 |
235 | def encode(self) -> tuple[int, int, int]:
236 | """
237 | Encode a 3-tuple description of this object.
238 |
239 | Returns
240 | -------
241 | type_idx : int
242 | The index of the object type
243 | color_idx : int
244 | The index of the object color
245 | state_idx : int
246 | The index of the object state
247 | """
248 | return tuple(self)
249 |
250 | @staticmethod
251 | def decode(type_idx: int, color_idx: int, state_idx: int) -> 'WorldObj' | None:
252 | """
253 | Create an object from a 3-tuple description.
254 |
255 | Parameters
256 | ----------
257 | type_idx : int
258 | The index of the object type
259 | color_idx : int
260 | The index of the object color
261 | state_idx : int
262 | The index of the object state
263 | """
264 | arr = np.array([type_idx, color_idx, state_idx])
265 | return WorldObj.from_array(arr)
266 |
267 | def render(self, img: ndarray[np.uint8]):
268 | """
269 | Draw the world object.
270 |
271 | Parameters
272 | ----------
273 | img : ndarray[int] of shape (width, height, 3)
274 | RGB image array to render object on
275 | """
276 | raise NotImplementedError
277 |
278 |
279 | class Goal(WorldObj):
280 | """
281 | Goal object an agent may be searching for.
282 | """
283 |
284 | def __new__(cls, color: str = Color.green):
285 | return super().__new__(cls, color=color)
286 |
287 | def can_overlap(self) -> bool:
288 | """
289 | :meta private:
290 | """
291 | return True
292 |
293 | def render(self, img):
294 | """
295 | :meta private:
296 | """
297 | fill_coords(img, point_in_rect(0, 1, 0, 1), self.color.rgb())
298 |
299 |
300 | class Floor(WorldObj):
301 | """
302 | Colored floor tile an agent can walk over.
303 | """
304 |
305 | def __new__(cls, color: str = Color.blue):
306 | """
307 | Parameters
308 | ----------
309 | color : str
310 | Object color
311 | """
312 | return super().__new__(cls, color=color)
313 |
314 | def can_overlap(self) -> bool:
315 | """
316 | :meta private:
317 | """
318 | return True
319 |
320 | def render(self, img):
321 | """
322 | :meta private:
323 | """
324 | # Give the floor a pale color
325 | color = self.color.rgb() / 2
326 | fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
327 |
328 |
329 | class Lava(WorldObj):
330 | """
331 | Lava object an agent can fall onto.
332 | """
333 |
334 | def __new__(cls):
335 | """
336 | """
337 | return super().__new__(cls, color=Color.red)
338 |
339 | def can_overlap(self) -> bool:
340 | """
341 | :meta private:
342 | """
343 | return True
344 |
345 | def render(self, img):
346 | """
347 | :meta private:
348 | """
349 | c = (255, 128, 0)
350 |
351 | # Background color
352 | fill_coords(img, point_in_rect(0, 1, 0, 1), c)
353 |
354 | # Little waves
355 | for i in range(3):
356 | ylo = 0.3 + 0.2 * i
357 | yhi = 0.4 + 0.2 * i
358 | fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
359 | fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
360 | fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
361 | fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
362 |
363 |
364 | class Wall(WorldObj):
365 | """
366 | Wall object that agents cannot move through.
367 | """
368 |
369 | @functools.cache # reuse instances, since object is effectively immutable
370 | def __new__(cls, color: str = Color.grey):
371 | """
372 | Parameters
373 | ----------
374 | color : str
375 | Object color
376 | """
377 | return super().__new__(cls, color=color)
378 |
379 | def render(self, img):
380 | """
381 | :meta private:
382 | """
383 | fill_coords(img, point_in_rect(0, 1, 0, 1), self.color.rgb())
384 |
385 |
386 | class Door(WorldObj):
387 | """
388 | Door object that may be opened or closed. Locked doors require a key to open.
389 |
390 | Attributes
391 | ----------
392 | is_open: bool
393 | Whether the door is open
394 | is_locked: bool
395 | Whether the door is locked
396 | """
397 |
398 | def __new__(
399 | cls, color: str = Color.blue, is_open: bool = False, is_locked: bool = False):
400 | """
401 | Parameters
402 | ----------
403 | color : str
404 | Object color
405 | is_open : bool
406 | Whether the door is open
407 | is_locked : bool
408 | Whether the door is locked
409 | """
410 | door = super().__new__(cls, color=color)
411 | door.is_open = is_open
412 | door.is_locked = is_locked
413 | return door
414 |
415 | def __str__(self):
416 | return f"{self.__class__.__name__}(color={self.color},state={self.state})"
417 |
418 | @property
419 | def is_open(self) -> bool:
420 | """
421 | Whether the door is open.
422 | """
423 | return self.state == State.open
424 |
425 | @is_open.setter
426 | def is_open(self, value: bool):
427 | """
428 | Set the door to be open or closed.
429 | """
430 | if value:
431 | self.state = State.open # set state to open
432 | elif not self.is_locked:
433 | self.state = State.closed # set state to closed (unless already locked)
434 |
435 | @property
436 | def is_locked(self) -> bool:
437 | """
438 | Whether the door is locked.
439 | """
440 | return self.state == State.locked
441 |
442 | @is_locked.setter
443 | def is_locked(self, value: bool):
444 | """
445 | Set the door to be locked or unlocked.
446 | """
447 | if value:
448 | self.state = State.locked # set state to locked
449 | elif not self.is_open:
450 | self.state = State.closed # set state to closed (unless already open)
451 |
452 | def can_overlap(self) -> bool:
453 | """
454 | :meta private:
455 | """
456 | return self.is_open
457 |
458 | def toggle(self, env, agent, pos):
459 | """
460 | :meta private:
461 | """
462 | if self.is_locked:
463 | # Check if the player has the right key to unlock the door
464 | carried_obj = agent.state.carrying
465 | if isinstance(carried_obj, Key) and carried_obj.color == self.color:
466 | self.is_locked = False
467 | self.is_open = True
468 | env.grid.update(*pos)
469 | return True
470 | return False
471 |
472 | self.is_open = not self.is_open
473 | env.grid.update(*pos)
474 | return True
475 |
476 | def render(self, img):
477 | """
478 | :meta private:
479 | """
480 | c = self.color.rgb()
481 |
482 | if self.is_open:
483 | fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
484 | fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
485 | return
486 |
487 | # Door frame and door
488 | if self.is_locked:
489 | fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
490 | fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * c)
491 |
492 | # Draw key slot
493 | fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
494 | else:
495 | fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
496 | fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
497 | fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
498 | fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
499 |
500 | # Draw door handle
501 | fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
502 |
503 |
504 | class Key(WorldObj):
505 | """
506 | Key object that can be picked up and used to unlock doors.
507 | """
508 |
509 | def __new__(cls, color: str = Color.blue):
510 | """
511 | Parameters
512 | ----------
513 | color : str
514 | Object color
515 | """
516 | return super().__new__(cls, color=color)
517 |
518 | def can_pickup(self) -> bool:
519 | """
520 | :meta private:
521 | """
522 | return True
523 |
524 | def render(self, img):
525 | """
526 | :meta private:
527 | """
528 | c = self.color.rgb()
529 |
530 | # Vertical quad
531 | fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
532 |
533 | # Teeth
534 | fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
535 | fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
536 |
537 | # Ring
538 | fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
539 | fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
540 |
541 |
542 | class Ball(WorldObj):
543 | """
544 | Ball object that can be picked up by agents.
545 | """
546 |
547 | def __new__(cls, color: str = Color.blue):
548 | """
549 | Parameters
550 | ----------
551 | color : str
552 | Object color
553 | """
554 | return super().__new__(cls, color=color)
555 |
556 | def can_pickup(self) -> bool:
557 | """
558 | :meta private:
559 | """
560 | return True
561 |
562 | def render(self, img):
563 | """
564 | :meta private:
565 | """
566 | fill_coords(img, point_in_circle(0.5, 0.5, 0.31), self.color.rgb())
567 |
568 |
569 | class Box(WorldObj):
570 | """
571 | Box object that may contain other objects.
572 | """
573 |
574 | def __new__(cls, color: str = Color.yellow, contains: WorldObj | None = None):
575 | """
576 | Parameters
577 | ----------
578 | color : str
579 | Object color
580 | contains : WorldObj or None
581 | Object contents
582 | """
583 | box = super().__new__(cls, color=color)
584 | box.contains = contains
585 | return box
586 |
587 | def can_pickup(self) -> bool:
588 | """
589 | :meta private:
590 | """
591 | return True
592 |
593 | def can_contain(self) -> bool:
594 | """
595 | :meta private:
596 | """
597 | return True
598 |
599 | def toggle(self, env, agent, pos):
600 | """
601 | :meta private:
602 | """
603 | # Replace the box by its contents
604 | env.grid.set(*pos, self.contains)
605 | return True
606 |
607 | def render(self, img):
608 | """
609 | :meta private:
610 | """
611 | # Outline
612 | fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), self.color.rgb())
613 | fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
614 |
615 | # Horizontal slit
616 | fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), self.color.rgb())
617 |
--------------------------------------------------------------------------------
/multigrid/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import gymnasium as gym
4 | import math
5 | import numpy as np
6 | import pygame
7 | import pygame.freetype
8 |
9 | from abc import ABC, abstractmethod
10 | from collections import defaultdict
11 | from gymnasium import spaces
12 | from itertools import repeat
13 | from numpy.typing import NDArray as ndarray
14 | from typing import Any, Callable, Iterable, Literal, SupportsFloat
15 |
16 | from .core.actions import Action
17 | from .core.agent import Agent, AgentState
18 | from .core.constants import Type, TILE_PIXELS
19 | from .core.grid import Grid
20 | from .core.mission import MissionSpace
21 | from .core.world_object import WorldObj
22 | from .utils.obs import gen_obs_grid_encoding
23 | from .utils.random import RandomMixin
24 |
25 |
26 |
27 | ### Typing
28 |
29 | AgentID = int
30 | ObsType = dict[str, Any]
31 |
32 |
33 |
34 | ### Environment
35 |
36 | class MultiGridEnv(gym.Env, RandomMixin, ABC):
37 | """
38 | Base class for multi-agent 2D gridworld environments.
39 |
40 | :Agents:
41 |
42 | The environment can be configured with any fixed number of agents.
43 | Agents are represented by :class:`.Agent` instances, and are
44 | identified by their index, from ``0`` to ``len(env.agents) - 1``.
45 |
46 | :Observation Space:
47 |
48 | The multi-agent observation space is a Dict mapping from agent index to
49 | corresponding agent observation space.
50 |
51 | The standard agent observation is a dictionary with the following entries:
52 |
53 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`)
54 | Encoding of the agent's view of the environment,
55 | where each grid object is encoded as a 3 dimensional tuple:
56 | (:class:`.Type`, :class:`.Color`, :class:`.State`)
57 | * direction : int
58 | Agent's direction (0: right, 1: down, 2: left, 3: up)
59 | * mission : Mission
60 | Task string corresponding to the current environment configuration
61 |
62 | :Action Space:
63 |
64 | The multi-agent action space is a Dict mapping from agent index to
65 | corresponding agent action space.
66 |
67 | Agent actions are discrete integers, as enumerated in :class:`.Action`.
68 |
69 | Attributes
70 | ----------
71 | agents : list[Agent]
72 | List of agents in the environment
73 | grid : Grid
74 | Environment grid
75 | observation_space : spaces.Dict[AgentID, spaces.Space]
76 | Joint observation space of all agents
77 | action_space : spaces.Dict[AgentID, spaces.Space]
78 | Joint action space of all agents
79 | """
80 | metadata = {
81 | 'render_modes': ['human', 'rgb_array'],
82 | 'render_fps': 20,
83 | }
84 |
85 | def __init__(
86 | self,
87 | mission_space: MissionSpace | str = "maximize reward",
88 | agents: Iterable[Agent] | int = 1,
89 | grid_size: int | None = None,
90 | width: int | None = None,
91 | height: int | None = None,
92 | max_steps: int = 100,
93 | see_through_walls: bool = False,
94 | agent_view_size: int = 7,
95 | allow_agent_overlap: bool = True,
96 | joint_reward: bool = False,
97 | success_termination_mode: Literal['any', 'all'] = 'any',
98 | failure_termination_mode: Literal['any', 'all'] = 'all',
99 | render_mode: str | None = None,
100 | screen_size: int | None = 640,
101 | highlight: bool = True,
102 | tile_size: int = TILE_PIXELS,
103 | agent_pov: bool = False):
104 | """
105 | Parameters
106 | ----------
107 | mission_space : MissionSpace
108 | Space of mission strings (i.e. agent instructions)
109 | agents : int or Iterable[Agent]
110 | Number of agents in the environment (or provide :class:`Agent` instances)
111 | grid_size : int
112 | Size of the environment grid (width and height)
113 | width : int
114 | Width of the environment grid (if `grid_size` is not provided)
115 | height : int
116 | Height of the environment grid (if `grid_size` is not provided)
117 | max_steps : int
118 | Maximum number of steps per episode
119 | see_through_walls : bool
120 | Whether agents can see through walls
121 | agent_view_size : int
122 | Size of agent view (must be odd)
123 | allow_agent_overlap : bool
124 | Whether agents are allowed to overlap
125 | joint_reward : bool
126 | Whether all agents receive the same joint reward
127 | success_termination_mode : 'any' or 'all'
128 | Whether to terminate when any agent completes its mission
129 | or when all agents complete their missions
130 | failure_termination_mode : 'any' or 'all'
131 | Whether to terminate when any agent fails its mission
132 | or when all agents fail their missions
133 | render_mode : str
134 | Rendering mode (human or rgb_array)
135 | screen_size : int
136 | Width and height of the rendering window (in pixels)
137 | highlight : bool
138 | Whether to highlight the view of each agent when rendering
139 | tile_size : int
140 | Width and height of each grid tiles (in pixels)
141 | """
142 | gym.Env.__init__(self)
143 | RandomMixin.__init__(self, self.np_random)
144 |
145 | # Initialize mission space
146 | if isinstance(mission_space, str):
147 | self.mission_space = MissionSpace.from_string(mission_space)
148 | else:
149 | self.mission_space = mission_space
150 |
151 | # Initialize grid
152 | width, height = (grid_size, grid_size) if grid_size else (width, height)
153 | assert width is not None and height is not None
154 | self.width, self.height = width, height
155 | self.grid: Grid = Grid(width, height)
156 |
157 | # Initialize agents
158 | if isinstance(agents, int):
159 | self.num_agents = agents
160 | self.agent_states = AgentState(agents) # joint agent state (vectorized)
161 | self.agents: list[Agent] = []
162 | for i in range(agents):
163 | agent = Agent(
164 | index=i,
165 | mission_space=self.mission_space,
166 | view_size=agent_view_size,
167 | see_through_walls=see_through_walls,
168 | )
169 | agent.state = self.agent_states[i]
170 | self.agents.append(agent)
171 | elif isinstance(agents, Iterable):
172 | assert {agent.index for agent in agents} == set(range(len(agents)))
173 | self.num_agents = len(agents)
174 | self.agent_states = AgentState(self.num_agents)
175 | self.agents: list[Agent] = sorted(agents, key=lambda agent: agent.index)
176 | for agent in self.agents:
177 | self.agent_states[agent.index] = agent.state # copy to joint agent state
178 | agent.state = self.agent_states[agent.index] # reference joint agent state
179 | else:
180 | raise ValueError(f"Invalid argument for agents: {agents}")
181 |
182 | # Action enumeration for this environment
183 | self.actions = Action
184 |
185 | # Range of possible rewards
186 | self.reward_range = (0, 1)
187 |
188 | assert isinstance(
189 | max_steps, int
190 | ), f"The argument max_steps must be an integer, got: {type(max_steps)}"
191 | self.max_steps = max_steps
192 |
193 | # Rendering attributes
194 | self.render_mode = render_mode
195 | self.highlight = highlight
196 | self.tile_size = tile_size
197 | self.agent_pov = agent_pov
198 | self.screen_size = screen_size
199 | self.render_size = None
200 | self.window = None
201 | self.clock = None
202 |
203 | # Other
204 | self.allow_agent_overlap = allow_agent_overlap
205 | self.joint_reward = joint_reward
206 | self.success_termination_mode = success_termination_mode
207 | self.failure_termination_mode = failure_termination_mode
208 |
209 | @property
210 | def observation_space(self) -> spaces.Dict[AgentID, spaces.Space]:
211 | """
212 | Return the joint observation space of all agents.
213 | """
214 | return spaces.Dict({
215 | agent.index: agent.observation_space
216 | for agent in self.agents
217 | })
218 |
219 | @property
220 | def action_space(self) -> spaces.Dict[AgentID, spaces.Space]:
221 | """
222 | Return the joint action space of all agents.
223 | """
224 | return spaces.Dict({
225 | agent.index: agent.action_space
226 | for agent in self.agents
227 | })
228 |
229 | @abstractmethod
230 | def _gen_grid(self, width: int, height: int):
231 | """
232 | :meta public:
233 |
234 | Generate the grid for a new episode.
235 |
236 | This method should:
237 |
238 | * Set ``self.grid`` and populate it with :class:`.WorldObj` instances
239 | * Set the positions and directions of each agent
240 |
241 | Parameters
242 | ----------
243 | width : int
244 | Width of the grid
245 | height : int
246 | Height of the grid
247 | """
248 | pass
249 |
250 | def reset(
251 | self, seed: int | None = None, **kwargs) -> tuple[
252 | dict[AgentID, ObsType]:
253 | dict[AgentID, dict[str, Any]]]:
254 | """
255 | Reset the environment.
256 |
257 | Parameters
258 | ----------
259 | seed : int or None
260 | Seed for random number generator
261 |
262 | Returns
263 | -------
264 | observations : dict[AgentID, ObsType]
265 | Observation for each agent
266 | infos : dict[AgentID, dict[str, Any]]
267 | Additional information for each agent
268 | """
269 | super().reset(seed=seed, **kwargs)
270 |
271 | # Reset agents
272 | self.mission_space.seed(seed)
273 | self.mission = self.mission_space.sample()
274 | self.agent_states = AgentState(self.num_agents)
275 | for agent in self.agents:
276 | agent.state = self.agent_states[agent.index]
277 | agent.reset(mission=self.mission)
278 |
279 | # Generate a new random grid at the start of each episode
280 | self._gen_grid(self.width, self.height)
281 |
282 | # These fields should be defined by _gen_grid
283 | assert np.all(self.agent_states.pos >= 0)
284 | assert np.all(self.agent_states.dir >= 0)
285 |
286 | # Check that agents don't overlap with other objects
287 | for agent in self.agents:
288 | start_cell = self.grid.get(*agent.state.pos)
289 | assert start_cell is None or start_cell.can_overlap()
290 |
291 | # Step count since episode start
292 | self.step_count = 0
293 |
294 | # Return first observation
295 | observations = self.gen_obs()
296 |
297 | # Render environment
298 | if self.render_mode == 'human':
299 | self.render()
300 |
301 | return observations, defaultdict(dict)
302 |
303 | def step(
304 | self,
305 | actions: dict[AgentID, Action]) -> tuple[
306 | dict[AgentID, ObsType],
307 | dict[AgentID, SupportsFloat],
308 | dict[AgentID, bool],
309 | dict[AgentID, bool],
310 | dict[AgentID, dict[str, Any]]]:
311 | """
312 | Run one timestep of the environment’s dynamics
313 | using the provided agent actions.
314 |
315 | Parameters
316 | ----------
317 | actions : dict[AgentID, Action]
318 | Action for each agent acting at this timestep
319 |
320 | Returns
321 | -------
322 | observations : dict[AgentID, ObsType]
323 | Observation for each agent
324 | rewards : dict[AgentID, SupportsFloat]
325 | Reward for each agent
326 | terminations : dict[AgentID, bool]
327 | Whether the episode has been terminated for each agent (success or failure)
328 | truncations : dict[AgentID, bool]
329 | Whether the episode has been truncated for each agent (max steps reached)
330 | infos : dict[AgentID, dict[str, Any]]
331 | Additional information for each agent
332 | """
333 | self.step_count += 1
334 | rewards = self.handle_actions(actions)
335 |
336 | # Generate outputs
337 | observations = self.gen_obs()
338 | terminations = dict(enumerate(self.agent_states.terminated))
339 | truncated = self.step_count >= self.max_steps
340 | truncations = dict(enumerate(repeat(truncated, self.num_agents)))
341 |
342 | # Rendering
343 | if self.render_mode == 'human':
344 | self.render()
345 |
346 | return observations, rewards, terminations, truncations, defaultdict(dict)
347 |
348 | def gen_obs(self) -> dict[AgentID, ObsType]:
349 | """
350 | Generate observations for each agent (partially observable, low-res encoding).
351 |
352 | Returns
353 | -------
354 | observations : dict[AgentID, ObsType]
355 | Mapping from agent ID to observation dict, containing:
356 | * 'image': partially observable view of the environment
357 | * 'direction': agent's direction / orientation (acting as a compass)
358 | * 'mission': textual mission string (instructions for the agent)
359 | """
360 | direction = self.agent_states.dir
361 | image = gen_obs_grid_encoding(
362 | self.grid.state,
363 | self.agent_states,
364 | self.agents[0].view_size,
365 | self.agents[0].see_through_walls,
366 | )
367 |
368 | observations = {}
369 | for i in range(self.num_agents):
370 | observations[i] = {
371 | 'image': image[i],
372 | 'direction': direction[i],
373 | 'mission': self.agents[i].mission,
374 | }
375 |
376 | return observations
377 |
378 | def handle_actions(
379 | self, actions: dict[AgentID, Action]) -> dict[AgentID, SupportsFloat]:
380 | """
381 | Handle actions taken by agents.
382 |
383 | Parameters
384 | ----------
385 | actions : dict[AgentID, Action]
386 | Action for each agent acting at this timestep
387 |
388 | Returns
389 | -------
390 | rewards : dict[AgentID, SupportsFloat]
391 | Reward for each agent
392 | """
393 | rewards = {agent_index: 0 for agent_index in range(self.num_agents)}
394 |
395 | # Randomize agent action order
396 | if self.num_agents == 1:
397 | order = (0,)
398 | else:
399 | order = self.np_random.random(size=self.num_agents).argsort()
400 |
401 | # Update agent states, grid states, and reward from actions
402 | for i in order:
403 | if i not in actions:
404 | continue
405 |
406 | agent, action = self.agents[i], actions[i]
407 |
408 | if agent.state.terminated:
409 | continue
410 |
411 | # Rotate left
412 | if action == Action.left:
413 | agent.state.dir = (agent.state.dir - 1) % 4
414 |
415 | # Rotate right
416 | elif action == Action.right:
417 | agent.state.dir = (agent.state.dir + 1) % 4
418 |
419 | # Move forward
420 | elif action == Action.forward:
421 | fwd_pos = agent.front_pos
422 | fwd_obj = self.grid.get(*fwd_pos)
423 |
424 | if fwd_obj is None or fwd_obj.can_overlap():
425 | if not self.allow_agent_overlap:
426 | agent_present = np.bitwise_and.reduce(
427 | self.agent_states.pos == fwd_pos, axis=1).any()
428 | if agent_present:
429 | continue
430 |
431 | agent.state.pos = fwd_pos
432 | if fwd_obj is not None:
433 | if fwd_obj.type == Type.goal:
434 | self.on_success(agent, rewards, {})
435 | if fwd_obj.type == Type.lava:
436 | self.on_failure(agent, rewards, {})
437 |
438 | # Pick up an object
439 | elif action == Action.pickup:
440 | fwd_pos = agent.front_pos
441 | fwd_obj = self.grid.get(*fwd_pos)
442 |
443 | if fwd_obj is not None and fwd_obj.can_pickup():
444 | if agent.state.carrying is None:
445 | agent.state.carrying = fwd_obj
446 | self.grid.set(*fwd_pos, None)
447 |
448 | # Drop an object
449 | elif action == Action.drop:
450 | fwd_pos = agent.front_pos
451 | fwd_obj = self.grid.get(*fwd_pos)
452 |
453 | if agent.state.carrying and fwd_obj is None:
454 | agent_present = np.bitwise_and.reduce(
455 | self.agent_states.pos == fwd_pos, axis=1).any()
456 | if not agent_present:
457 | self.grid.set(*fwd_pos, agent.state.carrying)
458 | agent.state.carrying.cur_pos = fwd_pos
459 | agent.state.carrying = None
460 |
461 | # Toggle/activate an object
462 | elif action == Action.toggle:
463 | fwd_pos = agent.front_pos
464 | fwd_obj = self.grid.get(*fwd_pos)
465 |
466 | if fwd_obj is not None:
467 | fwd_obj.toggle(self, agent, fwd_pos)
468 |
469 | # Done action (not used by default)
470 | elif action == Action.done:
471 | pass
472 |
473 | else:
474 | raise ValueError(f"Unknown action: {action}")
475 |
476 | return rewards
477 |
478 | def on_success(
479 | self,
480 | agent: Agent,
481 | rewards: dict[AgentID, SupportsFloat],
482 | terminations: dict[AgentID, bool]):
483 | """
484 | Callback for when an agent completes its mission.
485 |
486 | Parameters
487 | ----------
488 | agent : Agent
489 | Agent that completed its mission
490 | rewards : dict[AgentID, SupportsFloat]
491 | Reward dictionary to be updated
492 | terminations : dict[AgentID, bool]
493 | Termination dictionary to be updated
494 | """
495 | if self.success_termination_mode == 'any':
496 | self.agent_states.terminated = True # terminate all agents
497 | for i in range(self.num_agents):
498 | terminations[i] = True
499 | else:
500 | agent.state.terminated = True # terminate this agent only
501 | terminations[agent.index] = True
502 |
503 | if self.joint_reward:
504 | for i in range(self.num_agents):
505 | rewards[i] = self._reward() # reward all agents
506 | else:
507 | rewards[agent.index] = self._reward() # reward this agent only
508 |
509 | def on_failure(
510 | self,
511 | agent: Agent,
512 | rewards: dict[AgentID, SupportsFloat],
513 | terminations: dict[AgentID, bool]):
514 | """
515 | Callback for when an agent fails its mission prematurely.
516 |
517 | Parameters
518 | ----------
519 | agent : Agent
520 | Agent that failed its mission
521 | rewards : dict[AgentID, SupportsFloat]
522 | Reward dictionary to be updated
523 | terminations : dict[AgentID, bool]
524 | Termination dictionary to be updated
525 | """
526 | if self.failure_termination_mode == 'any':
527 | self.agent_states.terminated = True # terminate all agents
528 | for i in range(self.num_agents):
529 | terminations[i] = True
530 | else:
531 | agent.state.terminated = True # terminate this agent only
532 | terminations[agent.index] = True
533 |
534 | def is_done(self) -> bool:
535 | """
536 | Return whether the current episode is finished (for all agents).
537 | """
538 | truncated = self.step_count >= self.max_steps
539 | return truncated or all(self.agent_states.terminated)
540 |
541 | def __str__(self):
542 | """
543 | Produce a pretty string of the environment's grid along with the agent.
544 | A grid cell is represented by 2-character string, the first one for
545 | the object and the second one for the color.
546 | """
547 | # Map of object types to short string
548 | OBJECT_TO_STR = {
549 | 'wall': 'W',
550 | 'floor': 'F',
551 | 'door': 'D',
552 | 'key': 'K',
553 | 'ball': 'A',
554 | 'box': 'B',
555 | 'goal': 'G',
556 | 'lava': 'V',
557 | }
558 |
559 | # Map agent's direction to short string
560 | AGENT_DIR_TO_STR = {0: '>', 1: 'V', 2: '<', 3: '^'}
561 |
562 | # Get agent locations
563 | location_to_agent = {tuple(agent.pos): agent for agent in self.agents}
564 |
565 | output = ""
566 | for j in range(self.grid.height):
567 | for i in range(self.grid.width):
568 | if (i, j) in location_to_agent:
569 | output += 2 * AGENT_DIR_TO_STR[location_to_agent[i, j].dir]
570 | continue
571 |
572 | tile = self.grid.get(i, j)
573 |
574 | if tile is None:
575 | output += ' '
576 | continue
577 |
578 | if tile.type == 'agent':
579 | output += 2 * AGENT_DIR_TO_STR[tile.dir]
580 | continue
581 |
582 | if tile.type == 'door':
583 | if tile.is_open:
584 | output += '__'
585 | elif tile.is_locked:
586 | output += 'L' + tile.color[0].upper()
587 | else:
588 | output += 'D' + tile.color[0].upper()
589 | continue
590 |
591 | output += OBJECT_TO_STR[tile.type] + tile.color[0].upper()
592 |
593 | if j < self.grid.height - 1:
594 | output += '\n'
595 |
596 | return output
597 |
598 | def _reward(self) -> float:
599 | """
600 | Compute the reward to be given upon success.
601 | """
602 | return 1 - 0.9 * (self.step_count / self.max_steps)
603 |
604 | def place_obj(
605 | self,
606 | obj: WorldObj | None,
607 | top: tuple[int, int] = None,
608 | size: tuple[int, int] = None,
609 | reject_fn: Callable[[MultiGridEnv, tuple[int, int]], bool] | None = None,
610 | max_tries=math.inf) -> tuple[int, int]:
611 | """
612 | Place an object at an empty position in the grid.
613 |
614 | Parameters
615 | ----------
616 | obj: WorldObj
617 | Object to place in the grid
618 | top: tuple[int, int]
619 | Top-left position of the rectangular area where to place the object
620 | size: tuple[int, int]
621 | Width and height of the rectangular area where to place the object
622 | reject_fn: Callable(env, pos) -> bool
623 | Function to filter out potential positions
624 | max_tries: int
625 | Maximum number of attempts to place the object
626 | """
627 | if top is None:
628 | top = (0, 0)
629 | else:
630 | top = (max(top[0], 0), max(top[1], 0))
631 |
632 | if size is None:
633 | size = (self.grid.width, self.grid.height)
634 |
635 | num_tries = 0
636 |
637 | while True:
638 | # This is to handle with rare cases where rejection sampling
639 | # gets stuck in an infinite loop
640 | if num_tries > max_tries:
641 | raise RecursionError("rejection sampling failed in place_obj")
642 |
643 | num_tries += 1
644 |
645 | pos = (
646 | self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
647 | self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
648 | )
649 |
650 | # Don't place the object on top of another object
651 | if self.grid.get(*pos) is not None:
652 | continue
653 |
654 | # Don't place the object where agents are
655 | if np.bitwise_and.reduce(self.agent_states.pos == pos, axis=1).any():
656 | continue
657 |
658 | # Check if there is a filtering criterion
659 | if reject_fn and reject_fn(self, pos):
660 | continue
661 |
662 | break
663 |
664 | self.grid.set(pos[0], pos[1], obj)
665 |
666 | if obj is not None:
667 | obj.init_pos = pos
668 | obj.cur_pos = pos
669 |
670 | return pos
671 |
672 | def put_obj(self, obj: WorldObj, i: int, j: int):
673 | """
674 | Put an object at a specific position in the grid.
675 | """
676 | self.grid.set(i, j, obj)
677 | obj.init_pos = (i, j)
678 | obj.cur_pos = (i, j)
679 |
680 | def place_agent(
681 | self,
682 | agent: Agent,
683 | top=None,
684 | size=None,
685 | rand_dir=True,
686 | max_tries=math.inf) -> tuple[int, int]:
687 | """
688 | Set agent starting point at an empty position in the grid.
689 | """
690 | agent.state.pos = (-1, -1)
691 | pos = self.place_obj(None, top, size, max_tries=max_tries)
692 | agent.state.pos = pos
693 |
694 | if rand_dir:
695 | agent.state.dir = self._rand_int(0, 4)
696 |
697 | return pos
698 |
699 | def get_pov_render(self, *args, **kwargs):
700 | """
701 | Render an agent's POV observation for visualization.
702 | """
703 | raise NotImplementedError(
704 | "POV rendering not supported for multiagent environments."
705 | )
706 |
707 | def get_full_render(self, highlight: bool, tile_size: int):
708 | """
709 | Render a non-partial observation for visualization.
710 | """
711 | # Compute agent visibility masks
712 | obs_shape = self.agents[0].observation_space['image'].shape[:-1]
713 | vis_masks = np.zeros((self.num_agents, *obs_shape), dtype=bool)
714 | for i, agent_obs in self.gen_obs().items():
715 | vis_masks[i] = (agent_obs['image'][..., 0] != Type.unseen.to_index())
716 |
717 | # Mask of which cells to highlight
718 | highlight_mask = np.zeros((self.width, self.height), dtype=bool)
719 |
720 | for agent in self.agents:
721 | # Compute the world coordinates of the bottom-left corner
722 | # of the agent's view area
723 | f_vec = agent.state.dir.to_vec()
724 | r_vec = np.array((-f_vec[1], f_vec[0]))
725 | top_left = (
726 | agent.state.pos
727 | + f_vec * (agent.view_size - 1)
728 | - r_vec * (agent.view_size // 2)
729 | )
730 |
731 | # For each cell in the visibility mask
732 | for vis_j in range(0, agent.view_size):
733 | for vis_i in range(0, agent.view_size):
734 | # If this cell is not visible, don't highlight it
735 | if not vis_masks[agent.index][vis_i, vis_j]:
736 | continue
737 |
738 | # Compute the world coordinates of this cell
739 | abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
740 |
741 | if abs_i < 0 or abs_i >= self.width:
742 | continue
743 | if abs_j < 0 or abs_j >= self.height:
744 | continue
745 |
746 | # Mark this cell to be highlighted
747 | highlight_mask[abs_i, abs_j] = True
748 |
749 | # Render the whole grid
750 | img = self.grid.render(
751 | tile_size,
752 | agents=self.agents,
753 | highlight_mask=highlight_mask if highlight else None,
754 | )
755 |
756 | return img
757 |
758 | def get_frame(
759 | self,
760 | highlight: bool = True,
761 | tile_size: int = TILE_PIXELS,
762 | agent_pov: bool = False) -> ndarray[np.uint8]:
763 | """
764 | Returns an RGB image corresponding to the whole environment.
765 |
766 | Parameters
767 | ----------
768 | highlight: bool
769 | Whether to highlight agents' field of view (with a lighter gray color)
770 | tile_size: int
771 | How many pixels will form a tile from the NxM grid
772 | agent_pov: bool
773 | Whether to render agent's POV or the full environment
774 |
775 | Returns
776 | -------
777 | frame: ndarray of shape (H, W, 3)
778 | A frame representing RGB values for the HxW pixel image
779 | """
780 | if agent_pov:
781 | return self.get_pov_render(tile_size)
782 | else:
783 | return self.get_full_render(highlight, tile_size)
784 |
785 | def render(self):
786 | """
787 | Render the environment.
788 | """
789 | img = self.get_frame(self.highlight, self.tile_size)
790 |
791 | if self.render_mode == 'human':
792 | img = np.transpose(img, axes=(1, 0, 2))
793 | screen_size = (
794 | self.screen_size * min(img.shape[0] / img.shape[1], 1.0),
795 | self.screen_size * min(img.shape[1] / img.shape[0], 1.0),
796 | )
797 | if self.render_size is None:
798 | self.render_size = img.shape[:2]
799 | if self.window is None:
800 | pygame.init()
801 | pygame.display.init()
802 | pygame.display.set_caption(f'multigrid - {self.__class__.__name__}')
803 | self.window = pygame.display.set_mode(screen_size)
804 | if self.clock is None:
805 | self.clock = pygame.time.Clock()
806 | surf = pygame.surfarray.make_surface(img)
807 |
808 | # Create background with mission description
809 | offset = surf.get_size()[0] * 0.1
810 | # offset = 32 if self.agent_pov else 64
811 | bg = pygame.Surface(
812 | (int(surf.get_size()[0] + offset), int(surf.get_size()[1] + offset))
813 | )
814 | bg.convert()
815 | bg.fill((255, 255, 255))
816 | bg.blit(surf, (offset / 2, 0))
817 |
818 | bg = pygame.transform.smoothscale(bg, screen_size)
819 |
820 | font_size = 22
821 | text = str(self.mission)
822 | font = pygame.freetype.SysFont(pygame.font.get_default_font(), font_size)
823 | text_rect = font.get_rect(text, size=font_size)
824 | text_rect.center = bg.get_rect().center
825 | text_rect.y = bg.get_height() - font_size * 1.5
826 | font.render_to(bg, text_rect, text, size=font_size)
827 |
828 | self.window.blit(bg, (0, 0))
829 | pygame.event.pump()
830 | self.clock.tick(self.metadata['render_fps'])
831 | pygame.display.flip()
832 |
833 | elif self.render_mode == 'rgb_array':
834 | return img
835 |
836 | def close(self):
837 | """
838 | Close the rendering window.
839 | """
840 | if self.window:
841 | pygame.quit()
842 |
--------------------------------------------------------------------------------