├── multigrid ├── utils │ ├── __init__.py │ ├── misc.py │ ├── enum.py │ ├── random.py │ ├── minigrid_interface.py │ ├── rendering.py │ └── obs.py ├── __init__.py ├── core │ ├── __init__.py │ ├── actions.py │ ├── constants.py │ ├── mission.py │ ├── grid.py │ ├── agent.py │ ├── roomgrid.py │ └── world_object.py ├── envs │ ├── __init__.py │ ├── playground.py │ ├── blockedunlockpickup.py │ ├── empty.py │ ├── redbluedoors.py │ └── locked_hallway.py ├── pettingzoo │ └── __init__.py ├── rllib │ └── __init__.py ├── wrappers.py └── base.py ├── .gitignore ├── setup.py ├── .github └── workflows │ └── python-publish.yml ├── scripts ├── README.md ├── visualize.py └── train.py ├── pyproject.toml ├── README.md └── LICENSE /multigrid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /multigrid/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import MultiGridEnv 2 | from .core import * 3 | 4 | __version__ = '0.1.0' 5 | -------------------------------------------------------------------------------- /multigrid/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .actions import Action 2 | from .agent import Agent, AgentState 3 | from .constants import * 4 | from .grid import Grid 5 | from .mission import MissionSpace 6 | from .world_object import Ball, Box, Door, Floor, Goal, Key, Lava, Wall, WorldObj 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *__pycache__ 3 | *egg-info 4 | trained_models 5 | 6 | # PyPI 7 | build/* 8 | dist/* 9 | .idea/ 10 | 11 | # Docs 12 | .DS_Store 13 | _site 14 | .jekyll-cache 15 | __pycache__ 16 | .vscode/ 17 | docs/_build/ 18 | /docs/environments/**/*.md 19 | /docs/environments/**/*.html 20 | !/docs/environments/**/index.md 21 | 22 | # Virtual environments 23 | .env 24 | .venv 25 | venv 26 | -------------------------------------------------------------------------------- /multigrid/core/actions.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | 4 | 5 | class Action(enum.IntEnum): 6 | """ 7 | Enumeration of possible actions. 8 | """ 9 | left = 0 #: Turn left 10 | right = enum.auto() #: Turn right 11 | forward = enum.auto() #: Move forward 12 | pickup = enum.auto() #: Pick up an object 13 | drop = enum.auto() #: Drop an object 14 | toggle = enum.auto() #: Toggle / activate an object 15 | done = enum.auto() #: Done completing task 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import setuptools 3 | 4 | 5 | 6 | PACKAGE_DIR = pathlib.Path(__file__).absolute().parent 7 | 8 | 9 | 10 | def get_version(): 11 | """ 12 | Gets the multigrid version. 13 | """ 14 | path = PACKAGE_DIR / 'multigrid' / '__init__.py' 15 | content = path.read_text() 16 | 17 | for line in content.splitlines(): 18 | if line.startswith('__version__'): 19 | return line.strip().split()[-1].strip().strip("'") 20 | 21 | raise RuntimeError("bad version data in __init__.py") 22 | 23 | def get_description(): 24 | """ 25 | Gets the description from the readme. 26 | """ 27 | with open("README.md") as fh: 28 | long_description = "" 29 | header_count = 0 30 | for line in fh: 31 | if line.startswith('##'): 32 | header_count += 1 33 | if header_count < 2: 34 | long_description += line 35 | else: 36 | break 37 | return long_description 38 | 39 | setuptools.setup( 40 | name='multigrid', 41 | version=get_version(), 42 | long_description=get_description(), 43 | ) 44 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /multigrid/utils/misc.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Any 3 | from ..core.constants import Direction 4 | 5 | 6 | 7 | @functools.cache 8 | def front_pos(agent_x: int, agent_y: int, agent_dir: int): 9 | """ 10 | Get the position in front of an agent. 11 | """ 12 | dx, dy = Direction(agent_dir).to_vec() 13 | return (agent_x + dx, agent_y + dy) 14 | 15 | 16 | 17 | class PropertyAlias(property): 18 | """ 19 | A class property that is an alias for an attribute property. 20 | 21 | Instead of:: 22 | 23 | @property 24 | def x(self): 25 | self.attr.x 26 | 27 | @x.setter 28 | def x(self, value): 29 | self.attr.x = value 30 | 31 | we can simply just declare:: 32 | 33 | x = PropertyAlias('attr', 'x') 34 | """ 35 | 36 | def __init__(self, attr_name: str, attr_property_name: str, doc: str = None) -> None: 37 | """ 38 | Parameters 39 | ---------- 40 | attr_name : str 41 | Name of the base attribute 42 | attr_property : property 43 | Property from the base attribute class 44 | doc : str 45 | Docstring to append to the property's original docstring 46 | """ 47 | prop = lambda obj: getattr(type(getattr(obj, attr_name)), attr_property_name) 48 | fget = lambda obj: prop(obj).fget(getattr(obj, attr_name)) 49 | fset = lambda obj, value: prop(obj).fset(getattr(obj, attr_name), value) 50 | fdel = lambda obj: prop(obj).fdel(getattr(obj, attr_name)) 51 | super().__init__(fget, fset, fdel, doc=doc) 52 | self.__doc__ = doc 53 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Training MultiGrid agents with RLlib 2 | 3 | MultiGrid is compatible with RLlib's multi-agent API. 4 | 5 | This folder provides scripts to train and visualize agents over MultiGrid environments. 6 | 7 | ## Requirements 8 | 9 | Using MultiGrid environments with RLlib requires installation of [rllib](https://docs.ray.io/en/latest/rllib/index.html), and one of [PyTorch](https://pytorch.org/) or [TensorFlow](https://www.tensorflow.org/). 10 | 11 | ## Getting Started 12 | 13 | Train 2 agents on the `MultiGrid-Empty-8x8-v0` environment using the PPO algorithm: 14 | 15 | python train.py --algo PPO --env MultiGrid-Empty-8x8-v0 --num-agents 2 --save-dir ~/saved/empty8x8/ 16 | 17 | Visualize behavior from trained agents policies: 18 | 19 | python visualize.py --algo PPO --env MultiGrid-Empty-8x8-v0 --num-agents 2 --load-dir ~/saved/empty8x8/ 20 | 21 | For more options, run ``python train.py --help`` and ``python visualize.py --help``. 22 | 23 | ## Environments 24 | 25 | All of the environment configurations registered in [`multigrid.envs`](../multigrid/envs/__init__.py) can also be used with RLlib, and are registered via `import multigrid.rllib`. 26 | 27 | To use a specific MultiGrid environment configuration by name: 28 | 29 | >>> import multigrid.rllib 30 | >>> from ray.rllib.algorithms.ppo import PPOConfig 31 | >>> algorithm_config = PPOConfig().environment(env='MultiGrid-Empty-8x8-v0') 32 | 33 | To convert a custom `MultiGridEnv` to an RLlib `MultiAgentEnv`: 34 | 35 | >>> from multigrid.rllib import to_rllib_env 36 | >>> MyRLLibEnvClass = to_rllib_env(MyEnvClass) 37 | >>> algorithm_config = PPOConfig().environment(env=MyRLLibEnvClass) 38 | -------------------------------------------------------------------------------- /multigrid/utils/enum.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import aenum as enum 4 | import functools 5 | import numpy as np 6 | 7 | from numpy.typing import ArrayLike, NDArray as ndarray 8 | from typing import Any 9 | 10 | 11 | 12 | ### Helper Functions 13 | 14 | @functools.cache 15 | def _enum_array(enum_cls: enum.EnumMeta): 16 | """ 17 | Return an array of all values of the given enum. 18 | 19 | Parameters 20 | ---------- 21 | enum_cls : enum.EnumMeta 22 | Enum class 23 | """ 24 | return np.array([item.value for item in enum_cls]) 25 | 26 | @functools.cache 27 | def _enum_index(enum_item: enum.Enum): 28 | """ 29 | Return the index of the given enum item. 30 | 31 | Parameters 32 | ---------- 33 | enum_item : enum.Enum 34 | Enum item 35 | """ 36 | return list(enum_item.__class__).index(enum_item) 37 | 38 | 39 | 40 | ### Enumeration 41 | 42 | class IndexedEnum(enum.Enum): 43 | """ 44 | Enum where each member has a corresponding integer index. 45 | """ 46 | 47 | def __int__(self): 48 | return self.to_index() 49 | 50 | @classmethod 51 | def add_item(cls, name: str, value: Any): 52 | """ 53 | Add a new item to the enumeration. 54 | 55 | Parameters 56 | ---------- 57 | name : str 58 | Name of the new enum item 59 | value : Any 60 | Value of the new enum item 61 | """ 62 | enum.extend_enum(cls, name, value) 63 | _enum_array.cache_clear() 64 | _enum_index.cache_clear() 65 | 66 | @classmethod 67 | def from_index(cls, index: int | ArrayLike[int]) -> enum.Enum | ndarray: 68 | """ 69 | Return the enum item corresponding to the given index. 70 | Also supports vector inputs. 71 | 72 | Parameters 73 | ---------- 74 | index : int or ArrayLike[int] 75 | Enum index (or array of indices) 76 | 77 | Returns 78 | ------- 79 | enum.Enum or ndarray 80 | Enum item (or array of enum item values) 81 | """ 82 | out = _enum_array(cls)[index] 83 | return cls(out) if out.ndim == 0 else out 84 | 85 | def to_index(self) -> int: 86 | """ 87 | Return the integer index of this enum item. 88 | """ 89 | return _enum_index(self) 90 | -------------------------------------------------------------------------------- /multigrid/envs/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ************ 3 | Environments 4 | ************ 5 | 6 | This package contains implementations of several MultiGrid environments. 7 | 8 | ************** 9 | Configurations 10 | ************** 11 | 12 | * `Blocked Unlock Pickup <./multigrid.envs.blockedunlockpickup>`_ 13 | * ``MultiGrid-BlockedUnlockPickup-v0`` 14 | * `Empty <./multigrid.envs.empty>`_ 15 | * ``MultiGrid-Empty-5x5-v0`` 16 | * ``MultiGrid-Empty-Random-5x5-v0`` 17 | * ``MultiGrid-Empty-6x6-v0`` 18 | * ``MultiGrid-Empty-Random-6x6-v0`` 19 | * ``MultiGrid-Empty-8x8-v0`` 20 | * ``MultiGrid-Empty-16x16-v0`` 21 | * `Locked Hallway <./multigrid.envs.locked_hallway>`_ 22 | * ``MultiGrid-LockedHallway-2Rooms-v0`` 23 | * ``MultiGrid-LockedHallway-4Rooms-v0`` 24 | * ``MultiGrid-LockedHallway-6Rooms-v0`` 25 | * `Playground <./multigrid.envs.playground>`_ 26 | * ``MultiGrid-Playground-v0`` 27 | * `Red Blue Doors <./multigrid.envs.redbluedoors>`_ 28 | * ``MultiGrid-RedBlueDoors-6x6-v0`` 29 | * ``MultiGrid-RedBlueDoors-8x8-v0`` 30 | """ 31 | 32 | from .blockedunlockpickup import BlockedUnlockPickupEnv 33 | from .empty import EmptyEnv 34 | from .locked_hallway import LockedHallwayEnv 35 | from .playground import PlaygroundEnv 36 | from .redbluedoors import RedBlueDoorsEnv 37 | 38 | CONFIGURATIONS = { 39 | 'MultiGrid-BlockedUnlockPickup-v0': (BlockedUnlockPickupEnv, {}), 40 | 'MultiGrid-Empty-5x5-v0': (EmptyEnv, {'size': 5}), 41 | 'MultiGrid-Empty-Random-5x5-v0': (EmptyEnv, {'size': 5, 'agent_start_pos': None}), 42 | 'MultiGrid-Empty-6x6-v0': (EmptyEnv, {'size': 6}), 43 | 'MultiGrid-Empty-Random-6x6-v0': (EmptyEnv, {'size': 6, 'agent_start_pos': None}), 44 | 'MultiGrid-Empty-8x8-v0': (EmptyEnv, {}), 45 | 'MultiGrid-Empty-16x16-v0': (EmptyEnv, {'size': 16}), 46 | 'MultiGrid-LockedHallway-2Rooms-v0': (LockedHallwayEnv, {'num_rooms': 2}), 47 | 'MultiGrid-LockedHallway-4Rooms-v0': (LockedHallwayEnv, {'num_rooms': 4}), 48 | 'MultiGrid-LockedHallway-6Rooms-v0': (LockedHallwayEnv, {'num_rooms': 6}), 49 | 'MultiGrid-Playground-v0': (PlaygroundEnv, {}), 50 | 'MultiGrid-RedBlueDoors-6x6-v0': (RedBlueDoorsEnv, {'size': 6}), 51 | 'MultiGrid-RedBlueDoors-8x8-v0': (RedBlueDoorsEnv, {'size': 8}), 52 | } 53 | 54 | # Register environments with gymnasium 55 | from gymnasium.envs.registration import register 56 | for name, (env_cls, config) in CONFIGURATIONS.items(): 57 | register(id=name, entry_point=env_cls, kwargs=config) 58 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Package ###################################################################### 2 | 3 | [build-system] 4 | requires = ["setuptools >= 61.0.0"] 5 | build-backend = "setuptools.build_meta" 6 | 7 | [project] 8 | name = "multigrid" 9 | description = "Fast multi-agent gridworld reinforcement learning environments." 10 | readme = "README.md" 11 | requires-python = ">= 3.9" 12 | authors = [{ name = "Ini Oguntola", email = "ini@ini.io" }] 13 | license = { text = "Apache License" } 14 | keywords = ["Memory, Environment, Agent, Multi-Agent, RL, Gymnasium, Cooperative, Competitive"] 15 | classifiers = [ 16 | "Development Status :: 4 - Beta", # change to `5 - Production/Stable` when ready 17 | "License :: OSI Approved :: MIT License", 18 | "Programming Language :: Python :: 3", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | 'Intended Audience :: Science/Research', 23 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 24 | ] 25 | dependencies = [ 26 | "aenum>=1.3.0", 27 | "numba>=0.53.0", 28 | "numpy>=1.18.0", 29 | "gymnasium>=0.26", 30 | "pygame>=2.2.0", 31 | ] 32 | dynamic = ["version"] 33 | 34 | [tool.setuptools] 35 | include-package-data = true 36 | 37 | [tool.setuptools.packages.find] 38 | include = ["multigrid*"] 39 | 40 | # Linters and Test tools ####################################################### 41 | 42 | [tool.black] 43 | safe = true 44 | 45 | [tool.isort] 46 | atomic = true 47 | profile = "black" 48 | append_only = true 49 | src_paths = ["multigrid"] 50 | add_imports = [ "from __future__ import annotations" ] 51 | 52 | [tool.pyright] 53 | include = [ 54 | "multigrid/**", 55 | ] 56 | 57 | exclude = [ 58 | "**/node_modules", 59 | "**/__pycache__", 60 | ] 61 | 62 | strict = [] 63 | 64 | typeCheckingMode = "basic" 65 | pythonVersion = "3.9" 66 | typeshedPath = "typeshed" 67 | enableTypeIgnoreComments = true 68 | 69 | # This is required as the CI pre-commit does not download the module (i.e. numpy) 70 | # Therefore, we have to ignore missing imports 71 | reportMissingImports = "none" 72 | 73 | reportUnknownMemberType = "none" 74 | reportUnknownParameterType = "none" 75 | reportUnknownVariableType = "none" 76 | reportUnknownArgumentType = "none" 77 | reportPrivateUsage = "warning" 78 | reportUntypedFunctionDecorator = "none" 79 | reportMissingTypeStubs = false 80 | reportUnboundVariable = "warning" 81 | reportGeneralTypeIssues ="none" 82 | reportPrivateImportUsage = "none" 83 | -------------------------------------------------------------------------------- /multigrid/utils/random.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Iterable, TypeVar 3 | from ..core.constants import Color 4 | 5 | T = TypeVar('T') 6 | 7 | 8 | 9 | class RandomMixin: 10 | """ 11 | Mixin class for random number generation. 12 | """ 13 | 14 | def __init__(self, random_generator: np.random.Generator): 15 | """ 16 | Parameters 17 | ---------- 18 | random_generator : np.random.Generator 19 | Random number generator 20 | """ 21 | self.__np_random = random_generator 22 | 23 | def _rand_int(self, low: int, high: int) -> int: 24 | """ 25 | Generate random integer in range [low, high). 26 | 27 | :meta public: 28 | """ 29 | return self.__np_random.integers(low, high) 30 | 31 | def _rand_float(self, low: float, high: float) -> float: 32 | """ 33 | Generate random float in range [low, high). 34 | 35 | :meta public: 36 | """ 37 | return self.__np_random.uniform(low, high) 38 | 39 | def _rand_bool(self) -> bool: 40 | """ 41 | Generate random boolean value. 42 | 43 | :meta public: 44 | """ 45 | return self.__np_random.integers(0, 2) == 0 46 | 47 | def _rand_elem(self, iterable: Iterable[T]) -> T: 48 | """ 49 | Pick a random element in a list. 50 | 51 | :meta public: 52 | """ 53 | lst = list(iterable) 54 | idx = self._rand_int(0, len(lst)) 55 | return lst[idx] 56 | 57 | def _rand_subset(self, iterable: Iterable[T], num_elems: int) -> list[T]: 58 | """ 59 | Sample a random subset of distinct elements of a list. 60 | 61 | :meta public: 62 | """ 63 | lst = list(iterable) 64 | assert num_elems <= len(lst) 65 | 66 | out: list[T] = [] 67 | 68 | while len(out) < num_elems: 69 | elem = self._rand_elem(lst) 70 | lst.remove(elem) 71 | out.append(elem) 72 | 73 | return out 74 | 75 | def _rand_perm(self, iterable: Iterable[T]) -> list[T]: 76 | """ 77 | Randomly permute a list. 78 | 79 | :meta public: 80 | """ 81 | lst = list(iterable) 82 | self.__np_random.shuffle(lst) 83 | return lst 84 | 85 | def _rand_color(self) -> Color: 86 | """ 87 | Generate a random color. 88 | 89 | :meta public: 90 | """ 91 | return self._rand_elem(Color) 92 | 93 | def _rand_pos( 94 | self, x_low: int, x_high: int, y_low: int, y_high: int) -> tuple[int, int]: 95 | """ 96 | Generate a random (x, y) position tuple. 97 | 98 | :meta public: 99 | """ 100 | return ( 101 | self.__np_random.integers(x_low, x_high), 102 | self.__np_random.integers(y_low, y_high), 103 | ) 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiGrid 2 | 3 |
4 |

5 | Blocked Unlock Pickup: 2 Agents 6 |

7 |
8 | 9 | The **MultiGrid** library provides contains a collection of fast multi-agent discrete gridworld environments for reinforcement learning in [Gymnasium](https://github.com/Farama-Foundation/Gymnasium). This is a multi-agent extension of the [minigrid](https://github.com/Farama-Foundation/Minigrid) library, and the interface is designed to be as similar as possible. 10 | 11 | The environments are designed to be fast and easily customizable. Compared to minigrid, the underlying gridworld logic is **significantly optimized**, with environment simulation 10x to 20x faster by our benchmarks. 12 | 13 | Documentation for this library can be found at [ini.github.io/docs/multigrid](https://ini.github.io/docs/multigrid). 14 | 15 | ## Installation 16 | 17 | pip install multigrid 18 | 19 | Or alternatively, for an editable install: 20 | 21 | git clone https://github.com/ini/multigrid 22 | cd multigrid 23 | pip install -e . 24 | 25 | This package requires Python 3.9 or later. 26 | 27 | ## Environments 28 | 29 | The `multigrid.envs` package provides implementations of several multi-agent environments. [You can find the full list here](https://ini.github.io/docs/multigrid/multigrid/multigrid.envs). 30 | 31 | ## API 32 | 33 | MultiGrid follows the same pattern as RLlib's [MultiAgentEnv API](https://docs.ray.io/en/latest/rllib/rllib-env.html#multi-agent-and-hierarchical) and PettingZoo's [ParallelEnv API](https://pettingzoo.farama.org/api/parallel/). 34 | 35 | ```python 36 | import gymnasium as gym 37 | import multigrid.envs 38 | 39 | env = gym.make('MultiGrid-Empty-8x8-v0', agents=2, render_mode='human') 40 | 41 | observations, infos = env.reset() 42 | while not env.unwrapped.is_done(): 43 | # this is where you would insert your policy / policies 44 | actions = {agent.index: agent.action_space.sample() for agent in env.unwrapped.agents} 45 | observations, rewards, terminations, truncations, infos = env.step(actions) 46 | 47 | env.close() 48 | ``` 49 | 50 | More information about using MultiGrid directly with other APIs: 51 | * [PettingZoo](https://ini.github.io/docs/multigrid/multigrid/multigrid.pettingzoo) 52 | * [RLlib](https://ini.github.io/docs/multigrid/multigrid/multigrid.rllib) 53 | 54 | ## Training Agents 55 | 56 | See the [scripts folder](./scripts) for an example training with RLlib. 57 | 58 | ## Documentation 59 | 60 | Documentation for this package can be found at [ini.github.io/docs/multigrid](https://ini.github.io/docs/multigrid). 61 | 62 | ## Citation 63 | 64 | To cite this project please use: 65 | 66 | ``` 67 | @article{oguntola2023theory, 68 | title={Theory of mind as intrinsic motivation for multi-agent reinforcement learning}, 69 | author={Oguntola, Ini and Campbell, Joseph and Stepputtis, Simon and Sycara, Katia}, 70 | journal={arXiv preprint arXiv:2307.01158}, 71 | year={2023} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /multigrid/core/constants.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import numpy as np 3 | 4 | from numpy.typing import NDArray as ndarray 5 | from ..utils.enum import IndexedEnum 6 | 7 | 8 | 9 | #: Tile size for rendering grid cell 10 | TILE_PIXELS = 32 11 | 12 | COLORS = { 13 | 'red': np.array([255, 0, 0]), 14 | 'green': np.array([0, 255, 0]), 15 | 'blue': np.array([0, 0, 255]), 16 | 'purple': np.array([112, 39, 195]), 17 | 'yellow': np.array([255, 255, 0]), 18 | 'grey': np.array([100, 100, 100]), 19 | } 20 | 21 | DIR_TO_VEC = [ 22 | # Pointing right (positive X) 23 | np.array((1, 0)), 24 | # Down (positive Y) 25 | np.array((0, 1)), 26 | # Pointing left (negative X) 27 | np.array((-1, 0)), 28 | # Up (negative Y) 29 | np.array((0, -1)), 30 | ] 31 | 32 | 33 | 34 | class Type(str, IndexedEnum): 35 | """ 36 | Enumeration of object types. 37 | """ 38 | unseen = 'unseen' 39 | empty = 'empty' 40 | wall = 'wall' 41 | floor = 'floor' 42 | door = 'door' 43 | key = 'key' 44 | ball = 'ball' 45 | box = 'box' 46 | goal = 'goal' 47 | lava = 'lava' 48 | agent = 'agent' 49 | 50 | 51 | class Color(str, IndexedEnum): 52 | """ 53 | Enumeration of object colors. 54 | """ 55 | red = 'red' 56 | green = 'green' 57 | blue = 'blue' 58 | purple = 'purple' 59 | yellow = 'yellow' 60 | grey = 'grey' 61 | 62 | @classmethod 63 | def add_color(cls, name: str, rgb: ndarray[np.uint8]): 64 | """ 65 | Add a new color to the ``Color`` enumeration. 66 | 67 | Parameters 68 | ---------- 69 | name : str 70 | Name of the new color 71 | rgb : ndarray[np.uint8] of shape (3,) 72 | RGB value of the new color 73 | """ 74 | cls.add_item(name, name) 75 | COLORS[name] = np.asarray(rgb, dtype=np.uint8) 76 | 77 | @staticmethod 78 | def cycle(n: int) -> tuple['Color', ...]: 79 | """ 80 | Return a cycle of ``n`` colors. 81 | """ 82 | return tuple(Color.from_index(i % len(Color)) for i in range(int(n))) 83 | 84 | def rgb(self) -> ndarray[np.uint8]: 85 | """ 86 | Return the RGB value of this ``Color``. 87 | """ 88 | return COLORS[self] 89 | 90 | 91 | class State(str, IndexedEnum): 92 | """ 93 | Enumeration of object states. 94 | """ 95 | open = 'open' 96 | closed = 'closed' 97 | locked = 'locked' 98 | 99 | 100 | class Direction(enum.IntEnum): 101 | """ 102 | Enumeration of agent directions. 103 | """ 104 | right = 0 105 | down = 1 106 | left = 2 107 | up = 3 108 | 109 | def to_vec(self) -> ndarray[np.int8]: 110 | """ 111 | Return the vector corresponding to this ``Direction``. 112 | """ 113 | return DIR_TO_VEC[self] 114 | 115 | 116 | 117 | ### Minigrid Compatibility 118 | 119 | OBJECT_TO_IDX = {t: t.to_index() for t in Type} 120 | IDX_TO_OBJECT = {t.to_index(): t for t in Type} 121 | COLOR_TO_IDX = {c: c.to_index() for c in Color} 122 | IDX_TO_COLOR = {c.to_index(): c for c in Color} 123 | STATE_TO_IDX = {s: s.to_index() for s in State} 124 | COLOR_NAMES = sorted(list(Color)) 125 | -------------------------------------------------------------------------------- /multigrid/pettingzoo/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package provides tools for using MultiGrid environments with 3 | the PettingZoo ParallelEnv API. 4 | 5 | ***** 6 | Usage 7 | ***** 8 | 9 | Wrap an environment instance with :class:`.PettingZooWrapper`: 10 | 11 | >>> import gymnasium as gym 12 | >>> import multigrid.envs 13 | >>> env = gym.make('MultiGrid-Empty-8x8-v0', agents=2, render_mode='human') 14 | 15 | >>> from multigrid.pettingzoo import PettingZooWrapper 16 | >>> env = PettingZooWrapper(env) 17 | 18 | Wrap an environment class with :func:`.to_pettingzoo_env()`: 19 | 20 | >>> from multigrid.envs import EmptyEnv 21 | >>> from multigrid.pettingzoo import to_pettingzoo_env 22 | >>> PZEnv = to_pettingzoo_env(EmptyEnv, metadata={'name': 'empty_v0'}) 23 | >>> env = PZEnv(agents=2, render_mode='human') 24 | """ 25 | 26 | from __future__ import annotations 27 | 28 | import gymnasium as gym 29 | 30 | from gymnasium import spaces 31 | from pettingzoo import ParallelEnv 32 | from typing import Any 33 | 34 | from ..base import AgentID, MultiGridEnv 35 | 36 | 37 | 38 | class PettingZooWrapper(ParallelEnv): 39 | """ 40 | Wrapper for a ``MultiGridEnv`` environment that implements the 41 | PettingZoo ``ParallelEnv`` interface. 42 | """ 43 | 44 | def __init__(self, env: MultiGridEnv): 45 | self.env = env 46 | self.reset = self.env.reset 47 | self.step = self.env.step 48 | self.render = self.env.render 49 | self.close = self.env.close 50 | self.metadata = {} 51 | 52 | @property 53 | def agents(self) -> list[AgentID]: 54 | if self.env.is_done(): 55 | return [] 56 | return [agent.index for agent in self.env.agents if not agent.terminated] 57 | 58 | @property 59 | def possible_agents(self) -> list[AgentID]: 60 | return [agent.index for agent in self.env.agents] 61 | 62 | @property 63 | def observation_spaces(self) -> dict[AgentID, spaces.Space]: 64 | return dict(self.env.observation_space) 65 | 66 | @property 67 | def action_spaces(self) -> dict[AgentID, spaces.Space]: 68 | return dict(self.env.action_space) 69 | 70 | @property 71 | def render_mode(self) -> str | None: 72 | return self.env.render_mode 73 | 74 | def observation_space(self, agent_id: AgentID) -> spaces.Space: 75 | return self.env.observation_space[agent_id] 76 | 77 | def action_space(self, agent_id: AgentID) -> spaces.Space: 78 | return self.env.action_space[agent_id] 79 | 80 | 81 | 82 | def to_pettingzoo_env( 83 | env_cls: type[MultiGridEnv], 84 | *wrappers: gym.Wrapper, 85 | metadata: dict[str, Any] = {}) -> type[ParallelEnv]: 86 | """ 87 | Convert a ``MultiGridEnv`` environment class to a PettingZoo ``ParallelEnv`` class. 88 | 89 | Note that this is a wrapper around the environment **class**, 90 | not environment instances. 91 | 92 | Parameters 93 | ---------- 94 | env_cls : type[MultiGridEnv] 95 | ``MultiGridEnv`` environment class 96 | wrappers : gym.Wrapper 97 | Gym wrappers to apply to the environment 98 | metadata : dict[str, Any] 99 | Environment metadata 100 | 101 | Returns 102 | ------- 103 | pettingzoo_env_cls : type[ParallelEnv] 104 | PettingZoo ``ParallelEnv`` environment class 105 | """ 106 | class PettingZooEnv(PettingZooWrapper): 107 | def __init__(self, *args, **kwargs): 108 | env = env_cls(*args, **kwargs) 109 | for wrapper in wrappers: 110 | env = wrapper(env) 111 | super().__init__(env) 112 | 113 | PettingZooEnv.__name__ = f"PettingZoo_{env_cls.__name__}" 114 | PettingZooEnv.metadata = metadata 115 | return PettingZooEnv 116 | -------------------------------------------------------------------------------- /multigrid/rllib/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package provides tools for using MultiGrid environments with 3 | the RLlib MultiAgentEnv API. 4 | 5 | ***** 6 | Usage 7 | ***** 8 | 9 | Use a specific environment configuration from :mod:`multigrid.envs` by name: 10 | 11 | >>> import multigrid.rllib # registers environment configurations with RLlib 12 | >>> from ray.rllib.algorithms.ppo import PPOConfig 13 | >>> algorithm_config = PPOConfig().environment(env='MultiGrid-Empty-8x8-v0') 14 | 15 | Wrap an environment instance with :class:`.RLlibWrapper`: 16 | 17 | >>> import gymnasium as gym 18 | >>> import multigrid.envs 19 | >>> env = gym.make('MultiGrid-Empty-8x8-v0', agents=2, render_mode='human') 20 | 21 | >>> from multigrid.rllib import RLlibWrapper 22 | >>> env = RLlibWrapper(env) 23 | 24 | Wrap an environment class with :func:`.to_rllib_env()`: 25 | 26 | >>> from multigrid.envs import EmptyEnv 27 | >>> from multigrid.rllib import to_rllib_env 28 | >>> MyEnv = to_rllib_env(EmptyEnv, default_config={'size': 8}) 29 | >>> config = {'agents': 2, 'render_mode': 'human'} 30 | >>> env = MyEnv(config) 31 | """ 32 | 33 | import gymnasium as gym 34 | 35 | from ray.rllib.env import MultiAgentEnv 36 | from ray.tune.registry import register_env 37 | 38 | from ..base import MultiGridEnv 39 | from ..envs import CONFIGURATIONS 40 | from ..wrappers import OneHotObsWrapper 41 | 42 | 43 | 44 | class RLlibWrapper(MultiAgentEnv): 45 | """ 46 | Wrapper for a ``MultiGridEnv`` environment that implements the 47 | RLlib ``MultiAgentEnv`` interface. 48 | """ 49 | 50 | def __init__(self, env: MultiGridEnv): 51 | super().__init__() 52 | self.env = env 53 | self.agents = list(range(len(env.unwrapped.agents))) 54 | self.possible_agents = self.agents[:] 55 | 56 | def reset(self, *args, **kwargs): 57 | return self.env.reset(*args, **kwargs) 58 | 59 | def step(self, *args, **kwargs): 60 | obs, rewards, terminations, truncations, infos = self.env.step(*args, **kwargs) 61 | terminations['__all__'] = all(terminations.values()) 62 | truncations['__all__'] = all(truncations.values()) 63 | return obs, rewards, terminations, truncations, infos 64 | 65 | def get_observation_space(self, agent_index: int): 66 | return self.env.unwrapped.agents[agent_index].observation_space 67 | 68 | def get_action_space(self, agent_index: int): 69 | return self.env.unwrapped.agents[agent_index].action_space 70 | 71 | 72 | def to_rllib_env( 73 | env_cls: type[MultiGridEnv], 74 | *wrappers: gym.Wrapper, 75 | default_config: dict = {}) -> type[MultiAgentEnv]: 76 | """ 77 | Convert a ``MultiGridEnv`` environment class to an RLLib ``MultiAgentEnv`` class. 78 | 79 | Note that this is a wrapper around the environment **class**, 80 | not environment instances. 81 | 82 | Parameters 83 | ---------- 84 | env_cls : type[MultiGridEnv] 85 | ``MultiGridEnv`` environment class 86 | wrappers : gym.Wrapper 87 | Gym wrappers to apply to the environment 88 | default_config : dict 89 | Default configuration for the environment 90 | 91 | Returns 92 | ------- 93 | rllib_env_cls : type[MultiAgentEnv] 94 | RLlib ``MultiAgentEnv`` environment class 95 | """ 96 | class RLlibEnv(RLlibWrapper): 97 | def __init__(self, config: dict = {}): 98 | config = {**default_config, **config} 99 | env = env_cls(**config) 100 | for wrapper in wrappers: 101 | env = wrapper(env) 102 | super().__init__(env) 103 | 104 | RLlibEnv.__name__ = f"RLlib_{env_cls.__name__}" 105 | return RLlibEnv 106 | 107 | 108 | 109 | # Register environments with RLlib 110 | for name, (env_cls, config) in CONFIGURATIONS.items(): 111 | register_env(name, to_rllib_env(env_cls, OneHotObsWrapper, default_config=config)) 112 | -------------------------------------------------------------------------------- /multigrid/core/mission.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | from gymnasium import spaces 5 | from typing import Any, Callable, Iterable, Sequence 6 | 7 | 8 | 9 | class Mission(np.ndarray): 10 | """ 11 | Class representing an agent mission. 12 | """ 13 | 14 | def __new__(cls, string: str, index: Iterable[int] | None = None): 15 | """ 16 | Parameters 17 | ---------- 18 | string : str 19 | Mission string 20 | index : Iterable[int] 21 | Index of mission string in :class:`MissionSpace` 22 | """ 23 | mission = np.array(0 if index is None else index) 24 | mission = mission.view(cls) 25 | mission.string = string 26 | return mission.view(cls) 27 | 28 | def __array_finalize__(self, mission): 29 | if mission is None: return 30 | self.string = getattr(mission, 'string', None) 31 | 32 | def __str__(self) -> str: 33 | return self.string 34 | 35 | def __repr__(self) -> str: 36 | return f'{self.__class__.__name__}("{self.string}")' 37 | 38 | def __eq__(self, value: object) -> bool: 39 | return self.string == str(value) 40 | 41 | def __hash__(self) -> int: 42 | return hash(self.string) 43 | 44 | 45 | class MissionSpace(spaces.MultiDiscrete): 46 | """ 47 | Class representing a space over agent missions. 48 | 49 | Examples 50 | -------- 51 | >>> observation_space = MissionSpace( 52 | ... mission_func=lambda color: f"Get the {color} ball.", 53 | ... ordered_placeholders=[["green", "blue"]]) 54 | >>> observation_space.seed(123) 55 | >>> observation_space.sample() 56 | Mission("Get the blue ball.") 57 | 58 | >>> observation_space = MissionSpace.from_string("Get the ball.") 59 | >>> observation_space.sample() 60 | Mission("Get the ball.") 61 | """ 62 | 63 | def __init__( 64 | self, 65 | mission_func: Callable[..., str], 66 | ordered_placeholders: Sequence[Sequence[str]] = [], 67 | seed : int | np.random.Generator | None = None): 68 | """ 69 | Parameters 70 | ---------- 71 | mission_func : Callable(*args) -> str 72 | Deterministic function that generates a mission string 73 | ordered_placeholders : Sequence[Sequence[str]] 74 | Sequence of argument groups, ordered by placing order in ``mission_func()`` 75 | seed : int or np.random.Generator or None 76 | Seed for random sampling from the space 77 | """ 78 | self.mission_func = mission_func 79 | self.arg_groups = ordered_placeholders 80 | nvec = tuple(len(group) for group in self.arg_groups) 81 | super().__init__(nvec=nvec if nvec else (1,)) 82 | 83 | def __repr__(self) -> str: 84 | """ 85 | Get a string representation of this space. 86 | """ 87 | if self.arg_groups: 88 | return f'MissionSpace({self.mission_func.__name__}, {self.arg_groups})' 89 | return f"MissionSpace('{self.mission_func()}')" 90 | 91 | def get(self, idx: Iterable[int]) -> Mission: 92 | """ 93 | Get the mission string corresponding to the given index. 94 | 95 | Parameters 96 | ---------- 97 | idx : Iterable[int] 98 | Index of desired argument in each argument group 99 | """ 100 | if self.arg_groups: 101 | args = (self.arg_groups[axis][index] for axis, index in enumerate(idx)) 102 | return Mission(string=self.mission_func(*args), index=idx) 103 | return Mission(string=self.mission_func()) 104 | 105 | def sample(self) -> Mission: 106 | """ 107 | Sample a random mission string. 108 | """ 109 | idx = super().sample() 110 | return self.get(idx) 111 | 112 | def contains(self, x: Any) -> bool: 113 | """ 114 | Check if an item is a valid member of this mission space. 115 | 116 | Parameters 117 | ---------- 118 | x : Any 119 | Item to check 120 | """ 121 | for idx in np.ndindex(tuple(self.nvec)): 122 | if self.get(idx) == x: 123 | return True 124 | return False 125 | 126 | @staticmethod 127 | def from_string(string: str) -> MissionSpace: 128 | """ 129 | Create a mission space containing a single mission string. 130 | 131 | Parameters 132 | ---------- 133 | string : str 134 | Mission string 135 | """ 136 | return MissionSpace(mission_func=lambda: string) 137 | -------------------------------------------------------------------------------- /multigrid/envs/playground.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from multigrid.core.mission import MissionSpace 4 | from multigrid.core.roomgrid import RoomGrid 5 | 6 | 7 | 8 | class PlaygroundEnv(RoomGrid): 9 | """ 10 | .. image:: https://i.imgur.com/QBz99Vh.gif 11 | :width: 380 12 | 13 | *********** 14 | Description 15 | *********** 16 | 17 | Environment with multiple rooms and random objects. 18 | This environment has no specific goals or rewards. 19 | 20 | ************* 21 | Mission Space 22 | ************* 23 | 24 | None 25 | 26 | ***************** 27 | Observation Space 28 | ***************** 29 | 30 | The multi-agent observation space is a Dict mapping from agent index to 31 | corresponding agent observation space. 32 | 33 | Each agent observation is a dictionary with the following entries: 34 | 35 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`) 36 | Encoding of the agent's partially observable view of the environment, 37 | where the object at each grid cell is encoded as a vector: 38 | (:class:`.Type`, :class:`.Color`, :class:`.State`) 39 | * direction : int 40 | Agent's direction (0: right, 1: down, 2: left, 3: up) 41 | * mission : Mission 42 | Task string corresponding to the current environment configuration 43 | 44 | ************ 45 | Action Space 46 | ************ 47 | 48 | The multi-agent action space is a Dict mapping from agent index to 49 | corresponding agent action space. 50 | 51 | Agent actions are discrete integer values, given by: 52 | 53 | +-----+--------------+-----------------------------+ 54 | | Num | Name | Action | 55 | +=====+==============+=============================+ 56 | | 0 | left | Turn left | 57 | +-----+--------------+-----------------------------+ 58 | | 1 | right | Turn right | 59 | +-----+--------------+-----------------------------+ 60 | | 2 | forward | Move forward | 61 | +-----+--------------+-----------------------------+ 62 | | 3 | pickup | Pick up an object | 63 | +-----+--------------+-----------------------------+ 64 | | 4 | drop | Drop an object | 65 | +-----+--------------+-----------------------------+ 66 | | 5 | toggle | Toggle / activate an object | 67 | +-----+--------------+-----------------------------+ 68 | | 6 | done | Done completing task | 69 | +-----+--------------+-----------------------------+ 70 | 71 | ******* 72 | Rewards 73 | ******* 74 | 75 | None 76 | 77 | *********** 78 | Termination 79 | *********** 80 | 81 | The episode ends when the following condition is met: 82 | 83 | * Timeout (see ``max_steps``) 84 | 85 | ************************* 86 | Registered Configurations 87 | ************************* 88 | 89 | * ``MultiGrid-Playground-v0`` 90 | """ 91 | 92 | def __init__( 93 | self, 94 | room_size: int = 7, 95 | num_rows: int = 3, 96 | num_cols: int = 3, 97 | max_steps: int = 100, 98 | **kwargs): 99 | """ 100 | Parameters 101 | ---------- 102 | room_size : int, default=7 103 | Width and height for each of the rooms 104 | num_rows : int, default=3 105 | Number of rows of rooms 106 | num_cols : int, default=3 107 | Number of columns of rooms 108 | max_steps : int, default=100 109 | Maximum number of steps per episode 110 | **kwargs 111 | See :attr:`multigrid.base.MultiGridEnv.__init__` 112 | """ 113 | super().__init__( 114 | mission_space=MissionSpace.from_string(""), 115 | num_rows=num_rows, 116 | num_cols=num_cols, 117 | room_size=room_size, 118 | max_steps=max_steps, 119 | **kwargs, 120 | ) 121 | 122 | def _gen_grid(self, width, height): 123 | """ 124 | :meta private: 125 | """ 126 | super()._gen_grid(width, height) 127 | self.connect_all() 128 | 129 | # Place random objects in the world 130 | for i in range(0, 12): 131 | col = self._rand_int(0, self.num_cols) 132 | row = self._rand_int(0, self.num_rows) 133 | self.add_object(col, row) 134 | 135 | # Place agents 136 | for agent in self.agents: 137 | self.place_agent(agent) 138 | -------------------------------------------------------------------------------- /scripts/visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | 5 | from ray.rllib.utils.typing import AgentID 6 | from ray.rllib.utils.torch_utils import convert_to_torch_tensor 7 | from typing import Callable 8 | 9 | from train import get_algorithm_config, find_checkpoint_dir, get_policy_mapping_fn 10 | from ray.rllib.core.rl_module import RLModule 11 | 12 | 13 | 14 | def visualize( 15 | modules: dict[str, RLModule], 16 | policy_mapping_fn: Callable[[AgentID], str], 17 | num_episodes: int = 10) -> list[np.ndarray]: 18 | """ 19 | Visualize trajectories from trained agents. 20 | 21 | Parameters 22 | ---------- 23 | algorithm : Algorithm 24 | RLlib algorithm instance with trained policies 25 | policy_mapping_fn : Callable(AgentID) -> str 26 | Function mapping agent IDs to policy IDs 27 | num_episodes : int, default=10 28 | Number of episodes to visualize 29 | """ 30 | frames = [] 31 | env = algorithm.env_creator(algorithm.config.env_config) 32 | 33 | for episode in range(num_episodes): 34 | print() 35 | print('-' * 32, '\n', 'Episode', episode, '\n', '-' * 32) 36 | 37 | episode_rewards = {agent_id: 0.0 for agent_id in env.possible_agents} 38 | terminations, truncations = {'__all__': False}, {'__all__': False} 39 | observations, infos = env.reset() 40 | 41 | # Get initial states for each agent 42 | states = { 43 | agent_id: modules[policy_mapping_fn(agent_id)].get_initial_state() 44 | for agent_id in env.agents 45 | } 46 | 47 | while not terminations['__all__'] and not truncations['__all__']: 48 | # Store current frame 49 | frames.append(env.env.unwrapped.get_frame()) 50 | 51 | # Compute actions for each agent 52 | actions = {} 53 | observations = convert_to_torch_tensor(observations) 54 | for agent_id in env.agents: 55 | agent_module = modules[policy_mapping_fn(agent_id)] 56 | out = agent_module.forward_inference({'obs': observations[agent_id]}) 57 | logits = out['action_dist_inputs'] 58 | action_dist_class = agent_module.get_inference_action_dist_cls() 59 | action_dist = action_dist_class.from_logits(logits) 60 | actions[agent_id] = action_dist.sample().item() 61 | 62 | # Take actions in environment and accumulate rewards 63 | observations, rewards, terminations, truncations, infos = env.step(actions) 64 | for agent_id in rewards: 65 | episode_rewards[agent_id] += rewards[agent_id] 66 | 67 | frames.append(env.env.unwrapped.get_frame()) 68 | print('Rewards:', episode_rewards) 69 | 70 | env.close() 71 | return frames 72 | 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument( 78 | '--algo', type=str, default='PPO', 79 | help="The name of the RLlib-registered algorithm to use.") 80 | parser.add_argument( 81 | '--framework', type=str, choices=['torch', 'tf', 'tf2'], default='torch', 82 | help="Deep learning framework to use.") 83 | parser.add_argument( 84 | '--lstm', action='store_true', help="Use LSTM model.") 85 | parser.add_argument( 86 | '--env', type=str, default='MultiGrid-Empty-8x8-v0', 87 | help="MultiGrid environment to use.") 88 | parser.add_argument( 89 | '--env-config', type=json.loads, default={}, 90 | help="Environment config dict, given as a JSON string (e.g. '{\"size\": 8}')") 91 | parser.add_argument( 92 | '--num-agents', type=int, default=2, help="Number of agents in environment.") 93 | parser.add_argument( 94 | '--num-episodes', type=int, default=10, help="Number of episodes to visualize.") 95 | parser.add_argument( 96 | '--load-dir', type=str, 97 | help="Checkpoint directory for loading pre-trained policies.") 98 | parser.add_argument( 99 | '--gif', type=str, help="Store output as GIF at given path.") 100 | 101 | args = parser.parse_args() 102 | args.env_config.update(render_mode='human') 103 | config = get_algorithm_config( 104 | **vars(args), 105 | num_workers=0, 106 | num_gpus=0, 107 | ) 108 | algorithm = config.build() 109 | checkpoint = find_checkpoint_dir(args.load_dir) 110 | policy_mapping_fn = lambda agent_id, *args, **kwargs: f'policy_{agent_id}' 111 | 112 | if checkpoint: 113 | print(f"Loading checkpoint from {checkpoint}") 114 | path = checkpoint / 'learner_group' / 'learner' / 'rl_module/' 115 | modules = RLModule.from_checkpoint(path) 116 | policy_mapping_fn = get_policy_mapping_fn(checkpoint, args.num_agents) 117 | 118 | frames = visualize(modules, policy_mapping_fn, num_episodes=args.num_episodes) 119 | if args.gif: 120 | from array2gif import write_gif 121 | filename = args.gif if args.gif.endswith('.gif') else f'{args.gif}.gif' 122 | print(f"Saving GIF to {filename}") 123 | write_gif(np.array(frames), filename, fps=10) 124 | -------------------------------------------------------------------------------- /multigrid/envs/blockedunlockpickup.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from multigrid.core.constants import Color, Direction, Type 4 | from multigrid.core.mission import MissionSpace 5 | from multigrid.core.roomgrid import RoomGrid 6 | from multigrid.core.world_object import Ball 7 | 8 | 9 | 10 | class BlockedUnlockPickupEnv(RoomGrid): 11 | """ 12 | .. image:: https://i.imgur.com/uSFi059.gif 13 | :width: 275 14 | 15 | *********** 16 | Description 17 | *********** 18 | 19 | The objective is to pick up a box which is placed in another room, behind a 20 | locked door. The door is also blocked by a ball which must be moved before 21 | the door can be unlocked. Hence, agents must learn to move the ball, 22 | pick up the key, open the door and pick up the object in the other 23 | room. 24 | 25 | The standard setting is cooperative, where all agents receive the reward 26 | when the task is completed. 27 | 28 | ************* 29 | Mission Space 30 | ************* 31 | 32 | "pick up the ``{color}`` box" 33 | 34 | ``{color}`` is the color of the box. Can be any :class:`.Color`. 35 | 36 | ***************** 37 | Observation Space 38 | ***************** 39 | 40 | The multi-agent observation space is a Dict mapping from agent index to 41 | corresponding agent observation space. 42 | 43 | Each agent observation is a dictionary with the following entries: 44 | 45 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`) 46 | Encoding of the agent's partially observable view of the environment, 47 | where the object at each grid cell is encoded as a vector: 48 | (:class:`.Type`, :class:`.Color`, :class:`.State`) 49 | * direction : int 50 | Agent's direction (0: right, 1: down, 2: left, 3: up) 51 | * mission : Mission 52 | Task string corresponding to the current environment configuration 53 | 54 | ************ 55 | Action Space 56 | ************ 57 | 58 | The multi-agent action space is a Dict mapping from agent index to 59 | corresponding agent action space. 60 | 61 | Agent actions are discrete integer values, given by: 62 | 63 | +-----+--------------+-----------------------------+ 64 | | Num | Name | Action | 65 | +=====+==============+=============================+ 66 | | 0 | left | Turn left | 67 | +-----+--------------+-----------------------------+ 68 | | 1 | right | Turn right | 69 | +-----+--------------+-----------------------------+ 70 | | 2 | forward | Move forward | 71 | +-----+--------------+-----------------------------+ 72 | | 3 | pickup | Pick up an object | 73 | +-----+--------------+-----------------------------+ 74 | | 4 | drop | Drop an object | 75 | +-----+--------------+-----------------------------+ 76 | | 5 | toggle | Toggle / activate an object | 77 | +-----+--------------+-----------------------------+ 78 | | 6 | done | Done completing task | 79 | +-----+--------------+-----------------------------+ 80 | 81 | ******* 82 | Rewards 83 | ******* 84 | 85 | A reward of ``1 - 0.9 * (step_count / max_steps)`` is given for success, 86 | and ``0`` for failure. 87 | 88 | *********** 89 | Termination 90 | *********** 91 | 92 | The episode ends if any one of the following conditions is met: 93 | 94 | * Any agent picks up the correct box 95 | * Timeout (see ``max_steps``) 96 | 97 | ************************* 98 | Registered Configurations 99 | ************************* 100 | 101 | * ``MultiGrid-BlockedUnlockPickup-v0`` 102 | """ 103 | 104 | def __init__( 105 | self, 106 | room_size: int = 6, 107 | max_steps: int | None = None, 108 | joint_reward: bool = True, 109 | **kwargs): 110 | """ 111 | Parameters 112 | ---------- 113 | room_size : int, default=6 114 | Width and height for each of the two rooms 115 | max_steps : int, optional 116 | Maximum number of steps per episode 117 | joint_reward : bool, default=True 118 | Whether all agents receive the reward when the task is completed 119 | **kwargs 120 | See :attr:`multigrid.base.MultiGridEnv.__init__` 121 | """ 122 | assert room_size >= 4 123 | mission_space = MissionSpace( 124 | mission_func=self._gen_mission, 125 | ordered_placeholders=[list(Color), [Type.box, Type.key]], 126 | ) 127 | super().__init__( 128 | mission_space=mission_space, 129 | num_rows=1, 130 | num_cols=2, 131 | room_size=room_size, 132 | max_steps=max_steps or (16 * room_size**2), 133 | joint_reward=joint_reward, 134 | success_termination_mode='any', 135 | **kwargs, 136 | ) 137 | 138 | @staticmethod 139 | def _gen_mission(color: str, obj_type: str): 140 | return f"pick up the {color} {obj_type}" 141 | 142 | def _gen_grid(self, width, height): 143 | """ 144 | :meta private: 145 | """ 146 | super()._gen_grid(width, height) 147 | 148 | # Add a box to the room on the right 149 | self.obj, _ = self.add_object(1, 0, kind=Type.box) 150 | 151 | # Make sure the two rooms are directly connected by a locked door 152 | door, pos = self.add_door(0, 0, Direction.right, locked=True) 153 | 154 | # Block the door with a ball 155 | self.grid.set(pos[0] - 1, pos[1], Ball(color=self._rand_color())) 156 | 157 | # Add a key to unlock the door 158 | self.add_object(0, 0, Type.key, door.color) 159 | 160 | # Place agents in the left room 161 | for agent in self.agents: 162 | self.place_agent(agent, 0, 0) 163 | 164 | self.mission = f"pick up the {self.obj.color} {self.obj.type}" 165 | 166 | def step(self, actions): 167 | """ 168 | :meta private: 169 | """ 170 | obs, reward, terminated, truncated, info = super().step(actions) 171 | for agent in self.agents: 172 | if agent.state.carrying == self.obj: 173 | self.on_success(agent, reward, terminated) 174 | 175 | return obs, reward, terminated, truncated, info 176 | -------------------------------------------------------------------------------- /multigrid/utils/minigrid_interface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gymnasium import spaces 4 | from gymnasium.core import ActType, ObsType 5 | from typing import Any, Sequence, SupportsFloat 6 | 7 | from ..core.world_object import WorldObj 8 | from ..base import MultiGridEnv 9 | 10 | 11 | 12 | class MiniGridInterface(MultiGridEnv): 13 | """ 14 | MultiGridEnv interface for compatibility with single-agent MiniGrid environments. 15 | 16 | Most environment implementations deriving from `minigrid.MiniGridEnv` can be 17 | converted to a single-agent `MultiGridEnv` by simply inheriting from 18 | `MiniGridInterface` instead (along with using the multigrid grid and grid objects). 19 | 20 | Examples 21 | -------- 22 | Start with a single-agent minigrid environment: 23 | 24 | >>> from minigrid.core.world_object import Ball, Key, Door 25 | >>> from minigrid.core.grid import Grid 26 | >>> from minigrid import MiniGridEnv 27 | 28 | >>> class MyEnv(MiniGridEnv): 29 | >>> ... # existing class definition 30 | 31 | Now use multigrid imports, keeping the environment class definition the same: 32 | 33 | >>> from multigrid.core.world_object import Ball, Key, Door 34 | >>> from multigrid.core.grid import Grid 35 | >>> from multigrid.utils.minigrid_interface import MiniGridInterface as MiniGridEnv 36 | 37 | >>> class MyEnv(MiniGridEnv): 38 | >>> ... # same class definition 39 | """ 40 | 41 | def reset(self, *args, **kwargs) -> tuple[ObsType, dict[str, Any]]: 42 | """ 43 | Reset the environment. 44 | """ 45 | result = super().reset(*args, **kwargs) 46 | return tuple(item[0] for item in result) 47 | 48 | def step(self, action: ActType) -> tuple[ 49 | ObsType, 50 | SupportsFloat, 51 | bool, 52 | bool, 53 | dict[str, Any]]: 54 | """ 55 | Run one timestep of the environment’s dynamics 56 | using the provided agent action. 57 | """ 58 | result = super().step({0: action}) 59 | return tuple(item[0] for item in result) 60 | 61 | @property 62 | def action_space(self) -> spaces.Space: 63 | """ 64 | Get action space. 65 | """ 66 | assert len(self.agents) == 1, ( 67 | "This property is not supported for multi-agent envs. " 68 | "Use `env.agents[i].action_space` instead." 69 | ) 70 | return self.agents[0].action_space 71 | 72 | @action_space.setter 73 | def action_space(self, space: spaces.Space): 74 | """ 75 | Set action space. 76 | """ 77 | assert len(self.agents) == 1, ( 78 | "This property is not supported for multi-agent envs. " 79 | "Use `env.agents[i].action_space` instead." 80 | ) 81 | self.agents[0].action_space = space 82 | 83 | @property 84 | def observation_space(self) -> spaces.Space: 85 | """ 86 | Get observation space. 87 | """ 88 | assert len(self.agents) == 1, ( 89 | "This property is not supported for multi-agent envs. " 90 | "Use `env.agents[i].observation_space` instead." 91 | ) 92 | return self.agents[0].observation_space 93 | 94 | @observation_space.setter 95 | def observation_space(self, space: spaces.Space): 96 | """ 97 | Set observation space. 98 | """ 99 | assert len(self.agents) == 1, ( 100 | "This property is not supported for multi-agent envs. " 101 | "Use `env.agents[i].observation_space` instead." 102 | ) 103 | self.agents[0].observation_space = space 104 | 105 | @property 106 | def agent_pos(self) -> np.ndarray[int]: 107 | """ 108 | Get agent position. 109 | """ 110 | assert len(self.agents) == 1, ( 111 | "This property is not supported for multi-agent envs. " 112 | "Use `env.agents[i].pos` instead." 113 | ) 114 | return self.agents[0].pos 115 | 116 | @agent_pos.setter 117 | def agent_pos(self, value: Sequence[int]): 118 | """ 119 | Set agent position. 120 | """ 121 | assert len(self.agents) == 1, ( 122 | "This property is not supported for multi-agent envs. " 123 | "Use `env.agents[i].pos` instead." 124 | ) 125 | if value is not None: 126 | self.agents[0].pos = value 127 | 128 | @property 129 | def agent_dir(self) -> int: 130 | """ 131 | Get agent direction. 132 | """ 133 | assert len(self.agents) == 1, ( 134 | "This property is not supported for multi-agent envs. " 135 | "Use `env.agents[i].dir` instead." 136 | ) 137 | return self.agents[0].dir 138 | 139 | @agent_dir.setter 140 | def agent_dir(self, value: Sequence[int]): 141 | """ 142 | Set agent direction. 143 | """ 144 | assert len(self.agents) == 1, ( 145 | "This property is not supported for multi-agent envs. " 146 | "Use `env.agents[i].dir` instead." 147 | ) 148 | self.agents[0].dir = value 149 | 150 | @property 151 | def carrying(self) -> WorldObj: 152 | """ 153 | Get object carried by agent. 154 | """ 155 | assert len(self.agents) == 1, ( 156 | "This property is not supported for multi-agent envs. " 157 | "Use `env.agents[i].carrying` instead." 158 | ) 159 | return self.agents[0].carrying 160 | 161 | @property 162 | def dir_vec(self): 163 | """ 164 | Get the direction vector for the agent, pointing in the direction 165 | of forward movement. 166 | """ 167 | assert len(self.agents) == 1, ( 168 | "This property is not supported for multi-agent envs. " 169 | "Use `env.agents[i].dir.to_vec()` instead." 170 | ) 171 | return self.agents[0].dir.to_vec() 172 | 173 | @property 174 | def front_pos(self): 175 | """ 176 | Get the position of the cell that is right in front of the agent. 177 | """ 178 | assert len(self.agents) == 1, ( 179 | "This property is not supported for multi-agent envs. " 180 | "Use `env.agents[i].front_pos` instead." 181 | ) 182 | return self.agents[0].front_pos 183 | 184 | def place_agent(self, *args, **kwargs) -> tuple[int, int]: 185 | """ 186 | Set agent starting point at an empty position in the grid. 187 | """ 188 | return super().place_agent(self.agents[0], *args, **kwargs) 189 | -------------------------------------------------------------------------------- /multigrid/envs/empty.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from multigrid import MultiGridEnv 4 | from multigrid.core import Grid 5 | from multigrid.core.constants import Direction 6 | from multigrid.core.world_object import Goal 7 | 8 | 9 | 10 | class EmptyEnv(MultiGridEnv): 11 | """ 12 | .. image:: https://i.imgur.com/wY0tT7R.gif 13 | :width: 200 14 | 15 | *********** 16 | Description 17 | *********** 18 | 19 | This environment is an empty room, and the goal for each agent is to reach the 20 | green goal square, which provides a sparse reward. A small penalty is subtracted 21 | for the number of steps to reach the goal. 22 | 23 | The standard setting is competitive, where agents race to the goal, and 24 | only the winner receives a reward. 25 | 26 | This environment is useful with small rooms, to validate that your RL algorithm 27 | works correctly, and with large rooms to experiment with sparse rewards and 28 | exploration. The random variants of the environment have the agents starting 29 | at a random position for each episode, while the regular variants have the 30 | agent always starting in the corner opposite to the goal. 31 | 32 | ************* 33 | Mission Space 34 | ************* 35 | 36 | "get to the green goal square" 37 | 38 | ***************** 39 | Observation Space 40 | ***************** 41 | 42 | The multi-agent observation space is a Dict mapping from agent index to 43 | corresponding agent observation space. 44 | 45 | Each agent observation is a dictionary with the following entries: 46 | 47 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`) 48 | Encoding of the agent's partially observable view of the environment, 49 | where the object at each grid cell is encoded as a vector: 50 | (:class:`.Type`, :class:`.Color`, :class:`.State`) 51 | * direction : int 52 | Agent's direction (0: right, 1: down, 2: left, 3: up) 53 | * mission : Mission 54 | Task string corresponding to the current environment configuration 55 | 56 | ************ 57 | Action Space 58 | ************ 59 | 60 | The multi-agent action space is a Dict mapping from agent index to 61 | corresponding agent action space. 62 | 63 | Agent actions are discrete integer values, given by: 64 | 65 | +-----+--------------+-----------------------------+ 66 | | Num | Name | Action | 67 | +=====+==============+=============================+ 68 | | 0 | left | Turn left | 69 | +-----+--------------+-----------------------------+ 70 | | 1 | right | Turn right | 71 | +-----+--------------+-----------------------------+ 72 | | 2 | forward | Move forward | 73 | +-----+--------------+-----------------------------+ 74 | | 3 | pickup | Pick up an object | 75 | +-----+--------------+-----------------------------+ 76 | | 4 | drop | Drop an object | 77 | +-----+--------------+-----------------------------+ 78 | | 5 | toggle | Toggle / activate an object | 79 | +-----+--------------+-----------------------------+ 80 | | 6 | done | Done completing task | 81 | +-----+--------------+-----------------------------+ 82 | 83 | ******* 84 | Rewards 85 | ******* 86 | 87 | A reward of ``1 - 0.9 * (step_count / max_steps)`` is given for success, 88 | and ``0`` for failure. 89 | 90 | *********** 91 | Termination 92 | *********** 93 | 94 | The episode ends if any one of the following conditions is met: 95 | 96 | * Any agent reaches the goal 97 | * Timeout (see ``max_steps``) 98 | 99 | ************************* 100 | Registered Configurations 101 | ************************* 102 | 103 | * ``MultiGrid-Empty-5x5-v0`` 104 | * ``MultiGrid-Empty-Random-5x5-v0`` 105 | * ``MultiGrid-Empty-6x6-v0`` 106 | * ``MultiGrid-Empty-Random-6x6-v0`` 107 | * ``MultiGrid-Empty-8x8-v0`` 108 | * ``MultiGrid-Empty-16x16-v0`` 109 | """ 110 | 111 | def __init__( 112 | self, 113 | size: int = 8, 114 | agent_start_pos: tuple[int, int] | None = (1, 1), 115 | agent_start_dir: Direction | None = Direction.right, 116 | max_steps: int | None = None, 117 | joint_reward: bool = False, 118 | success_termination_mode: str = 'any', 119 | **kwargs): 120 | """ 121 | Parameters 122 | ---------- 123 | size : int, default=8 124 | Width and height of the grid 125 | agent_start_pos : tuple[int, int], default=(1, 1) 126 | Starting position of the agents (random if None) 127 | agent_start_dir : Direction, default=Direction.right 128 | Starting direction of the agents (random if None) 129 | max_steps : int, optional 130 | Maximum number of steps per episode 131 | joint_reward : bool, default=True 132 | Whether all agents receive the reward when the task is completed 133 | success_termination_mode : 'any' or 'all', default='any' 134 | Whether to terminate the environment when any agent reaches the goal 135 | or after all agents reach the goal 136 | **kwargs 137 | See :attr:`multigrid.base.MultiGridEnv.__init__` 138 | """ 139 | self.agent_start_pos = agent_start_pos 140 | self.agent_start_dir = agent_start_dir 141 | 142 | super().__init__( 143 | mission_space="get to the green goal square", 144 | grid_size=size, 145 | max_steps=max_steps or (4 * size**2), 146 | joint_reward=joint_reward, 147 | success_termination_mode=success_termination_mode, 148 | **kwargs, 149 | ) 150 | 151 | def _gen_grid(self, width, height): 152 | """ 153 | :meta private: 154 | """ 155 | # Create an empty grid 156 | self.grid = Grid(width, height) 157 | 158 | # Generate the surrounding walls 159 | self.grid.wall_rect(0, 0, width, height) 160 | 161 | # Place a goal square in the bottom-right corner 162 | self.put_obj(Goal(), width - 2, height - 2) 163 | 164 | # Place the agent 165 | for agent in self.agents: 166 | if self.agent_start_pos is not None and self.agent_start_dir is not None: 167 | agent.state.pos = self.agent_start_pos 168 | agent.state.dir = self.agent_start_dir 169 | else: 170 | self.place_agent(agent) 171 | -------------------------------------------------------------------------------- /multigrid/envs/redbluedoors.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from multigrid import MultiGridEnv 4 | from multigrid.core import Action, Grid, MissionSpace 5 | from multigrid.core.constants import Color 6 | from multigrid.core.world_object import Door 7 | 8 | 9 | 10 | class RedBlueDoorsEnv(MultiGridEnv): 11 | """ 12 | .. image:: https://i.imgur.com/usbavAh.gif 13 | :width: 400 14 | 15 | *********** 16 | Description 17 | *********** 18 | 19 | This environment is a room with one red and one blue door facing 20 | opposite directions. Agents must open the red door and then open the blue door, 21 | in that order. 22 | 23 | The standard setting is cooperative, where all agents receive the reward 24 | upon completion of the task. 25 | 26 | ************* 27 | Mission Space 28 | ************* 29 | 30 | "open the red door then the blue door" 31 | 32 | ***************** 33 | Observation Space 34 | ***************** 35 | 36 | The multi-agent observation space is a Dict mapping from agent index to 37 | corresponding agent observation space. 38 | 39 | Each agent observation is a dictionary with the following entries: 40 | 41 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`) 42 | Encoding of the agent's partially observable view of the environment, 43 | where the object at each grid cell is encoded as a vector: 44 | (:class:`.Type`, :class:`.Color`, :class:`.State`) 45 | * direction : int 46 | Agent's direction (0: right, 1: down, 2: left, 3: up) 47 | * mission : Mission 48 | Task string corresponding to the current environment configuration 49 | 50 | ************ 51 | Action Space 52 | ************ 53 | 54 | The multi-agent action space is a Dict mapping from agent index to 55 | corresponding agent action space. 56 | 57 | Agent actions are discrete integer values, given by: 58 | 59 | +-----+--------------+-----------------------------+ 60 | | Num | Name | Action | 61 | +=====+==============+=============================+ 62 | | 0 | left | Turn left | 63 | +-----+--------------+-----------------------------+ 64 | | 1 | right | Turn right | 65 | +-----+--------------+-----------------------------+ 66 | | 2 | forward | Move forward | 67 | +-----+--------------+-----------------------------+ 68 | | 3 | pickup | Pick up an object | 69 | +-----+--------------+-----------------------------+ 70 | | 4 | drop | Drop an object | 71 | +-----+--------------+-----------------------------+ 72 | | 5 | toggle | Toggle / activate an object | 73 | +-----+--------------+-----------------------------+ 74 | | 6 | done | Done completing task | 75 | +-----+--------------+-----------------------------+ 76 | 77 | ******* 78 | Rewards 79 | ******* 80 | 81 | A reward of ``1 - 0.9 * (step_count / max_steps)`` is given for success, 82 | and ``0`` for failure. 83 | 84 | *********** 85 | Termination 86 | *********** 87 | 88 | The episode ends if any one of the following conditions is met: 89 | 90 | * Any agent opens the blue door while the red door is open (success) 91 | * Any agent opens the blue door while the red door is not open (failure) 92 | * Timeout (see ``max_steps``) 93 | 94 | ************************* 95 | Registered Configurations 96 | ************************* 97 | 98 | * ``MultiGrid-RedBlueDoors-6x6-v0`` 99 | * ``MultiGrid-RedBlueDoors-8x8-v0`` 100 | """ 101 | 102 | def __init__( 103 | self, 104 | size: int = 8, 105 | max_steps: int | None = None, 106 | joint_reward: bool = True, 107 | success_termination_mode: str = 'any', 108 | failure_termination_mode: str = 'any', 109 | **kwargs): 110 | """ 111 | Parameters 112 | ---------- 113 | size : int, default=8 114 | Width and height of the grid 115 | max_steps : int, optional 116 | Maximum number of steps per episode 117 | joint_reward : bool, default=True 118 | Whether all agents receive the reward when the task is completed 119 | success_termination_mode : 'any' or 'all', default='any' 120 | Whether to terminate the environment when any agent fails the task 121 | or after all agents fail the task 122 | failure_termination_mode : 'any' or 'all', default='any' 123 | Whether to terminate the environment when any agent fails the task 124 | or after all agents fail the task 125 | **kwargs 126 | See :attr:`multigrid.base.MultiGridEnv.__init__` 127 | """ 128 | self.size = size 129 | mission_space = MissionSpace.from_string("open the red door then the blue door") 130 | super().__init__( 131 | mission_space=mission_space, 132 | width=(2 * size), 133 | height=size, 134 | max_steps=max_steps or (20 * size**2), 135 | joint_reward=joint_reward, 136 | success_termination_mode=success_termination_mode, 137 | failure_termination_mode=failure_termination_mode, 138 | **kwargs, 139 | ) 140 | 141 | def _gen_grid(self, width, height): 142 | """ 143 | :meta private: 144 | """ 145 | # Create an empty grid 146 | self.grid = Grid(width, height) 147 | 148 | # Generate the grid walls 149 | room_top = (width // 4, 0) 150 | room_size = (width // 2, height) 151 | self.grid.wall_rect(0, 0, width, height) 152 | self.grid.wall_rect(*room_top, *room_size) 153 | 154 | # Place agents in the top-left corner 155 | for agent in self.agents: 156 | self.place_agent(agent, top=room_top, size=room_size) 157 | 158 | # Add a red door at a random position in the left wall 159 | x = room_top[0] 160 | y = self._rand_int(1, height - 1) 161 | self.red_door = Door(Color.red) 162 | self.grid.set(x, y, self.red_door) 163 | 164 | # Add a blue door at a random position in the right wall 165 | x = room_top[0] + room_size[0] - 1 166 | y = self._rand_int(1, height - 1) 167 | self.blue_door = Door(Color.blue) 168 | self.grid.set(x, y, self.blue_door) 169 | 170 | def step(self, actions): 171 | """ 172 | :meta private: 173 | """ 174 | obs, reward, terminated, truncated, info = super().step(actions) 175 | 176 | for agent_id, action in actions.items(): 177 | if action == Action.toggle: 178 | agent = self.agents[agent_id] 179 | fwd_obj = self.grid.get(*agent.front_pos) 180 | if fwd_obj == self.blue_door and self.blue_door.is_open: 181 | if self.red_door.is_open: 182 | self.on_success(agent, reward, terminated) 183 | else: 184 | self.on_failure(agent, reward, terminated) 185 | self.blue_door.is_open = False # close the door again 186 | 187 | return obs, reward, terminated, truncated, info 188 | -------------------------------------------------------------------------------- /multigrid/wrappers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import gymnasium as gym 4 | import numba as nb 5 | import numpy as np 6 | 7 | from gymnasium import spaces 8 | from gymnasium.core import ObservationWrapper 9 | from numpy.typing import NDArray as ndarray 10 | 11 | from .base import MultiGridEnv, AgentID, ObsType 12 | from .core.constants import Color, Direction, State, Type 13 | from .core.world_object import WorldObj 14 | 15 | 16 | 17 | class FullyObsWrapper(ObservationWrapper): 18 | """ 19 | Fully observable gridworld using a compact grid encoding instead of agent view. 20 | 21 | Examples 22 | -------- 23 | >>> import gymnasium as gym 24 | >>> import multigrid.envs 25 | >>> env = gym.make('MultiGrid-Empty-16x16-v0') 26 | >>> obs, _ = env.reset() 27 | >>> obs[0]['image'].shape 28 | (7, 7, 3) 29 | 30 | >>> from multigrid.wrappers import FullyObsWrapper 31 | >>> env = FullyObsWrapper(env) 32 | >>> obs, _ = env.reset() 33 | >>> obs[0]['image'].shape 34 | (16, 16, 3) 35 | """ 36 | 37 | def __init__(self, env: MultiGridEnv): 38 | """ 39 | """ 40 | super().__init__(env) 41 | 42 | # Update agent observation spaces 43 | for agent in self.env.agents: 44 | agent.observation_space['image'] = spaces.Box( 45 | low=0, high=255, shape=(env.height, env.width, WorldObj.dim), dtype=int) 46 | 47 | def observation(self, obs: dict[AgentID, ObsType]) -> dict[AgentID, ObsType]: 48 | """ 49 | :meta private: 50 | """ 51 | img = self.env.grid.encode() 52 | for agent in self.env.agents: 53 | img[agent.state.pos] = agent.encode() 54 | 55 | for agent_id in obs: 56 | obs[agent_id]['image'] = img 57 | 58 | return obs 59 | 60 | 61 | class ImgObsWrapper(ObservationWrapper): 62 | """ 63 | Use the image as the only observation output for each agent. 64 | 65 | Examples 66 | -------- 67 | >>> import gymnasium as gym 68 | >>> import multigrid.envs 69 | >>> env = gym.make('MultiGrid-Empty-8x8-v0') 70 | >>> obs, _ = env.reset() 71 | >>> obs[0].keys() 72 | dict_keys(['image', 'direction', 'mission']) 73 | 74 | >>> from multigrid.wrappers import ImgObsWrapper 75 | >>> env = ImgObsWrapper(env) 76 | >>> obs, _ = env.reset() 77 | >>> obs.shape 78 | (7, 7, 3) 79 | """ 80 | 81 | def __init__(self, env: MultiGridEnv): 82 | """ 83 | """ 84 | super().__init__(env) 85 | 86 | # Update agent observation spaces 87 | for agent in self.env.agents: 88 | agent.observation_space = agent.observation_space['image'] 89 | agent.observation_space.dtype = np.uint8 90 | 91 | def observation(self, obs: dict[AgentID, ObsType]) -> dict[AgentID, ObsType]: 92 | """ 93 | :meta private: 94 | """ 95 | for agent_id in obs: 96 | obs[agent_id] = obs[agent_id]['image'].astype(np.uint8) 97 | 98 | return obs 99 | 100 | 101 | class OneHotObsWrapper(ObservationWrapper): 102 | """ 103 | Wrapper to get a one-hot encoding of a partially observable 104 | agent view as observation. 105 | 106 | Examples 107 | -------- 108 | >>> import gymnasium as gym 109 | >>> import multigrid.envs 110 | >>> env = gym.make('MultiGrid-Empty-5x5-v0') 111 | >>> obs, _ = env.reset() 112 | >>> obs[0]['image'][0, :, :] 113 | array([[2, 5, 0], 114 | [2, 5, 0], 115 | [2, 5, 0], 116 | [2, 5, 0], 117 | [2, 5, 0], 118 | [2, 5, 0], 119 | [2, 5, 0]]) 120 | 121 | >>> from multigrid.wrappers import OneHotObsWrapper 122 | >>> env = OneHotObsWrapper(env) 123 | >>> obs, _ = env.reset() 124 | >>> obs[0]['image'][0, :, :] 125 | array([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], 126 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], 127 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], 128 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], 129 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], 130 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], 131 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]], 132 | dtype=uint8) 133 | """ 134 | 135 | def __init__(self, env: MultiGridEnv): 136 | """ 137 | """ 138 | super().__init__(env) 139 | self.dim_sizes = np.array([ 140 | len(Type), len(Color), max(len(State), len(Direction))]) 141 | 142 | # Update agent observation spaces 143 | dim = sum(self.dim_sizes) 144 | for agent in self.env.agents: 145 | view_height, view_width, _ = agent.observation_space['image'].shape 146 | agent.observation_space['image'] = spaces.Box( 147 | low=0, high=1, shape=(view_height, view_width, dim), dtype=np.uint8) 148 | 149 | def observation(self, obs: dict[AgentID, ObsType]) -> dict[AgentID, ObsType]: 150 | """ 151 | :meta private: 152 | """ 153 | for agent_id in obs: 154 | obs[agent_id]['image'] = self.one_hot(obs[agent_id]['image'], self.dim_sizes) 155 | 156 | return obs 157 | 158 | @staticmethod 159 | @nb.njit(cache=True) 160 | def one_hot(x: ndarray[np.int], dim_sizes: ndarray[np.int]) -> ndarray[np.uint8]: 161 | """ 162 | Return a one-hot encoding of a 3D integer array, 163 | where each 2D slice is encoded separately. 164 | 165 | Parameters 166 | ---------- 167 | x : ndarray[int] of shape (view_height, view_width, dim) 168 | 3D array of integers to be one-hot encoded 169 | dim_sizes : ndarray[int] of shape (dim,) 170 | Number of possible values for each dimension 171 | 172 | Returns 173 | ------- 174 | out : ndarray[uint8] of shape (view_height, view_width, sum(dim_sizes)) 175 | One-hot encoding 176 | 177 | :meta private: 178 | """ 179 | out = np.zeros((x.shape[0], x.shape[1], sum(dim_sizes)), dtype=np.uint8) 180 | 181 | dim_offset = 0 182 | for d in range(len(dim_sizes)): 183 | for i in range(x.shape[0]): 184 | for j in range(x.shape[1]): 185 | k = dim_offset + x[i, j, d] 186 | out[i, j, k] = 1 187 | 188 | dim_offset += dim_sizes[d] 189 | 190 | return out 191 | 192 | 193 | class SingleAgentWrapper(gym.Wrapper): 194 | """ 195 | Wrapper to convert a multi-agent environment into a 196 | single-agent environment. 197 | 198 | Examples 199 | -------- 200 | >>> import gymnasium as gym 201 | >>> import multigrid.envs 202 | >>> env = gym.make('MultiGrid-Empty-5x5-v0') 203 | >>> obs, _ = env.reset() 204 | >>> obs[0].keys() 205 | dict_keys(['image', 'direction', 'mission']) 206 | 207 | >>> from multigrid.wrappers import SingleAgentWrapper 208 | >>> env = SingleAgentWrapper(env) 209 | >>> obs, _ = env.reset() 210 | >>> obs.keys() 211 | dict_keys(['image', 'direction', 'mission']) 212 | """ 213 | 214 | def __init__(self, env: MultiGridEnv): 215 | """ 216 | """ 217 | super().__init__(env) 218 | self.observation_space = env.agents[0].observation_space 219 | self.action_space = env.agents[0].action_space 220 | 221 | def reset(self, *args, **kwargs): 222 | """ 223 | :meta private: 224 | """ 225 | result = super().reset(*args, **kwargs) 226 | return tuple(item[0] for item in result) 227 | 228 | def step(self, action): 229 | """ 230 | :meta private: 231 | """ 232 | result = super().step({0: action}) 233 | return tuple(item[0] for item in result) 234 | -------------------------------------------------------------------------------- /multigrid/utils/rendering.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | import numpy as np 5 | from numpy.typing import NDArray as ndarray 6 | from typing import Callable 7 | 8 | 9 | 10 | # Constants 11 | 12 | FilterFunction = Callable[[float, float], bool] 13 | White = np.array([255, 255, 255]) 14 | 15 | 16 | 17 | # Functions 18 | 19 | def downsample(img: ndarray[np.uint8], factor: int) -> ndarray[np.uint8]: 20 | """ 21 | Downsample an image along both dimensions by some factor. 22 | 23 | Parameters 24 | ---------- 25 | img : ndarray[uint8] of shape (height, width, 3) 26 | The image to downsample 27 | factor : int 28 | The factor by which to downsample the image 29 | 30 | Returns 31 | ------- 32 | img : ndarray[uint8] of shape (height/factor, width/factor, 3) 33 | The downsampled image 34 | """ 35 | assert img.shape[0] % factor == 0 36 | assert img.shape[1] % factor == 0 37 | 38 | img = img.reshape( 39 | [img.shape[0] // factor, factor, img.shape[1] // factor, factor, 3] 40 | ) 41 | img = img.mean(axis=3) 42 | img = img.mean(axis=1) 43 | 44 | return img 45 | 46 | def fill_coords( 47 | img: ndarray[np.uint8], 48 | fn: FilterFunction, 49 | color: ndarray[np.uint8]) -> ndarray[np.uint8]: 50 | """ 51 | Fill pixels of an image with coordinates matching a filter function. 52 | 53 | Parameters 54 | ---------- 55 | img : ndarray[uint8] of shape (height, width, 3) 56 | The image to fill 57 | fn : Callable(float, float) -> bool 58 | The filter function to use for coordinates 59 | color : ndarray[uint8] of shape (3,) 60 | RGB color to fill matching coordinates 61 | 62 | Returns 63 | ------- 64 | img : ndarray[np.uint8] of shape (height, width, 3) 65 | The updated image 66 | """ 67 | for y in range(img.shape[0]): 68 | for x in range(img.shape[1]): 69 | yf = (y + 0.5) / img.shape[0] 70 | xf = (x + 0.5) / img.shape[1] 71 | if fn(xf, yf): 72 | img[y, x] = color 73 | 74 | return img 75 | 76 | def rotate_fn(fin: FilterFunction, cx: float, cy: float, theta: float) -> FilterFunction: 77 | """ 78 | Rotate a coordinate filter function around a center point by some angle. 79 | 80 | Parameters 81 | ---------- 82 | fin : Callable(float, float) -> bool 83 | The filter function to rotate 84 | cx : float 85 | The x-coordinate of the center of rotation 86 | cy : float 87 | The y-coordinate of the center of rotation 88 | theta : float 89 | The angle by which to rotate the filter function (in radians) 90 | 91 | Returns 92 | ------- 93 | fout : Callable(float, float) -> bool 94 | The rotated filter function 95 | """ 96 | def fout(x, y): 97 | x = x - cx 98 | y = y - cy 99 | 100 | x2 = cx + x * math.cos(-theta) - y * math.sin(-theta) 101 | y2 = cy + y * math.cos(-theta) + x * math.sin(-theta) 102 | 103 | return fin(x2, y2) 104 | 105 | return fout 106 | 107 | def point_in_line( 108 | x0: float, y0: float, x1: float, y1: float, r: float) -> FilterFunction: 109 | """ 110 | Return a filter function that returns True for points within distance r 111 | from the line between (x0, y0) and (x1, y1). 112 | 113 | Parameters 114 | ---------- 115 | x0 : float 116 | The x-coordinate of the line start point 117 | y0 : float 118 | The y-coordinate of the line start point 119 | x1 : float 120 | The x-coordinate of the line end point 121 | y1 : float 122 | The y-coordinate of the line end point 123 | r : float 124 | Maximum distance from the line 125 | 126 | Returns 127 | ------- 128 | fn : Callable(float, float) -> bool 129 | Filter function 130 | """ 131 | p0 = np.array([x0, y0], dtype=np.float32) 132 | p1 = np.array([x1, y1], dtype=np.float32) 133 | dir = p1 - p0 134 | dist = np.linalg.norm(dir) 135 | dir = dir / dist 136 | 137 | xmin = min(x0, x1) - r 138 | xmax = max(x0, x1) + r 139 | ymin = min(y0, y1) - r 140 | ymax = max(y0, y1) + r 141 | 142 | def fn(x, y): 143 | # Fast, early escape test 144 | if x < xmin or x > xmax or y < ymin or y > ymax: 145 | return False 146 | 147 | q = np.array([x, y]) 148 | pq = q - p0 149 | 150 | # Closest point on line 151 | a = np.dot(pq, dir) 152 | a = np.clip(a, 0, dist) 153 | p = p0 + a * dir 154 | 155 | dist_to_line = np.linalg.norm(q - p) 156 | return dist_to_line <= r 157 | 158 | return fn 159 | 160 | def point_in_circle(cx: float, cy: float, r: float) -> FilterFunction: 161 | """ 162 | Return a filter function that returns True for points within radius r 163 | from a given point. 164 | 165 | Parameters 166 | ---------- 167 | cx : float 168 | The x-coordinate of the circle center 169 | cy : float 170 | The y-coordinate of the circle center 171 | r : float 172 | The radius of the circle 173 | 174 | Returns 175 | ------- 176 | fn : Callable(float, float) -> bool 177 | Filter function 178 | """ 179 | def fn(x, y): 180 | return (x - cx) * (x - cx) + (y - cy) * (y - cy) <= r * r 181 | 182 | return fn 183 | 184 | def point_in_rect(xmin: float, xmax: float, ymin: float, ymax: float) -> FilterFunction: 185 | """ 186 | Return a filter function that returns True for points within a rectangle. 187 | 188 | Parameters 189 | ---------- 190 | xmin : float 191 | The minimum x-coordinate of the rectangle 192 | xmax : float 193 | The maximum x-coordinate of the rectangle 194 | ymin : float 195 | The minimum y-coordinate of the rectangle 196 | ymax : float 197 | The maximum y-coordinate of the rectangle 198 | 199 | Returns 200 | ------- 201 | fn : Callable(float, float) -> bool 202 | Filter function 203 | """ 204 | def fn(x, y): 205 | return x >= xmin and x <= xmax and y >= ymin and y <= ymax 206 | 207 | return fn 208 | 209 | def point_in_triangle( 210 | a: tuple[float, float], 211 | b: tuple[float, float], 212 | c: tuple[float, float]) -> FilterFunction: 213 | """ 214 | Return a filter function that returns True for points within a triangle. 215 | 216 | Parameters 217 | ---------- 218 | a : tuple[float, float] 219 | The first vertex of the triangle 220 | b : tuple[float, float] 221 | The second vertex of the triangle 222 | c : tuple[float, float] 223 | The third vertex of the triangle 224 | 225 | Returns 226 | ------- 227 | fn : Callable(float, float) -> bool 228 | Filter function 229 | """ 230 | a = np.array(a, dtype=np.float32) 231 | b = np.array(b, dtype=np.float32) 232 | c = np.array(c, dtype=np.float32) 233 | 234 | def fn(x, y): 235 | v0 = c - a 236 | v1 = b - a 237 | v2 = np.array((x, y)) - a 238 | 239 | # Compute dot products 240 | dot00 = np.dot(v0, v0) 241 | dot01 = np.dot(v0, v1) 242 | dot02 = np.dot(v0, v2) 243 | dot11 = np.dot(v1, v1) 244 | dot12 = np.dot(v1, v2) 245 | 246 | # Compute barycentric coordinates 247 | inv_denom = 1 / (dot00 * dot11 - dot01 * dot01) 248 | u = (dot11 * dot02 - dot01 * dot12) * inv_denom 249 | v = (dot00 * dot12 - dot01 * dot02) * inv_denom 250 | 251 | # Check if point is in triangle 252 | return (u >= 0) and (v >= 0) and (u + v) < 1 253 | 254 | return fn 255 | 256 | def highlight_img( 257 | img: ndarray[np.uint8], 258 | color: ndarray[np.uint8] = White, 259 | alpha=0.30) -> ndarray[np.uint8]: 260 | """ 261 | Add highlighting to an image. 262 | 263 | Parameters 264 | ---------- 265 | img : ndarray[uint8] of shape (height, width, 3) 266 | The image to highlight 267 | color : ndarray[uint8] of shape (3,) 268 | RGB color to use for highlighting 269 | alpha : float 270 | The alpha value to use for blending 271 | 272 | Returns 273 | ------- 274 | img : ndarray[uint8] of shape (height, width, 3) 275 | The highlighted image 276 | """ 277 | blend_img = img + alpha * (np.array(color, dtype=np.uint8) - img) 278 | blend_img = blend_img.clip(0, 255).astype(np.uint8) 279 | img[:, :, :] = blend_img 280 | -------------------------------------------------------------------------------- /multigrid/envs/locked_hallway.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from math import ceil 4 | from multigrid import MultiGridEnv 5 | from multigrid.core.actions import Action 6 | from multigrid.core.constants import Color, Direction 7 | from multigrid.core.mission import MissionSpace 8 | from multigrid.core.roomgrid import Room, RoomGrid 9 | from multigrid.core.world_object import Door, Key 10 | 11 | 12 | 13 | class LockedHallwayEnv(RoomGrid): 14 | """ 15 | .. image:: https://i.imgur.com/VylPtnn.gif 16 | :width: 325 17 | 18 | *********** 19 | Description 20 | *********** 21 | 22 | This environment consists of a hallway with multiple locked rooms on either side. 23 | To unlock each door, agents must first find the corresponding key, 24 | which may be in another locked room. Agents are rewarded for each door they unlock. 25 | 26 | The standard setting is cooperative, where all agents receive a reward 27 | for each door that is opened. 28 | 29 | ************* 30 | Mission Space 31 | ************* 32 | 33 | "unlock all the doors" 34 | 35 | ***************** 36 | Observation Space 37 | ***************** 38 | 39 | The multi-agent observation space is a Dict mapping from agent index to 40 | corresponding agent observation space. 41 | 42 | Each agent observation is a dictionary with the following entries: 43 | 44 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`) 45 | Encoding of the agent's partially observable view of the environment, 46 | where each grid cell is encoded as a 3 dimensional tuple: 47 | (:class:`.Type`, :class:`.Color`, :class:`.State`) 48 | * direction : int 49 | Agent's direction (0: right, 1: down, 2: left, 3: up) 50 | * mission : Mission 51 | Task string corresponding to the current environment configuration 52 | 53 | ************ 54 | Action Space 55 | ************ 56 | 57 | The multi-agent action space is a Dict mapping from agent index to 58 | corresponding agent action space. 59 | 60 | Agent actions are discrete integer values, given by: 61 | 62 | +-----+--------------+-----------------------------+ 63 | | Num | Name | Action | 64 | +=====+==============+=============================+ 65 | | 0 | left | Turn left | 66 | +-----+--------------+-----------------------------+ 67 | | 1 | right | Turn right | 68 | +-----+--------------+-----------------------------+ 69 | | 2 | forward | Move forward | 70 | +-----+--------------+-----------------------------+ 71 | | 3 | pickup | Pick up an object | 72 | +-----+--------------+-----------------------------+ 73 | | 4 | drop | Drop an object | 74 | +-----+--------------+-----------------------------+ 75 | | 5 | toggle | Toggle / activate an object | 76 | +-----+--------------+-----------------------------+ 77 | | 6 | done | Done completing task | 78 | +-----+--------------+-----------------------------+ 79 | 80 | ******* 81 | Rewards 82 | ******* 83 | 84 | A reward of ``1 - 0.9 * (step_count / max_steps)`` is given 85 | when a door is unlocked. 86 | 87 | *********** 88 | Termination 89 | *********** 90 | 91 | The episode ends if any one of the following conditions is met: 92 | 93 | * All doors are unlocked 94 | * Timeout (see ``max_steps``) 95 | 96 | ************************* 97 | Registered Configurations 98 | ************************* 99 | 100 | * ``MultiGrid-LockedHallway-2Rooms-v0`` 101 | * ``MultiGrid-LockedHallway-4Rooms-v0`` 102 | * ``MultiGrid-LockedHallway-6Rooms-v0`` 103 | """ 104 | 105 | def __init__( 106 | self, 107 | num_rooms: int = 6, 108 | room_size: int = 5, 109 | max_hallway_keys: int = 1, 110 | max_keys_per_room: int = 2, 111 | max_steps: int | None = None, 112 | joint_reward: bool = True, 113 | **kwargs): 114 | """ 115 | Parameters 116 | ---------- 117 | num_rooms : int, default=6 118 | Number of rooms in the environment 119 | room_size : int, default=5 120 | Width and height for each of the rooms 121 | max_hallway_keys : int, default=1 122 | Maximum number of keys in the hallway 123 | max_keys_per_room : int, default=2 124 | Maximum number of keys in each room 125 | max_steps : int, optional 126 | Maximum number of steps per episode 127 | joint_reward : bool, default=True 128 | Whether all agents receive the same reward 129 | **kwargs 130 | See :attr:`multigrid.base.MultiGridEnv.__init__` 131 | """ 132 | assert room_size >= 4 133 | assert num_rooms % 2 == 0 134 | 135 | self.num_rooms = num_rooms 136 | self.max_hallway_keys = max_hallway_keys 137 | self.max_keys_per_room = max_keys_per_room 138 | 139 | if max_steps is None: 140 | max_steps = 8 * num_rooms * room_size**2 141 | 142 | super().__init__( 143 | mission_space=MissionSpace.from_string("unlock all the doors"), 144 | room_size=room_size, 145 | num_rows=(num_rooms // 2), 146 | num_cols=3, 147 | max_steps=max_steps, 148 | joint_reward=joint_reward, 149 | **kwargs, 150 | ) 151 | 152 | def _gen_grid(self, width, height): 153 | """ 154 | :meta private: 155 | """ 156 | super()._gen_grid(width, height) 157 | 158 | LEFT, HALLWAY, RIGHT = range(3) # columns 159 | color_sequence = list(Color) * ceil(self.num_rooms / len(Color)) 160 | color_sequence = self._rand_perm(color_sequence)[:self.num_rooms] 161 | 162 | # Create hallway 163 | for row in range(self.num_rows - 1): 164 | self.remove_wall(HALLWAY, row, Direction.down) 165 | 166 | # Add doors 167 | self.rooms: dict[Color, Room] = {} 168 | door_colors = self._rand_perm(color_sequence) 169 | for row in range(self.num_rows): 170 | for col, dir in ((LEFT, Direction.right), (RIGHT, Direction.left)): 171 | color = door_colors.pop() 172 | self.rooms[color] = self.get_room(col, row) 173 | self.add_door( 174 | col, row, dir=dir, color=color, locked=True, rand_pos=False) 175 | 176 | # Place keys in hallway 177 | num_hallway_keys = self._rand_int(1, self.max_hallway_keys + 1) 178 | hallway_top = self.get_room(HALLWAY, 0).top 179 | hallway_size = (self.get_room(HALLWAY, 0).size[0], self.height) 180 | for key_color in color_sequence[:num_hallway_keys]: 181 | self.place_obj(Key(color=key_color), top=hallway_top, size=hallway_size) 182 | 183 | # Place keys in rooms 184 | key_index = num_hallway_keys 185 | while key_index < len(color_sequence): 186 | room = self.rooms[color_sequence[key_index - 1]] 187 | num_room_keys = self._rand_int(1, self.max_keys_per_room + 1) 188 | for key_color in color_sequence[key_index : key_index + num_room_keys]: 189 | self.place_obj(Key(color=key_color), top=room.top, size=room.size) 190 | key_index += 1 191 | 192 | # Place agents in hallway 193 | for agent in self.agents: 194 | MultiGridEnv.place_agent(self, agent, top=hallway_top, size=hallway_size) 195 | 196 | def reset(self, **kwargs): 197 | """ 198 | :meta private: 199 | """ 200 | self.unlocked_doors = [] 201 | return super().reset(**kwargs) 202 | 203 | def step(self, actions): 204 | """ 205 | :meta private: 206 | """ 207 | observations, rewards, terminations, truncations, infos = super().step(actions) 208 | 209 | # Reward for unlocking doors 210 | for agent_id, action in actions.items(): 211 | if action == Action.toggle: 212 | fwd_obj = self.grid.get(*self.agents[agent_id].front_pos) 213 | if isinstance(fwd_obj, Door) and not fwd_obj.is_locked: 214 | if fwd_obj not in self.unlocked_doors: 215 | self.unlocked_doors.append(fwd_obj) 216 | if self.joint_reward: 217 | for k in rewards: 218 | rewards[k] += self._reward() 219 | else: 220 | rewards[agent_id] += self._reward() 221 | 222 | # Check if all doors are unlocked 223 | if len(self.unlocked_doors) == len(self.rooms): 224 | for agent in self.agents: 225 | terminations[agent.index] = True 226 | 227 | return observations, rewards, terminations, truncations, infos 228 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import json 5 | import multigrid.rllib 6 | import os 7 | import random 8 | import ray 9 | import ray.train 10 | import ray.tune 11 | import torch 12 | import torch.nn as nn 13 | 14 | from multigrid.core.constants import Direction 15 | from pathlib import Path 16 | from ray.rllib.algorithms import PPOConfig 17 | from ray.rllib.core.columns import Columns 18 | from ray.rllib.core.rl_module import MultiRLModuleSpec, RLModuleSpec 19 | from ray.rllib.core.rl_module.apis import ValueFunctionAPI 20 | from ray.rllib.core.rl_module.torch import TorchRLModule 21 | from ray.rllib.utils.from_config import NotProvided 22 | from typing import Callable 23 | 24 | 25 | 26 | ### Helper Methods 27 | 28 | def get_policy_mapping_fn( 29 | checkpoint_dir: Path | str | None, 30 | num_agents: int, 31 | ) -> Callable: 32 | try: 33 | policies = sorted([ 34 | path for path in (checkpoint_dir / 'policies').iterdir() if path.is_dir()]) 35 | 36 | def policy_mapping_fn(agent_id, *args, **kwargs): 37 | return policies[agent_id % len(policies)].name 38 | 39 | print('Loading policies from:', checkpoint_dir) 40 | for agent_id in range(num_agents): 41 | print('Agent ID:', agent_id, 'Policy ID:', policy_mapping_fn(agent_id)) 42 | 43 | return policy_mapping_fn 44 | 45 | except: 46 | return lambda agent_id, *args, **kwargs: f'policy_{agent_id}' 47 | 48 | def find_checkpoint_dir(search_dir: Path | str | None) -> Path | None: 49 | try: 50 | checkpoints = Path(search_dir).expanduser().glob('**/rllib_checkpoint.json') 51 | if checkpoints: 52 | return sorted(checkpoints, key=os.path.getmtime)[-1].parent 53 | except: 54 | return None 55 | 56 | def preprocess_batch(batch: dict) -> torch.Tensor: 57 | image = batch['obs']['image'] 58 | direction = batch['obs']['direction'] 59 | direction = 2 * torch.pi * (direction / len(Direction)) 60 | direction = torch.stack([torch.cos(direction), torch.sin(direction)], dim=-1) 61 | direction = direction[..., None, None, :].expand(*image.shape[:-1], 2) 62 | x = torch.cat([image, direction], dim=-1).float() 63 | return x 64 | 65 | 66 | 67 | ### Models 68 | 69 | class MultiGridEncoder(nn.Module): 70 | 71 | def __init__(self, in_channels: int = 23): 72 | super().__init__() 73 | self.model = nn.Sequential( 74 | nn.Conv2d(in_channels, 16, (3, 3)), nn.ReLU(), 75 | nn.Conv2d(16, 32, (3, 3)), nn.ReLU(), 76 | nn.Conv2d(32, 64, (3, 3)), nn.ReLU(), 77 | nn.Flatten(), 78 | ) 79 | 80 | def forward(self, x): 81 | x = x[None] if x.ndim == 3 else x # add batch dimension 82 | x = x.permute(0, 3, 1, 2) # channels-first (NHWC -> NCHW) 83 | return self.model(x) 84 | 85 | 86 | class AgentModule(TorchRLModule, ValueFunctionAPI): 87 | 88 | def setup(self): 89 | self.base = nn.Identity() 90 | self.actor = nn.Sequential( 91 | MultiGridEncoder(in_channels=23), 92 | nn.Linear(64, 64), nn.ReLU(), 93 | nn.Linear(64, 7), 94 | ) 95 | self.critic = nn.Sequential( 96 | MultiGridEncoder(in_channels=23), 97 | nn.Linear(64, 64), nn.ReLU(), 98 | nn.Linear(64, 1), 99 | ) 100 | 101 | def _forward(self, batch, **kwargs): 102 | x = self.base(preprocess_batch(batch)) 103 | return {Columns.ACTION_DIST_INPUTS: self.actor(x)} 104 | 105 | def _forward_train(self, batch, **kwargs): 106 | x = self.base(preprocess_batch(batch)) 107 | return { 108 | Columns.ACTION_DIST_INPUTS: self.actor(x), 109 | Columns.EMBEDDINGS: self.critic(batch.get('value_inputs', x)), 110 | } 111 | 112 | def compute_values(self, batch, embeddings = None): 113 | if embeddings is None: 114 | x = self.base(preprocess_batch(batch)) 115 | embeddings = self.critic(batch.get('value_inputs', x)) 116 | 117 | return embeddings.squeeze(-1) 118 | 119 | def get_initial_state(self): 120 | return {} 121 | 122 | 123 | 124 | ### Training 125 | 126 | def get_algorithm_config( 127 | env: str = 'MultiGrid-Empty-8x8-v0', 128 | env_config: dict = {}, 129 | num_agents: int = 2, 130 | # lstm: bool = False, TODO: implement LSTM model 131 | num_workers: int = 0, 132 | num_gpus: int = 0, 133 | lr: float = NotProvided, 134 | batch_size: int = NotProvided, 135 | # centralized_critic: bool = False, TODO: implement centralized critic 136 | **kwargs, 137 | ) -> PPOConfig: 138 | """ 139 | Return the RL algorithm configuration dictionary. 140 | """ 141 | config = PPOConfig() 142 | config = config.api_stack( 143 | enable_env_runner_and_connector_v2=True, 144 | enable_rl_module_and_learner=True, 145 | ) 146 | config = config.debugging(seed=random.randint(0, 1000000)) 147 | config = config.env_runners( 148 | num_env_runners=num_workers, 149 | num_envs_per_env_runner=1, 150 | num_gpus_per_env_runner=num_gpus if torch.cuda.is_available() else 0, 151 | ) 152 | config = config.environment(env=env, env_config={**env_config, 'agents': num_agents}) 153 | config = config.framework('torch') 154 | config = config.multi_agent( 155 | policies={f'policy_{i}' for i in range(num_agents)}, 156 | policy_mapping_fn=get_policy_mapping_fn(None, num_agents), 157 | policies_to_train=[f'policy_{i}' for i in range(num_agents)], 158 | ) 159 | config = config.training(lr=lr, train_batch_size=batch_size) 160 | config = config.rl_module( 161 | rl_module_spec=MultiRLModuleSpec( 162 | rl_module_specs={ 163 | f'policy_{i}': RLModuleSpec(module_class=AgentModule) 164 | for i in range(num_agents) 165 | } 166 | ) 167 | ) 168 | 169 | return config 170 | 171 | def train( 172 | config: PPOConfig, 173 | stop_conditions: dict, 174 | save_dir: str, 175 | load_dir: str = None, 176 | ): 177 | """ 178 | Train an RLlib algorithm. 179 | """ 180 | checkpoint = find_checkpoint_dir(load_dir) 181 | if checkpoint: 182 | tuner = ray.tune.Tuner.restore(checkpoint) 183 | else: 184 | tuner = ray.tune.Tuner( 185 | config.algo_class, 186 | param_space=config, 187 | run_config=ray.train.RunConfig( 188 | storage_path=save_dir, 189 | stop=stop_conditions, 190 | verbose=1, 191 | checkpoint_config=ray.train.CheckpointConfig( 192 | checkpoint_frequency=20, 193 | checkpoint_at_end=True, 194 | ), 195 | ), 196 | ) 197 | 198 | results = tuner.fit() 199 | return results 200 | 201 | 202 | 203 | if __name__ == "__main__": 204 | parser = argparse.ArgumentParser() 205 | 206 | parser.add_argument( 207 | '--algo', type=str, default='PPO', 208 | help="The name of the RLlib-registered algorithm to use.") 209 | parser.add_argument( 210 | '--env', type=str, default='MultiGrid-Empty-8x8-v0', 211 | help="MultiGrid environment to use.") 212 | parser.add_argument( 213 | '--env-config', type=json.loads, default={}, 214 | help="Environment config dict, given as a JSON string (e.g. '{\"size\": 8}')") 215 | parser.add_argument( 216 | '--num-agents', type=int, default=2, 217 | help="Number of agents in environment.") 218 | # parser.add_argument( 219 | # '--lstm', action='store_true', 220 | # help="Use LSTM model.") 221 | # parser.add_argument( 222 | # '--centralized-critic', action='store_true', 223 | # help="Use centralized critic for training.") 224 | parser.add_argument( 225 | '--num-workers', type=int, default=8, 226 | help="Number of rollout workers.") 227 | parser.add_argument( 228 | '--num-gpus', type=int, default=1, 229 | help="Number of GPUs to train on.") 230 | parser.add_argument( 231 | '--num-timesteps', type=int, default=1e7, 232 | help="Total number of timesteps to train.") 233 | parser.add_argument( 234 | '--lr', type=float, 235 | help="Learning rate for training.") 236 | parser.add_argument( 237 | '--load-dir', type=str, 238 | help="Checkpoint directory for loading pre-trained policies.") 239 | parser.add_argument( 240 | '--save-dir', type=str, default='~/ray_results/', 241 | help="Directory for saving checkpoints, results, and trained policies.") 242 | 243 | args = parser.parse_args() 244 | config = get_algorithm_config(**vars(args)) 245 | 246 | print() 247 | print(f"Running with following CLI options: {args}") 248 | print('\n', '-' * 64, '\n', "Training with following configuration:", '\n', '-' * 64) 249 | print() 250 | 251 | stop_conditions = { 252 | 'learners/__all_modules__/num_env_steps_trained_lifetime': args.num_timesteps} 253 | train(config, stop_conditions, args.save_dir, args.load_dir) 254 | -------------------------------------------------------------------------------- /multigrid/core/grid.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | 5 | from collections import defaultdict 6 | from functools import cached_property 7 | from numpy.typing import NDArray as ndarray 8 | from typing import Any, Callable, Iterable 9 | 10 | from .agent import Agent 11 | from .constants import Type, TILE_PIXELS 12 | from .world_object import Wall, WorldObj 13 | 14 | from ..utils.rendering import ( 15 | downsample, 16 | fill_coords, 17 | highlight_img, 18 | point_in_rect, 19 | ) 20 | 21 | 22 | 23 | class Grid: 24 | """ 25 | Class representing a grid of :class:`.WorldObj` objects. 26 | 27 | Attributes 28 | ---------- 29 | width : int 30 | Width of the grid 31 | height : int 32 | Height of the grid 33 | world_objects : dict[tuple[int, int], WorldObj] 34 | Dictionary of world objects in the grid, indexed by (x, y) location 35 | state : ndarray[int] of shape (width, height, WorldObj.dim) 36 | Grid state, where each (x, y) entry is a world object encoding 37 | """ 38 | 39 | # Static cache of pre-renderer tiles 40 | _tile_cache: dict[tuple[Any, ...], Any] = {} 41 | 42 | def __init__(self, width: int, height: int): 43 | """ 44 | Parameters 45 | ---------- 46 | width : int 47 | Width of the grid 48 | height : int 49 | Height of the grid 50 | """ 51 | assert width >= 3 52 | assert height >= 3 53 | self.world_objects: dict[tuple[int, int], WorldObj] = {} # indexed by location 54 | self.state: ndarray[np.int] = np.zeros((width, height, WorldObj.dim), dtype=int) 55 | self.state[...] = WorldObj.empty() 56 | 57 | @cached_property 58 | def width(self) -> int: 59 | """ 60 | Width of the grid. 61 | """ 62 | return self.state.shape[0] 63 | 64 | @cached_property 65 | def height(self) -> int: 66 | """ 67 | Height of the grid. 68 | """ 69 | return self.state.shape[1] 70 | 71 | @property 72 | def grid(self) -> list[WorldObj | None]: 73 | """ 74 | Return a list of all world objects in the grid. 75 | """ 76 | return [self.get(i, j) for i in range(self.width) for j in range(self.height)] 77 | 78 | def set(self, x: int, y: int, obj: WorldObj | None): 79 | """ 80 | Set a world object at the given coordinates. 81 | 82 | Parameters 83 | ---------- 84 | x : int 85 | Grid x-coordinate 86 | y : int 87 | Grid y-coordinate 88 | obj : WorldObj or None 89 | Object to place 90 | """ 91 | # Update world object dictionary 92 | self.world_objects[x, y] = obj 93 | 94 | # Update grid state 95 | if isinstance(obj, WorldObj): 96 | self.state[x, y] = obj 97 | elif obj is None: 98 | self.state[x, y] = WorldObj.empty() 99 | else: 100 | raise TypeError(f"cannot set grid value to {type(obj)}") 101 | 102 | def get(self, x: int, y: int) -> WorldObj | None: 103 | """ 104 | Get the world object at the given coordinates. 105 | 106 | Parameters 107 | ---------- 108 | x : int 109 | Grid x-coordinate 110 | y : int 111 | Grid y-coordinate 112 | """ 113 | # Create WorldObj instance if none exists 114 | if (x, y) not in self.world_objects: 115 | self.world_objects[x, y] = WorldObj.from_array(self.state[x, y]) 116 | 117 | return self.world_objects[x, y] 118 | 119 | def update(self, x: int, y: int): 120 | """ 121 | Update the grid state from the world object at the given coordinates. 122 | 123 | Parameters 124 | ---------- 125 | x : int 126 | Grid x-coordinate 127 | y : int 128 | Grid y-coordinate 129 | """ 130 | if (x, y) in self.world_objects: 131 | self.state[x, y] = self.world_objects[x, y] 132 | 133 | def horz_wall( 134 | self, 135 | x: int, y: int, 136 | length: int | None = None, 137 | obj_type: Callable[[], WorldObj] = Wall): 138 | """ 139 | Create a horizontal wall. 140 | 141 | Parameters 142 | ---------- 143 | x : int 144 | Leftmost x-coordinate of wall 145 | y : int 146 | Y-coordinate of wall 147 | length : int or None 148 | Length of wall. If None, wall extends to the right edge of the grid. 149 | obj_type : Callable() -> WorldObj 150 | Function that returns a WorldObj instance to use for the wall 151 | """ 152 | length = self.width - x if length is None else length 153 | self.state[x:x+length, y] = obj_type() 154 | 155 | def vert_wall( 156 | self, 157 | x: int, y: int, 158 | length: int | None = None, 159 | obj_type: Callable[[], WorldObj] = Wall): 160 | """ 161 | Create a vertical wall. 162 | 163 | Parameters 164 | ---------- 165 | x : int 166 | X-coordinate of wall 167 | y : int 168 | Topmost y-coordinate of wall 169 | length : int or None 170 | Length of wall. If None, wall extends to the bottom edge of the grid. 171 | obj_type : Callable() -> WorldObj 172 | Function that returns a WorldObj instance to use for the wall 173 | """ 174 | length = self.height - y if length is None else length 175 | self.state[x, y:y+length] = obj_type() 176 | 177 | def wall_rect(self, x: int, y: int, w: int, h: int): 178 | """ 179 | Create a walled rectangle. 180 | 181 | Parameters 182 | ---------- 183 | x : int 184 | X-coordinate of top-left corner 185 | y : int 186 | Y-coordinate of top-left corner 187 | w : int 188 | Width of rectangle 189 | h : int 190 | Height of rectangle 191 | """ 192 | self.horz_wall(x, y, w) 193 | self.horz_wall(x, y + h - 1, w) 194 | self.vert_wall(x, y, h) 195 | self.vert_wall(x + w - 1, y, h) 196 | 197 | @classmethod 198 | def render_tile( 199 | cls, 200 | obj: WorldObj | None = None, 201 | agent: Agent | None = None, 202 | highlight: bool = False, 203 | tile_size: int = TILE_PIXELS, 204 | subdivs: int = 3) -> ndarray[np.uint8]: 205 | """ 206 | Render a tile and cache the result. 207 | 208 | Parameters 209 | ---------- 210 | obj : WorldObj or None 211 | Object to render 212 | agent : Agent or None 213 | Agent to render 214 | highlight : bool 215 | Whether to highlight the tile 216 | tile_size : int 217 | Tile size (in pixels) 218 | subdivs : int 219 | Downsampling factor for supersampling / anti-aliasing 220 | """ 221 | # Hash map lookup key for the cache 222 | key: tuple[Any, ...] = (highlight, tile_size) 223 | if agent: 224 | key += (agent.state.color, agent.state.dir) 225 | else: 226 | key += (None, None) 227 | key = obj.encode() + key if obj else key 228 | 229 | if key in cls._tile_cache: 230 | return cls._tile_cache[key] 231 | 232 | img = np.zeros( 233 | shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8) 234 | 235 | # Draw the grid lines (top and left edges) 236 | fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100)) 237 | fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100)) 238 | 239 | # Draw the object 240 | if obj is not None: 241 | obj.render(img) 242 | 243 | # Draw the agent 244 | if agent is not None and not agent.state.terminated: 245 | agent.render(img) 246 | 247 | # Highlight the cell if needed 248 | if highlight: 249 | highlight_img(img) 250 | 251 | # Downsample the image to perform supersampling/anti-aliasing 252 | img = downsample(img, subdivs) 253 | 254 | # Cache the rendered tile 255 | cls._tile_cache[key] = img 256 | 257 | return img 258 | 259 | def render( 260 | self, 261 | tile_size: int, 262 | agents: Iterable[Agent] = (), 263 | highlight_mask: ndarray[np.bool] | None = None) -> ndarray[np.uint8]: 264 | """ 265 | Render this grid at a given scale. 266 | 267 | Parameters 268 | ---------- 269 | tile_size: int 270 | Tile size (in pixels) 271 | agents: Iterable[Agent] 272 | Agents to render 273 | highlight_mask: ndarray 274 | Boolean mask indicating which grid locations to highlight 275 | """ 276 | if highlight_mask is None: 277 | highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool) 278 | 279 | # Get agent locations 280 | # For overlapping agents, non-terminated agents get priority 281 | location_to_agent = defaultdict(type(None)) 282 | for agent in sorted(agents, key=lambda a: not a.terminated): 283 | location_to_agent[tuple(agent.pos)] = agent 284 | 285 | # Initialize pixel array 286 | width_px = self.width * tile_size 287 | height_px = self.height * tile_size 288 | img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8) 289 | 290 | # Render the grid 291 | for j in range(0, self.height): 292 | for i in range(0, self.width): 293 | assert highlight_mask is not None 294 | cell = self.get(i, j) 295 | tile_img = Grid.render_tile( 296 | cell, 297 | agent=location_to_agent[i, j], 298 | highlight=highlight_mask[i, j], 299 | tile_size=tile_size, 300 | ) 301 | 302 | ymin = j * tile_size 303 | ymax = (j + 1) * tile_size 304 | xmin = i * tile_size 305 | xmax = (i + 1) * tile_size 306 | img[ymin:ymax, xmin:xmax, :] = tile_img 307 | 308 | return img 309 | 310 | def encode(self, vis_mask: ndarray[np.bool] | None = None) -> ndarray[np.int]: 311 | """ 312 | Produce a compact numpy encoding of the grid. 313 | 314 | Parameters 315 | ---------- 316 | vis_mask : ndarray[bool] of shape (width, height) 317 | Visibility mask 318 | """ 319 | if vis_mask is None: 320 | vis_mask = np.ones((self.width, self.height), dtype=bool) 321 | 322 | encoding = self.state.copy() 323 | encoding[~vis_mask][..., WorldObj.TYPE] = Type.unseen.to_index() 324 | return encoding 325 | 326 | @staticmethod 327 | def decode(array: ndarray[np.int]) -> tuple['Grid', ndarray[np.bool]]: 328 | """ 329 | Decode an array grid encoding back into a `Grid` instance. 330 | 331 | Parameters 332 | ---------- 333 | array : ndarray[int] of shape (width, height, dim) 334 | Grid encoding 335 | 336 | Returns 337 | ------- 338 | grid : Grid 339 | Decoded `Grid` instance 340 | vis_mask : ndarray[bool] of shape (width, height) 341 | Visibility mask 342 | """ 343 | width, height, dim = array.shape 344 | assert dim == WorldObj.dim 345 | 346 | vis_mask = (array[..., WorldObj.TYPE] != Type.unseen.to_index()) 347 | grid = Grid(width, height) 348 | grid.state[vis_mask] = array[vis_mask] 349 | return grid, vis_mask 350 | -------------------------------------------------------------------------------- /multigrid/core/agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | 5 | from gymnasium import spaces 6 | from numpy.typing import ArrayLike, NDArray as ndarray 7 | 8 | from .actions import Action 9 | from .constants import Color, Direction, Type 10 | from .mission import Mission, MissionSpace 11 | from .world_object import WorldObj 12 | 13 | from ..utils.misc import front_pos, PropertyAlias 14 | from ..utils.rendering import ( 15 | fill_coords, 16 | point_in_triangle, 17 | rotate_fn, 18 | ) 19 | 20 | 21 | 22 | class Agent: 23 | """ 24 | Class representing an agent in the environment. 25 | 26 | :Observation Space: 27 | 28 | Observations are dictionaries with the following entries: 29 | 30 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`) 31 | Encoding of the agent's view of the environment 32 | * direction : int 33 | Agent's direction (0: right, 1: down, 2: left, 3: up) 34 | * mission : Mission 35 | Task string corresponding to the current environment configuration 36 | 37 | :Action Space: 38 | 39 | Actions are discrete integers, as enumerated in :class:`.Action`. 40 | 41 | Attributes 42 | ---------- 43 | index : int 44 | Index of the agent in the environment 45 | state : AgentState 46 | State of the agent 47 | mission : Mission 48 | Current mission string for the agent 49 | action_space : gym.spaces.Discrete 50 | Action space for the agent 51 | observation_space : gym.spaces.Dict 52 | Observation space for the agent 53 | """ 54 | 55 | def __init__( 56 | self, 57 | index: int, 58 | mission_space: MissionSpace = MissionSpace.from_string('maximize reward'), 59 | view_size: int = 7, 60 | see_through_walls: bool = False): 61 | """ 62 | Parameters 63 | ---------- 64 | index : int 65 | Index of the agent in the environment 66 | mission_space : MissionSpace 67 | The mission space for the agent 68 | view_size : int 69 | The size of the agent's view (must be odd) 70 | see_through_walls : bool 71 | Whether the agent can see through walls 72 | """ 73 | self.index: int = index 74 | self.state: AgentState = AgentState() 75 | self.mission: Mission = None 76 | 77 | # Number of cells (width and height) in the agent view 78 | assert view_size % 2 == 1 79 | assert view_size >= 3 80 | self.view_size = view_size 81 | self.see_through_walls = see_through_walls 82 | 83 | # Observations are dictionaries containing an 84 | # encoding of the grid and a textual 'mission' string 85 | self.observation_space = spaces.Dict({ 86 | 'image': spaces.Box( 87 | low=0, 88 | high=255, 89 | shape=(view_size, view_size, WorldObj.dim), 90 | dtype=int, 91 | ), 92 | 'direction': spaces.Discrete(len(Direction)), 93 | 'mission': mission_space, 94 | }) 95 | 96 | # Actions are discrete integer values 97 | self.action_space = spaces.Discrete(len(Action)) 98 | 99 | # AgentState Properties 100 | color = PropertyAlias( 101 | 'state', 'color', doc='Alias for :attr:`AgentState.color`.') 102 | dir = PropertyAlias( 103 | 'state', 'dir', doc='Alias for :attr:`AgentState.dir`.') 104 | pos = PropertyAlias( 105 | 'state', 'pos', doc='Alias for :attr:`AgentState.pos`.') 106 | terminated = PropertyAlias( 107 | 'state', 'terminated', doc='Alias for :attr:`AgentState.terminated`.') 108 | carrying = PropertyAlias( 109 | 'state', 'carrying', doc='Alias for :attr:`AgentState.carrying`.') 110 | 111 | @property 112 | def front_pos(self) -> tuple[int, int]: 113 | """ 114 | Get the position of the cell that is directly in front of the agent. 115 | """ 116 | agent_dir = self.state._view[AgentState.DIR] 117 | agent_pos = self.state._view[AgentState.POS] 118 | return front_pos(*agent_pos, agent_dir) 119 | 120 | def reset(self, mission: Mission = Mission('maximize reward')): 121 | """ 122 | Reset the agent to an initial state. 123 | 124 | Parameters 125 | ---------- 126 | mission : Mission 127 | Mission string to use for the new episode 128 | """ 129 | self.mission = mission 130 | self.state.pos = (-1, -1) 131 | self.state.dir = -1 132 | self.state.terminated = False 133 | self.state.carrying = None 134 | 135 | def encode(self) -> tuple[int, int, int]: 136 | """ 137 | Encode a description of this agent as a 3-tuple of integers. 138 | 139 | Returns 140 | ------- 141 | type_idx : int 142 | The index of the agent type 143 | color_idx : int 144 | The index of the agent color 145 | agent_dir : int 146 | The direction of the agent (0: right, 1: down, 2: left, 3: up) 147 | """ 148 | return (Type.agent.to_index(), self.state.color.to_index(), self.state.dir) 149 | 150 | def render(self, img: ndarray[np.uint8]): 151 | """ 152 | Draw the agent. 153 | 154 | Parameters 155 | ---------- 156 | img : ndarray[int] of shape (width, height, 3) 157 | RGB image array to render agent on 158 | """ 159 | tri_fn = point_in_triangle( 160 | (0.12, 0.19), 161 | (0.87, 0.50), 162 | (0.12, 0.81), 163 | ) 164 | 165 | # Rotate the agent based on its direction 166 | tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * np.pi * self.state.dir) 167 | fill_coords(img, tri_fn, self.state.color.rgb()) 168 | 169 | 170 | class AgentState(np.ndarray): 171 | """ 172 | State for an :class:`.Agent` object. 173 | 174 | ``AgentState`` objects also support vectorized operations, 175 | in which case the ``AgentState`` object represents the states of multiple agents. 176 | 177 | Attributes 178 | ---------- 179 | color : Color or ndarray[str] 180 | Agent color 181 | dir : Direction or ndarray[int] 182 | Agent direction (0: right, 1: down, 2: left, 3: up) 183 | pos : tuple[int, int] or ndarray[int] 184 | Agent (x, y) position 185 | terminated : bool or ndarray[bool] 186 | Whether the agent has terminated 187 | carrying : WorldObj or None or ndarray[object] 188 | Object the agent is carrying 189 | 190 | Examples 191 | -------- 192 | Create a vectorized agent state for 3 agents: 193 | 194 | >>> agent_state = AgentState(3) 195 | >>> agent_state 196 | AgentState(3) 197 | 198 | Access and set state attributes for one agent at a time: 199 | 200 | >>> a = agent_state[0] 201 | >>> a 202 | AgentState() 203 | >>> a.color 204 | 'red' 205 | >>> a.color = 'yellow' 206 | 207 | The underlying vectorized state is automatically updated as well: 208 | 209 | >>> agent_state.color 210 | array(['yellow', 'green', 'blue']) 211 | 212 | Access and set state attributes all at once: 213 | 214 | >>> agent_state.dir 215 | array([-1, -1, -1]) 216 | >>> agent_state.dir = np.random.randint(4, size=(len(agent_state))) 217 | >>> agent_state.dir 218 | array([2, 3, 0]) 219 | >>> a.dir 220 | 2 221 | """ 222 | # State vector indices 223 | TYPE = 0 224 | COLOR = 1 225 | DIR = 2 226 | ENCODING = slice(0, 3) 227 | POS = slice(3, 5) 228 | TERMINATED = 5 229 | CARRYING = slice(6, 6 + WorldObj.dim) 230 | 231 | # State vector dimension 232 | dim = 6 + WorldObj.dim 233 | 234 | def __new__(cls, *dims: int): 235 | """ 236 | Parameters 237 | ---------- 238 | dims : int, optional 239 | Shape of vectorized agent state 240 | """ 241 | obj = np.zeros(dims + (cls.dim,), dtype=int).view(cls) 242 | 243 | # Set default values 244 | obj[..., AgentState.TYPE] = Type.agent 245 | obj[..., AgentState.COLOR].flat = Color.cycle(np.prod(dims)) 246 | obj[..., AgentState.DIR] = -1 247 | obj[..., AgentState.POS] = (-1, -1) 248 | 249 | # Other attributes 250 | obj._carried_obj = np.empty(dims, dtype=object) # object references 251 | obj._terminated = np.zeros(dims, dtype=bool) # cache for faster access 252 | obj._view = obj.view(np.ndarray) # view of the underlying array (faster indexing) 253 | 254 | return obj 255 | 256 | def __repr__(self): 257 | shape = str(self.shape[:-1]).replace(",)", ")") 258 | return f'{self.__class__.__name__}{shape}' 259 | 260 | def __getitem__(self, idx): 261 | out = super().__getitem__(idx) 262 | if out.shape and out.shape[-1] == self.dim: 263 | out._view = self._view[idx, ...] 264 | out._carried_obj = self._carried_obj[idx, ...] # set carried object reference 265 | out._terminated = self._terminated[idx, ...] # set terminated cache 266 | 267 | return out 268 | 269 | @property 270 | def color(self) -> Color | ndarray[np.str]: 271 | """ 272 | Return the agent color. 273 | """ 274 | return Color.from_index(self._view[..., AgentState.COLOR]) 275 | 276 | @color.setter 277 | def color(self, value: str | ArrayLike[str]): 278 | """ 279 | Set the agent color. 280 | """ 281 | self[..., AgentState.COLOR] = np.vectorize(lambda c: Color(c).to_index())(value) 282 | 283 | @property 284 | def dir(self) -> Direction | ndarray[np.int]: 285 | """ 286 | Return the agent direction. 287 | """ 288 | out = self._view[..., AgentState.DIR] 289 | return Direction(out.item()) if out.ndim == 0 else out 290 | 291 | @dir.setter 292 | def dir(self, value: int | ArrayLike[int]): 293 | """ 294 | Set the agent direction. 295 | """ 296 | self[..., AgentState.DIR] = value 297 | 298 | @property 299 | def pos(self) -> tuple[int, int] | ndarray[np.int]: 300 | """ 301 | Return the agent's (x, y) position. 302 | """ 303 | out = self._view[..., AgentState.POS] 304 | return tuple(out) if out.ndim == 1 else out 305 | 306 | @pos.setter 307 | def pos(self, value: ArrayLike[int] | ArrayLike[ArrayLike[int]]): 308 | """ 309 | Set the agent's (x, y) position. 310 | """ 311 | self[..., AgentState.POS] = value 312 | 313 | @property 314 | def terminated(self) -> bool | ndarray[np.bool]: 315 | """ 316 | Return whether the agent has terminated. 317 | """ 318 | out = self._terminated 319 | return out.item() if out.ndim == 0 else out 320 | 321 | @terminated.setter 322 | def terminated(self, value: bool | ArrayLike[bool]): 323 | """ 324 | Set whether the agent has terminated. 325 | """ 326 | self[..., AgentState.TERMINATED] = value 327 | self._terminated[...] = value 328 | 329 | @property 330 | def carrying(self) -> WorldObj | None | ndarray[np.object]: 331 | """ 332 | Return the object the agent is carrying. 333 | """ 334 | out = self._carried_obj 335 | return out.item() if out.ndim == 0 else out 336 | 337 | @carrying.setter 338 | def carrying(self, obj: WorldObj | None | ArrayLike[object]): 339 | """ 340 | Set the object the agent is carrying. 341 | """ 342 | self[..., AgentState.CARRYING] = WorldObj.empty() if obj is None else obj 343 | if isinstance(obj, (WorldObj, type(None))): 344 | self._carried_obj[...].fill(obj) 345 | else: 346 | self._carried_obj[...] = obj 347 | -------------------------------------------------------------------------------- /multigrid/utils/obs.py: -------------------------------------------------------------------------------- 1 | import numba as nb 2 | import numpy as np 3 | 4 | from ..core.agent import AgentState 5 | from ..core.constants import Color, Direction, State, Type 6 | from ..core.world_object import Wall, WorldObj 7 | 8 | from numpy.typing import NDArray as ndarray 9 | 10 | 11 | 12 | ### Constants 13 | 14 | WALL_ENCODING = Wall().encode() 15 | UNSEEN_ENCODING = WorldObj(Type.unseen, Color.from_index(0)).encode() 16 | ENCODE_DIM = WorldObj.dim 17 | 18 | GRID_ENCODING_IDX = slice(None) 19 | 20 | AGENT_DIR_IDX = AgentState.DIR 21 | AGENT_POS_IDX = AgentState.POS 22 | AGENT_TERMINATED_IDX = AgentState.TERMINATED 23 | AGENT_CARRYING_IDX = AgentState.CARRYING 24 | AGENT_ENCODING_IDX = AgentState.ENCODING 25 | 26 | TYPE = WorldObj.TYPE 27 | STATE = WorldObj.STATE 28 | 29 | WALL = int(Type.wall) 30 | DOOR = int(Type.door) 31 | 32 | OPEN = int(State.open) 33 | CLOSED = int(State.closed) 34 | LOCKED = int(State.locked) 35 | 36 | RIGHT = int(Direction.right) 37 | LEFT = int(Direction.left) 38 | UP = int(Direction.up) 39 | DOWN = int(Direction.down) 40 | 41 | 42 | 43 | 44 | ### Observation Functions 45 | 46 | @nb.njit(cache=True) 47 | def see_behind(world_obj: ndarray[np.int_]) -> bool: 48 | """ 49 | Can an agent see behind this object? 50 | 51 | Parameters 52 | ---------- 53 | world_obj : ndarray[int] of shape (encode_dim,) 54 | World object encoding 55 | """ 56 | if world_obj is None: 57 | return True 58 | if world_obj[TYPE] == WALL: 59 | return False 60 | elif world_obj[TYPE] == DOOR and world_obj[STATE] != OPEN: 61 | return False 62 | 63 | return True 64 | 65 | @nb.njit(cache=True) 66 | def gen_obs_grid_encoding( 67 | grid_state: ndarray[np.int_], 68 | agent_state: ndarray[np.int_], 69 | agent_view_size: int, 70 | see_through_walls: bool) -> ndarray[np.int_]: 71 | """ 72 | Generate encoding for the sub-grid observed by an agent (including visibility mask). 73 | 74 | Parameters 75 | ---------- 76 | grid_state : ndarray[int] of shape (width, height, grid_state_dim) 77 | Array representation for each grid object 78 | agent_state : ndarray[int] of shape (num_agents, agent_state_dim) 79 | Array representation for each agent 80 | agent_view_size : int 81 | Width and height of observation sub-grids 82 | see_through_walls : bool 83 | Whether the agent can see through walls 84 | 85 | Returns 86 | ------- 87 | img : ndarray[int] of shape (num_agents, view_size, view_size, encode_dim) 88 | Encoding of observed sub-grid for each agent 89 | """ 90 | obs_grid = gen_obs_grid(grid_state, agent_state, agent_view_size) 91 | 92 | # Generate and apply visibility masks 93 | vis_mask = get_vis_mask(obs_grid) 94 | num_agents = len(agent_state) 95 | for agent in range(num_agents): 96 | if not see_through_walls: 97 | for i in range(agent_view_size): 98 | for j in range(agent_view_size): 99 | if not vis_mask[agent, i, j]: 100 | obs_grid[agent, i, j] = UNSEEN_ENCODING 101 | 102 | return obs_grid 103 | 104 | @nb.njit(cache=True) 105 | def gen_obs_grid_vis_mask( 106 | grid_state: ndarray[np.int_], 107 | agent_state: ndarray[np.int_], 108 | agent_view_size: int) -> ndarray[np.int_]: 109 | """ 110 | Generate visibility mask for the sub-grid observed by an agent. 111 | 112 | Parameters 113 | ---------- 114 | grid_state : ndarray[int] of shape (width, height, grid_state_dim) 115 | Array representation for each grid object 116 | agent_state : ndarray[int] of shape (num_agents, agent_state_dim) 117 | Array representation for each agent 118 | agent_view_size : int 119 | Width and height of observation sub-grids 120 | 121 | Returns 122 | ------- 123 | mask : ndarray[int] of shape (num_agents, view_size, view_size) 124 | Encoding of observed sub-grid for each agent 125 | """ 126 | obs_grid = gen_obs_grid(grid_state, agent_state, agent_view_size) 127 | return get_vis_mask(obs_grid) 128 | 129 | 130 | @nb.njit(cache=True) 131 | def gen_obs_grid( 132 | grid_state: ndarray[np.int_], 133 | agent_state: ndarray[np.int_], 134 | agent_view_size: int) -> ndarray[np.int_]: 135 | """ 136 | Generate the sub-grid observed by each agent (WITHOUT visibility mask). 137 | 138 | Parameters 139 | ---------- 140 | grid_state : ndarray[int] of shape (width, height, grid_state_dim) 141 | Array representation for each grid object 142 | agent_state : ndarray[int] of shape (num_agents, agent_state_dim) 143 | Array representation for each agent 144 | agent_view_size : int 145 | Width and height of observation sub-grids 146 | 147 | Returns 148 | ------- 149 | obs_grid : ndarray[int] of shape (num_agents, width, height, encode_dim) 150 | Observed sub-grid for each agent 151 | """ 152 | num_agents = len(agent_state) 153 | obs_width, obs_height = agent_view_size, agent_view_size 154 | 155 | # Process agent states 156 | agent_grid_encoding = agent_state[..., AGENT_ENCODING_IDX] 157 | agent_dir = agent_state[..., AGENT_DIR_IDX] 158 | agent_pos = agent_state[..., AGENT_POS_IDX] 159 | agent_terminated = agent_state[..., AGENT_TERMINATED_IDX] 160 | agent_carrying_encoding = agent_state[..., AGENT_CARRYING_IDX] 161 | 162 | # Get grid encoding 163 | if num_agents > 1: 164 | grid_encoding = np.empty((*grid_state.shape[:-1], ENCODE_DIM), dtype=np.int_) 165 | grid_encoding[...] = grid_state[..., GRID_ENCODING_IDX] 166 | 167 | # Insert agent grid encodings 168 | for agent in range(num_agents): 169 | if not agent_terminated[agent]: 170 | i, j = agent_pos[agent] 171 | grid_encoding[i, j, GRID_ENCODING_IDX] = agent_grid_encoding[agent] 172 | else: 173 | grid_encoding = grid_state[..., GRID_ENCODING_IDX] 174 | 175 | # Get top left corner of observation grids 176 | top_left = get_view_exts(agent_dir, agent_pos, agent_view_size) 177 | topX, topY = top_left[:, 0], top_left[:, 1] 178 | 179 | # Populate observation grids 180 | num_left_rotations = (agent_dir + 1) % 4 181 | obs_grid = np.empty((num_agents, obs_width, obs_height, ENCODE_DIM), dtype=np.int_) 182 | for agent in range(num_agents): 183 | for i in range(0, obs_width): 184 | for j in range(0, obs_height): 185 | # Absolute coordinates in world grid 186 | x, y = topX[agent] + i, topY[agent] + j 187 | 188 | # Rotated relative coordinates for observation grid 189 | if num_left_rotations[agent] == 0: 190 | i_rot, j_rot = i, j 191 | elif num_left_rotations[agent] == 1: 192 | i_rot, j_rot = j, obs_width - i - 1 193 | elif num_left_rotations[agent] == 2: 194 | i_rot, j_rot = obs_width - i - 1, obs_height - j - 1 195 | elif num_left_rotations[agent] == 3: 196 | i_rot, j_rot = obs_height - j - 1, i 197 | 198 | # Set observation grid 199 | if 0 <= x < grid_encoding.shape[0] and 0 <= y < grid_encoding.shape[1]: 200 | obs_grid[agent, i_rot, j_rot] = grid_encoding[x, y] 201 | else: 202 | obs_grid[agent, i_rot, j_rot] = WALL_ENCODING 203 | 204 | # Make it so each agent sees what it's carrying 205 | # We do this by placing the carried object at the agent position 206 | # in each agent's partially observable view 207 | obs_grid[:, obs_width // 2, obs_height - 1] = agent_carrying_encoding 208 | 209 | return obs_grid 210 | 211 | @nb.njit(cache=True) 212 | def get_see_behind_mask(grid_array: ndarray[np.int_]) -> ndarray[np.int_]: 213 | """ 214 | Return boolean mask indicating which grid locations can be seen through. 215 | 216 | Parameters 217 | ---------- 218 | grid_array : ndarray[int] of shape (num_agents, width, height, dim) 219 | Grid object array for each agent 220 | 221 | Returns 222 | ------- 223 | see_behind_mask : ndarray[bool] of shape (width, height) 224 | Boolean visibility mask 225 | """ 226 | num_agents, width, height = grid_array.shape[:3] 227 | see_behind_mask = np.zeros((num_agents, width, height), dtype=np.bool_) 228 | for agent in range(num_agents): 229 | for i in range(width): 230 | for j in range(height): 231 | see_behind_mask[agent, i, j] = see_behind(grid_array[agent, i, j]) 232 | 233 | return see_behind_mask 234 | 235 | @nb.njit(cache=True) 236 | def get_vis_mask(obs_grid: ndarray[np.int_]) -> ndarray[np.bool_]: 237 | """ 238 | Generate a boolean mask indicating which grid locations are visible to each agent. 239 | 240 | Parameters 241 | ---------- 242 | obs_grid : ndarray[int] of shape (num_agents, width, height, dim) 243 | Grid object array for each agent observation 244 | 245 | Returns 246 | ------- 247 | vis_mask : ndarray[bool] of shape (num_agents, width, height) 248 | Boolean visibility mask for each agent 249 | """ 250 | num_agents, width, height = obs_grid.shape[:3] 251 | see_behind_mask = get_see_behind_mask(obs_grid) 252 | vis_mask = np.zeros((num_agents, width, height), dtype=np.bool_) 253 | vis_mask[:, width // 2, height - 1] = True # agent relative position 254 | 255 | for agent in range(num_agents): 256 | for j in range(height - 1, -1, -1): 257 | # Forward pass 258 | for i in range(0, width - 1): 259 | if vis_mask[agent, i, j] and see_behind_mask[agent, i, j]: 260 | vis_mask[agent, i + 1, j] = True 261 | if j > 0: 262 | vis_mask[agent, i + 1, j - 1] = True 263 | vis_mask[agent, i, j - 1] = True 264 | 265 | # Backward pass 266 | for i in range(width - 1, 0, -1): 267 | if vis_mask[agent, i, j] and see_behind_mask[agent, i, j]: 268 | vis_mask[agent, i - 1, j] = True 269 | if j > 0: 270 | vis_mask[agent, i - 1, j - 1] = True 271 | vis_mask[agent, i, j - 1] = True 272 | 273 | return vis_mask 274 | 275 | @nb.njit(cache=True) 276 | def get_view_exts( 277 | agent_dir: ndarray[np.int_], 278 | agent_pos: ndarray[np.int_], 279 | agent_view_size: int) -> ndarray[np.int_]: 280 | """ 281 | Get the extents of the square set of grid cells visible to each agent. 282 | 283 | Parameters 284 | ---------- 285 | agent_dir : ndarray[int] of shape (num_agents,) 286 | Direction of each agent 287 | agent_pos : ndarray[int] of shape (num_agents, 2) 288 | The (x, y) position of each agent 289 | agent_view_size : int 290 | Width and height of agent view 291 | 292 | Returns 293 | ------- 294 | top_left : ndarray[int] of shape (num_agents, 2) 295 | The (x, y) coordinates of the top-left corner of each agent's observable view 296 | """ 297 | agent_x, agent_y = agent_pos[:, 0], agent_pos[:, 1] 298 | top_left = np.zeros((agent_dir.shape[0], 2), dtype=np.int_) 299 | 300 | # Facing right 301 | top_left[agent_dir == RIGHT, 0] = agent_x[agent_dir == RIGHT] 302 | top_left[agent_dir == RIGHT, 1] = agent_y[agent_dir == RIGHT] - agent_view_size // 2 303 | 304 | # Facing down 305 | top_left[agent_dir == DOWN, 0] = agent_x[agent_dir == DOWN] - agent_view_size // 2 306 | top_left[agent_dir == DOWN, 1] = agent_y[agent_dir == DOWN] 307 | 308 | # Facing left 309 | top_left[agent_dir == LEFT, 0] = agent_x[agent_dir == LEFT] - agent_view_size + 1 310 | top_left[agent_dir == LEFT, 1] = agent_y[agent_dir == LEFT] - agent_view_size // 2 311 | 312 | # Facing up 313 | top_left[agent_dir == UP, 0] = agent_x[agent_dir == UP] - agent_view_size // 2 314 | top_left[agent_dir == UP, 1] = agent_y[agent_dir == UP] - agent_view_size + 1 315 | 316 | return top_left 317 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 Maxime Chevalier-Boisvert 190 | Copyright 2022 Farama Foundation 191 | Copyright 2023 Ini Oguntola 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. -------------------------------------------------------------------------------- /multigrid/core/roomgrid.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | 5 | from collections import deque 6 | from typing import Callable, Iterable, TypeVar 7 | 8 | from .agent import Agent 9 | from .constants import Color, Direction, Type 10 | from .grid import Grid 11 | from .world_object import Door, WorldObj 12 | from ..base import MultiGridEnv 13 | 14 | 15 | 16 | T = TypeVar('T') 17 | 18 | 19 | 20 | def bfs(start_node: T, neighbor_fn: Callable[[T], Iterable[T]]) -> set[T]: 21 | """ 22 | Run a breadth-first search from a starting node. 23 | 24 | Parameters 25 | ---------- 26 | start_node : T 27 | Start node 28 | neighbor_fn : Callable(T) -> Iterable[T] 29 | Function that returns the neighbors of a node 30 | 31 | Returns 32 | ------- 33 | visited : set[T] 34 | Set of nodes reachable from the start node 35 | """ 36 | visited, queue = set(), deque([start_node]) 37 | while queue: 38 | node = queue.popleft() 39 | if node not in visited: 40 | visited.add(node) 41 | queue.extend(neighbor_fn(node)) 42 | 43 | return visited 44 | 45 | def reject_next_to(env: MultiGridEnv, pos: tuple[int, int]): 46 | """ 47 | Function to filter out object positions that are right next to 48 | the agent's starting point 49 | """ 50 | return any(np.linalg.norm(pos - env.agent_states.pos, axis=-1) <= 1) 51 | 52 | 53 | class Room: 54 | """ 55 | Room as an area inside a grid. 56 | """ 57 | 58 | def __init__(self, top: tuple[int, int], size: tuple[int, int]): 59 | """ 60 | Parameters 61 | ---------- 62 | top : tuple[int, int] 63 | Top-left position of the room 64 | size : tuple[int, int] 65 | Room size as (width, height) 66 | """ 67 | self.top, self.size = top, size 68 | Point = tuple[int, int] # typing alias 69 | 70 | # Mapping of door objects and door positions 71 | self.doors: dict[Direction, Door | None] = {d: None for d in Direction} 72 | self.door_pos: dict[Direction, Point | None] = {d: None for d in Direction} 73 | 74 | # Mapping of rooms adjacent to this one 75 | self.neighbors: dict[Direction, Room | None] = {d: None for d in Direction} 76 | 77 | # List of objects contained in this room 78 | self.objs = [] 79 | 80 | @property 81 | def locked(self) -> bool: 82 | """ 83 | Return whether this room is behind a locked door. 84 | """ 85 | return any(door and door.is_locked for door in self.doors.values()) 86 | 87 | def set_door_pos( 88 | self, 89 | dir: Direction, 90 | random: np.random.Generator | None = None) -> tuple[int, int]: 91 | """ 92 | Set door position in the given direction. 93 | 94 | Parameters 95 | ---------- 96 | dir : Direction 97 | Direction of wall to place door 98 | random : np.random.Generator, optional 99 | Random number generator (if provided, door position will be random) 100 | """ 101 | left, top = self.top 102 | right, bottom = self.top[0] + self.size[0] - 1, self.top[1] + self.size[1] - 1, 103 | 104 | if dir == Direction.right: 105 | if random: 106 | self.door_pos[dir] = (right, random.integers(top + 1, bottom)) 107 | else: 108 | self.door_pos[dir] = (right, (top + bottom) // 2) 109 | 110 | elif dir == Direction.down: 111 | if random: 112 | self.door_pos[dir] = (random.integers(left + 1, right), bottom) 113 | else: 114 | self.door_pos[dir] = ((left + right) // 2, bottom) 115 | 116 | elif dir == Direction.left: 117 | if random: 118 | self.door_pos[dir] = (left, random.integers(top + 1, bottom)) 119 | else: 120 | self.door_pos[dir] = (left, (top + bottom) // 2) 121 | 122 | elif dir == Direction.up: 123 | if random: 124 | self.door_pos[dir] = (random.integers(left + 1, right), top) 125 | else: 126 | self.door_pos[dir] = ((left + right) // 2, top) 127 | 128 | return self.door_pos[dir] 129 | 130 | def pos_inside(self, x: int, y: int) -> bool: 131 | """ 132 | Check if a position is within the bounds of this room. 133 | """ 134 | left_x, top_y = self.top 135 | width, height = self.size 136 | return left_x <= x < left_x + width and top_y <= y < top_y + height 137 | 138 | 139 | class RoomGrid(MultiGridEnv): 140 | """ 141 | Environment with multiple rooms and random objects. 142 | This is meant to serve as a base class for other environments. 143 | """ 144 | 145 | def __init__( 146 | self, 147 | room_size: int = 7, 148 | num_rows: int = 3, 149 | num_cols: int = 3, 150 | **kwargs): 151 | """ 152 | Parameters 153 | ---------- 154 | room_size : int, default=7 155 | Width and height for each of the rooms 156 | num_rows : int, default=3 157 | Number of rows of rooms 158 | num_cols : int, default=3 159 | Number of columns of rooms 160 | **kwargs 161 | See :attr:`multigrid.base.MultiGridEnv.__init__` 162 | """ 163 | assert room_size >= 3 164 | assert num_rows > 0 165 | assert num_cols > 0 166 | self.room_size = room_size 167 | self.num_rows = num_rows 168 | self.num_cols = num_cols 169 | height = (room_size - 1) * num_rows + 1 170 | width = (room_size - 1) * num_cols + 1 171 | super().__init__(width=width, height=height, **kwargs) 172 | 173 | def get_room(self, col: int, row: int) -> Room: 174 | """ 175 | Get the room at the given column and row. 176 | 177 | Parameters 178 | ---------- 179 | col : int 180 | Column of the room 181 | row : int 182 | Row of the room 183 | """ 184 | assert 0 <= col < self.num_cols 185 | assert 0 <= row < self.num_rows 186 | return self.room_grid[row][col] 187 | 188 | def room_from_pos(self, x: int, y: int) -> Room: 189 | """ 190 | Get the room a given position maps to. 191 | 192 | Parameters 193 | ---------- 194 | x : int 195 | Grid x-coordinate 196 | y : int 197 | Grid y-coordinate 198 | """ 199 | col = x // (self.room_size - 1) 200 | row = y // (self.room_size - 1) 201 | return self.get_room(col, row) 202 | 203 | def _gen_grid(self, width, height): 204 | # Create the grid 205 | self.grid = Grid(width, height) 206 | self.room_grid = [[None] * self.num_cols for _ in range(self.num_rows)] 207 | 208 | # Create rooms 209 | for row in range(self.num_rows): 210 | for col in range(self.num_cols): 211 | room = Room( 212 | (col * (self.room_size - 1), row * (self.room_size - 1)), 213 | (self.room_size, self.room_size), 214 | ) 215 | self.room_grid[row][col] = room 216 | self.grid.wall_rect(*room.top, *room.size) # generate walls 217 | 218 | # Create connections between rooms 219 | for row in range(self.num_rows): 220 | for col in range(self.num_cols): 221 | room = self.room_grid[row][col] 222 | if col < self.num_cols - 1: 223 | room.neighbors[Direction.right] = self.room_grid[row][col + 1] 224 | if row < self.num_rows - 1: 225 | room.neighbors[Direction.down] = self.room_grid[row + 1][col] 226 | if col > 0: 227 | room.neighbors[Direction.left] = self.room_grid[row][col - 1] 228 | if row > 0: 229 | room.neighbors[Direction.up] = self.room_grid[row - 1][col] 230 | 231 | # Agents start in the middle, facing right 232 | self.agent_states.dir = Direction.right 233 | self.agent_states.pos = ( 234 | (self.num_cols // 2) * (self.room_size - 1) + (self.room_size // 2), 235 | (self.num_rows // 2) * (self.room_size - 1) + (self.room_size // 2), 236 | ) 237 | 238 | def place_in_room( 239 | self, col: int, row: int, obj: WorldObj) -> tuple[WorldObj, tuple[int, int]]: 240 | """ 241 | Add an existing object to the given room. 242 | 243 | Parameters 244 | ---------- 245 | col : int 246 | Room column 247 | row : int 248 | Room row 249 | obj : WorldObj 250 | Object to add 251 | """ 252 | room = self.get_room(col, row) 253 | pos = self.place_obj( 254 | obj, room.top, room.size, reject_fn=reject_next_to, max_tries=1000) 255 | room.objs.append(obj) 256 | return obj, pos 257 | 258 | def add_object( 259 | self, 260 | col: int, 261 | row: int, 262 | kind: Type | None = None, 263 | color: Color | None = None) -> tuple[WorldObj, tuple[int, int]]: 264 | """ 265 | Create a new object in the given room. 266 | 267 | Parameters 268 | ---------- 269 | col : int 270 | Room column 271 | row : int 272 | Room row 273 | kind : str, optional 274 | Type of object to add (random if not specified) 275 | color : str, optional 276 | Color of the object to add (random if not specified) 277 | """ 278 | kind = kind or self._rand_elem([Type.key, Type.ball, Type.box]) 279 | color = color or self._rand_color() 280 | obj = WorldObj(type=kind, color=color) 281 | return self.place_in_room(col, row, obj) 282 | 283 | def add_door( 284 | self, 285 | col: int, 286 | row: int, 287 | dir: Direction | None = None, 288 | color: Color | None = None, 289 | locked: bool | None = None, 290 | rand_pos: bool = True) -> tuple[Door, tuple[int, int]]: 291 | """ 292 | Add a door to a room, connecting it to a neighbor. 293 | 294 | Parameters 295 | ---------- 296 | col : int 297 | Room column 298 | row : int 299 | Room row 300 | dir : Direction, optional 301 | Which wall to put the door on (random if not specified) 302 | color : Color, optional 303 | Color of the door (random if not specified) 304 | locked : bool, optional 305 | Whether the door is locked (random if not specified) 306 | rand_pos : bool, default=True 307 | Whether to place the door at a random position on the room wall 308 | """ 309 | room = self.get_room(col, row) 310 | 311 | # Need to make sure that there is a neighbor along this wall 312 | # and that there is not already a door 313 | if dir is None: 314 | while room.neighbors[dir] is None or room.doors[dir] is not None: 315 | dir = self._rand_elem(Direction) 316 | else: 317 | assert room.neighbors[dir] is not None, "no neighbor in this direction" 318 | assert room.doors[dir] is None, "door already exists" 319 | 320 | # Create the door 321 | color = color if color is not None else self._rand_color() 322 | locked = locked if locked is not None else self._rand_bool() 323 | door = Door(color, is_locked=locked) 324 | pos = room.set_door_pos(dir, random=self.np_random if rand_pos else None) 325 | self.put_obj(door, *pos) 326 | 327 | # Connect the door to the neighboring room 328 | room.doors[dir] = door 329 | room.neighbors[dir].doors[(dir + 2) % 4] = door 330 | 331 | return door, pos 332 | 333 | def remove_wall(self, col: int, row: int, dir: Direction): 334 | """ 335 | Remove a wall between two rooms. 336 | 337 | Parameters 338 | ---------- 339 | col : int 340 | Room column 341 | row : int 342 | Room row 343 | dir : Direction 344 | Direction of the wall to remove 345 | """ 346 | room = self.get_room(col, row) 347 | assert room.doors[dir] is None, "door exists on this wall" 348 | assert room.neighbors[dir], "invalid wall" 349 | 350 | tx, ty = room.top 351 | w, h = room.size 352 | 353 | # Remove the wall 354 | if dir == Direction.right: 355 | for i in range(1, h - 1): 356 | self.grid.set(tx + w - 1, ty + i, None) 357 | elif dir == Direction.down: 358 | for i in range(1, w - 1): 359 | self.grid.set(tx + i, ty + h - 1, None) 360 | elif dir == Direction.left: 361 | for i in range(1, h - 1): 362 | self.grid.set(tx, ty + i, None) 363 | elif dir == Direction.up: 364 | for i in range(1, w - 1): 365 | self.grid.set(tx + i, ty, None) 366 | else: 367 | assert False, "invalid wall index" 368 | 369 | # Mark the rooms as connected 370 | room.doors[dir] = True 371 | room.neighbors[dir].doors[(dir + 2) % 4] = True 372 | 373 | def place_agent( 374 | self, 375 | agent: Agent, 376 | col: int | None = None, 377 | row: int | None = None, 378 | rand_dir: bool = True) -> tuple[int, int]: 379 | """ 380 | Place an agent in a room. 381 | 382 | Parameters 383 | ---------- 384 | agent : Agent 385 | Agent to place 386 | col : int, optional 387 | Room column to place the agent in (random if not specified) 388 | row : int, optional 389 | Room row to place the agent in (random if not specified) 390 | rand_dir : bool, default=True 391 | Whether to select a random agent direction 392 | """ 393 | col = col if col is not None else self._rand_int(0, self.num_cols) 394 | row = row if row is not None else self._rand_int(0, self.num_rows) 395 | room = self.get_room(col, row) 396 | 397 | # Find a position that is not right in front of an object 398 | while True: 399 | super().place_agent(agent, room.top, room.size, rand_dir, max_tries=1000) 400 | front_cell = self.grid.get(*agent.front_pos) 401 | if front_cell is None or front_cell.type == Type.wall: 402 | break 403 | 404 | return agent.state.pos 405 | 406 | def connect_all( 407 | self, 408 | door_colors: list[Color] = list(Color), 409 | max_itrs: int = 5000) -> list[Door]: 410 | """ 411 | Make sure that all rooms are reachable by the agent from its 412 | starting position. 413 | 414 | Parameters 415 | ---------- 416 | door_colors : list[Color], default=list(Color) 417 | Color options for creating doors 418 | max_itrs : int, default=5000 419 | Maximum number of iterations to try to connect all rooms 420 | """ 421 | added_doors = [] 422 | neighbor_fn = lambda room: [ 423 | room.neighbors[dir] for dir in Direction if room.doors[dir] is not None] 424 | start_room = self.get_room(0, 0) 425 | 426 | for i in range(max_itrs): 427 | # If all rooms are reachable, stop 428 | reachable_rooms = bfs(start_room, neighbor_fn) 429 | if len(reachable_rooms) == self.num_rows * self.num_cols: 430 | return added_doors 431 | 432 | # Pick a random room and door position 433 | col = self._rand_int(0, self.num_cols) 434 | row = self._rand_int(0, self.num_rows) 435 | dir = self._rand_elem(Direction) 436 | room = self.get_room(col, row) 437 | 438 | # If there is already a door there, skip 439 | if not room.neighbors[dir] or room.doors[dir]: 440 | continue 441 | 442 | neighbor_room = room.neighbors[dir] 443 | assert neighbor_room is not None 444 | if room.locked or neighbor_room.locked: 445 | continue 446 | 447 | # Add a new door 448 | color = self._rand_elem(door_colors) 449 | door, _ = self.add_door(col, row, dir=dir, color=color, locked=False) 450 | added_doors.append(door) 451 | 452 | raise RecursionError('connect_all() failed') 453 | 454 | def add_distractors( 455 | self, 456 | col: int | None = None, 457 | row: int | None = None, 458 | num_distractors: int = 10, 459 | all_unique: bool = True) -> list[WorldObj]: 460 | """ 461 | Add random objects that can potentially distract / confuse the agent. 462 | 463 | Parameters 464 | ---------- 465 | col : int, optional 466 | Room column to place the objects in (random if not specified) 467 | row : int, optional 468 | Room row to place the objects in (random if not specified) 469 | num_distractors : int, default=10 470 | Number of distractor objects to add 471 | all_unique : bool, default=True 472 | Whether all distractor objects should be unique with respect to (type, color) 473 | """ 474 | # Collect keys for existing room objects 475 | room_objs = (obj for row in self.room_grid for room in row for obj in room.objs) 476 | room_obj_keys = {(obj.type, obj.color) for obj in room_objs} 477 | 478 | # Add distractors 479 | distractors = [] 480 | while len(distractors) < num_distractors: 481 | color = self._rand_color() 482 | type = self._rand_elem([Type.key, Type.ball, Type.box]) 483 | 484 | if all_unique and (type, color) in room_obj_keys: 485 | continue 486 | 487 | # Add the object to a random room if no room specified 488 | col = col if col is not None else self._rand_int(0, self.num_cols) 489 | row = row if row is not None else self._rand_int(0, self.num_rows) 490 | distractor, _ = self.add_object(col, row, kind=type, color=color) 491 | 492 | room_obj_keys.append((type, color)) 493 | distractors.append(distractor) 494 | 495 | return distractors 496 | -------------------------------------------------------------------------------- /multigrid/core/world_object.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | import numpy as np 5 | 6 | from numpy.typing import ArrayLike, NDArray as ndarray 7 | from typing import Any, TYPE_CHECKING 8 | 9 | from .constants import Color, State, Type 10 | from ..utils.rendering import ( 11 | fill_coords, 12 | point_in_circle, 13 | point_in_line, 14 | point_in_rect, 15 | ) 16 | 17 | if TYPE_CHECKING: 18 | from .agent import Agent 19 | from ..base import MultiGridEnv 20 | 21 | 22 | 23 | class WorldObjMeta(type): 24 | """ 25 | Metaclass for world objects. 26 | 27 | Each subclass is associated with a unique :class:`Type` enumeration value. 28 | 29 | By default, the type name is the class name (in lowercase), but this can be 30 | overridden by setting the `type_name` attribute in the class definition. 31 | Type names are dynamically added to the :class:`Type` enumeration 32 | if not already present. 33 | 34 | Examples 35 | -------- 36 | >>> class A(WorldObj): pass 37 | >>> A().type 38 | 39 | 40 | >>> class B(WorldObj): type_name = 'goal' 41 | >>> B().type 42 | 43 | 44 | :meta private: 45 | """ 46 | 47 | # Registry of object classes 48 | _TYPE_IDX_TO_CLASS = {} 49 | 50 | def __new__(meta, name, bases, class_dict): 51 | cls = super().__new__(meta, name, bases, class_dict) 52 | 53 | if name != 'WorldObj': 54 | type_name = class_dict.get('type_name', name.lower()) 55 | 56 | # Add the object class name to the `Type` enumeration if not already present 57 | if type_name not in set(Type): 58 | Type.add_item(type_name, type_name) 59 | 60 | # Store the object class with its corresponding type index 61 | meta._TYPE_IDX_TO_CLASS[Type(type_name).to_index()] = cls 62 | 63 | return cls 64 | 65 | 66 | class WorldObj(np.ndarray, metaclass=WorldObjMeta): 67 | """ 68 | Base class for grid world objects. 69 | 70 | Attributes 71 | ---------- 72 | type : Type 73 | The object type 74 | color : Color 75 | The object color 76 | state : State 77 | The object state 78 | contains : WorldObj or None 79 | The object contained by this object, if any 80 | init_pos : tuple[int, int] or None 81 | The initial position of the object 82 | cur_pos : tuple[int, int] or None 83 | The current position of the object 84 | """ 85 | # WorldObj vector indices 86 | TYPE = 0 87 | COLOR = 1 88 | STATE = 2 89 | 90 | # WorldObj vector dimension 91 | dim = len([TYPE, COLOR, STATE]) 92 | 93 | def __new__(cls, type: str | None = None, color: str = Color.from_index(0)): 94 | """ 95 | Parameters 96 | ---------- 97 | type : str or None 98 | Object type 99 | color : str 100 | Object color 101 | """ 102 | # If not provided, infer the object type from the class 103 | type_name = type or getattr(cls, 'type_name', cls.__name__.lower()) 104 | type_idx = Type(type_name).to_index() 105 | 106 | # Use the WorldObj subclass corresponding to the object type 107 | cls = WorldObjMeta._TYPE_IDX_TO_CLASS.get(type_idx, cls) 108 | 109 | # Create the object 110 | obj = np.zeros(cls.dim, dtype=int).view(cls) 111 | obj[WorldObj.TYPE] = type_idx 112 | obj[WorldObj.COLOR] = Color(color).to_index() 113 | obj.contains: WorldObj | None = None # object contained by this object 114 | obj.init_pos: tuple[int, int] | None = None # initial position of the object 115 | obj.cur_pos: tuple[int, int] | None = None # current position of the object 116 | 117 | return obj 118 | 119 | def __bool__(self) -> bool: 120 | return self.type != Type.empty 121 | 122 | def __repr__(self) -> str: 123 | return f"{self.__class__.__name__}(color={self.color})" 124 | 125 | def __str__(self) -> str: 126 | return self.__repr__() 127 | 128 | def __eq__(self, other: Any): 129 | return self is other 130 | 131 | @staticmethod 132 | @functools.cache 133 | def empty() -> 'WorldObj': 134 | """ 135 | Return an empty WorldObj instance. 136 | """ 137 | return WorldObj(type=Type.empty) 138 | 139 | @staticmethod 140 | def from_array(arr: ArrayLike[int]) -> 'WorldObj' | None: 141 | """ 142 | Convert an array to a WorldObj instance. 143 | 144 | Parameters 145 | ---------- 146 | arr : ArrayLike[int] 147 | Array encoding the object type, color, and state 148 | """ 149 | type_idx = arr[WorldObj.TYPE] 150 | 151 | if type_idx == Type.empty.to_index(): 152 | return None 153 | 154 | if type_idx in WorldObj._TYPE_IDX_TO_CLASS: 155 | cls = WorldObj._TYPE_IDX_TO_CLASS[type_idx] 156 | obj = cls.__new__(cls) 157 | obj[...] = arr 158 | return obj 159 | 160 | raise ValueError(f'Unknown object type: {arr[WorldObj.TYPE]}') 161 | 162 | @functools.cached_property 163 | def type(self) -> Type: 164 | """ 165 | Return the object type. 166 | """ 167 | return Type.from_index(self[WorldObj.TYPE]) 168 | 169 | @property 170 | def color(self) -> Color: 171 | """ 172 | Return the object color. 173 | """ 174 | return Color.from_index(self[WorldObj.COLOR]) 175 | 176 | @color.setter 177 | def color(self, value: str): 178 | """ 179 | Set the object color. 180 | """ 181 | self[WorldObj.COLOR] = Color(value).to_index() 182 | 183 | @property 184 | def state(self) -> str: 185 | """ 186 | Return the name of the object state. 187 | """ 188 | return State.from_index(self[WorldObj.STATE]) 189 | 190 | @state.setter 191 | def state(self, value: str): 192 | """ 193 | Set the name of the object state. 194 | """ 195 | self[WorldObj.STATE] = State(value).to_index() 196 | 197 | def can_overlap(self) -> bool: 198 | """ 199 | Can an agent overlap with this? 200 | """ 201 | return self.type == Type.empty 202 | 203 | def can_pickup(self) -> bool: 204 | """ 205 | Can an agent pick this up? 206 | """ 207 | return False 208 | 209 | def can_contain(self) -> bool: 210 | """ 211 | Can this contain another object? 212 | """ 213 | return False 214 | 215 | def toggle(self, env: MultiGridEnv, agent: Agent, pos: tuple[int, int]) -> bool: 216 | """ 217 | Toggle the state of this object or trigger an action this object performs. 218 | 219 | Parameters 220 | ---------- 221 | env : MultiGridEnv 222 | The environment this object is contained in 223 | agent : Agent 224 | The agent performing the toggle action 225 | pos : tuple[int, int] 226 | The (x, y) position of this object in the environment grid 227 | 228 | Returns 229 | ------- 230 | success : bool 231 | Whether the toggle action was successful 232 | """ 233 | return False 234 | 235 | def encode(self) -> tuple[int, int, int]: 236 | """ 237 | Encode a 3-tuple description of this object. 238 | 239 | Returns 240 | ------- 241 | type_idx : int 242 | The index of the object type 243 | color_idx : int 244 | The index of the object color 245 | state_idx : int 246 | The index of the object state 247 | """ 248 | return tuple(self) 249 | 250 | @staticmethod 251 | def decode(type_idx: int, color_idx: int, state_idx: int) -> 'WorldObj' | None: 252 | """ 253 | Create an object from a 3-tuple description. 254 | 255 | Parameters 256 | ---------- 257 | type_idx : int 258 | The index of the object type 259 | color_idx : int 260 | The index of the object color 261 | state_idx : int 262 | The index of the object state 263 | """ 264 | arr = np.array([type_idx, color_idx, state_idx]) 265 | return WorldObj.from_array(arr) 266 | 267 | def render(self, img: ndarray[np.uint8]): 268 | """ 269 | Draw the world object. 270 | 271 | Parameters 272 | ---------- 273 | img : ndarray[int] of shape (width, height, 3) 274 | RGB image array to render object on 275 | """ 276 | raise NotImplementedError 277 | 278 | 279 | class Goal(WorldObj): 280 | """ 281 | Goal object an agent may be searching for. 282 | """ 283 | 284 | def __new__(cls, color: str = Color.green): 285 | return super().__new__(cls, color=color) 286 | 287 | def can_overlap(self) -> bool: 288 | """ 289 | :meta private: 290 | """ 291 | return True 292 | 293 | def render(self, img): 294 | """ 295 | :meta private: 296 | """ 297 | fill_coords(img, point_in_rect(0, 1, 0, 1), self.color.rgb()) 298 | 299 | 300 | class Floor(WorldObj): 301 | """ 302 | Colored floor tile an agent can walk over. 303 | """ 304 | 305 | def __new__(cls, color: str = Color.blue): 306 | """ 307 | Parameters 308 | ---------- 309 | color : str 310 | Object color 311 | """ 312 | return super().__new__(cls, color=color) 313 | 314 | def can_overlap(self) -> bool: 315 | """ 316 | :meta private: 317 | """ 318 | return True 319 | 320 | def render(self, img): 321 | """ 322 | :meta private: 323 | """ 324 | # Give the floor a pale color 325 | color = self.color.rgb() / 2 326 | fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color) 327 | 328 | 329 | class Lava(WorldObj): 330 | """ 331 | Lava object an agent can fall onto. 332 | """ 333 | 334 | def __new__(cls): 335 | """ 336 | """ 337 | return super().__new__(cls, color=Color.red) 338 | 339 | def can_overlap(self) -> bool: 340 | """ 341 | :meta private: 342 | """ 343 | return True 344 | 345 | def render(self, img): 346 | """ 347 | :meta private: 348 | """ 349 | c = (255, 128, 0) 350 | 351 | # Background color 352 | fill_coords(img, point_in_rect(0, 1, 0, 1), c) 353 | 354 | # Little waves 355 | for i in range(3): 356 | ylo = 0.3 + 0.2 * i 357 | yhi = 0.4 + 0.2 * i 358 | fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0)) 359 | fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0)) 360 | fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0)) 361 | fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0)) 362 | 363 | 364 | class Wall(WorldObj): 365 | """ 366 | Wall object that agents cannot move through. 367 | """ 368 | 369 | @functools.cache # reuse instances, since object is effectively immutable 370 | def __new__(cls, color: str = Color.grey): 371 | """ 372 | Parameters 373 | ---------- 374 | color : str 375 | Object color 376 | """ 377 | return super().__new__(cls, color=color) 378 | 379 | def render(self, img): 380 | """ 381 | :meta private: 382 | """ 383 | fill_coords(img, point_in_rect(0, 1, 0, 1), self.color.rgb()) 384 | 385 | 386 | class Door(WorldObj): 387 | """ 388 | Door object that may be opened or closed. Locked doors require a key to open. 389 | 390 | Attributes 391 | ---------- 392 | is_open: bool 393 | Whether the door is open 394 | is_locked: bool 395 | Whether the door is locked 396 | """ 397 | 398 | def __new__( 399 | cls, color: str = Color.blue, is_open: bool = False, is_locked: bool = False): 400 | """ 401 | Parameters 402 | ---------- 403 | color : str 404 | Object color 405 | is_open : bool 406 | Whether the door is open 407 | is_locked : bool 408 | Whether the door is locked 409 | """ 410 | door = super().__new__(cls, color=color) 411 | door.is_open = is_open 412 | door.is_locked = is_locked 413 | return door 414 | 415 | def __str__(self): 416 | return f"{self.__class__.__name__}(color={self.color},state={self.state})" 417 | 418 | @property 419 | def is_open(self) -> bool: 420 | """ 421 | Whether the door is open. 422 | """ 423 | return self.state == State.open 424 | 425 | @is_open.setter 426 | def is_open(self, value: bool): 427 | """ 428 | Set the door to be open or closed. 429 | """ 430 | if value: 431 | self.state = State.open # set state to open 432 | elif not self.is_locked: 433 | self.state = State.closed # set state to closed (unless already locked) 434 | 435 | @property 436 | def is_locked(self) -> bool: 437 | """ 438 | Whether the door is locked. 439 | """ 440 | return self.state == State.locked 441 | 442 | @is_locked.setter 443 | def is_locked(self, value: bool): 444 | """ 445 | Set the door to be locked or unlocked. 446 | """ 447 | if value: 448 | self.state = State.locked # set state to locked 449 | elif not self.is_open: 450 | self.state = State.closed # set state to closed (unless already open) 451 | 452 | def can_overlap(self) -> bool: 453 | """ 454 | :meta private: 455 | """ 456 | return self.is_open 457 | 458 | def toggle(self, env, agent, pos): 459 | """ 460 | :meta private: 461 | """ 462 | if self.is_locked: 463 | # Check if the player has the right key to unlock the door 464 | carried_obj = agent.state.carrying 465 | if isinstance(carried_obj, Key) and carried_obj.color == self.color: 466 | self.is_locked = False 467 | self.is_open = True 468 | env.grid.update(*pos) 469 | return True 470 | return False 471 | 472 | self.is_open = not self.is_open 473 | env.grid.update(*pos) 474 | return True 475 | 476 | def render(self, img): 477 | """ 478 | :meta private: 479 | """ 480 | c = self.color.rgb() 481 | 482 | if self.is_open: 483 | fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c) 484 | fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0)) 485 | return 486 | 487 | # Door frame and door 488 | if self.is_locked: 489 | fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c) 490 | fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * c) 491 | 492 | # Draw key slot 493 | fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c) 494 | else: 495 | fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c) 496 | fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0)) 497 | fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c) 498 | fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0)) 499 | 500 | # Draw door handle 501 | fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c) 502 | 503 | 504 | class Key(WorldObj): 505 | """ 506 | Key object that can be picked up and used to unlock doors. 507 | """ 508 | 509 | def __new__(cls, color: str = Color.blue): 510 | """ 511 | Parameters 512 | ---------- 513 | color : str 514 | Object color 515 | """ 516 | return super().__new__(cls, color=color) 517 | 518 | def can_pickup(self) -> bool: 519 | """ 520 | :meta private: 521 | """ 522 | return True 523 | 524 | def render(self, img): 525 | """ 526 | :meta private: 527 | """ 528 | c = self.color.rgb() 529 | 530 | # Vertical quad 531 | fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c) 532 | 533 | # Teeth 534 | fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c) 535 | fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c) 536 | 537 | # Ring 538 | fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c) 539 | fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0)) 540 | 541 | 542 | class Ball(WorldObj): 543 | """ 544 | Ball object that can be picked up by agents. 545 | """ 546 | 547 | def __new__(cls, color: str = Color.blue): 548 | """ 549 | Parameters 550 | ---------- 551 | color : str 552 | Object color 553 | """ 554 | return super().__new__(cls, color=color) 555 | 556 | def can_pickup(self) -> bool: 557 | """ 558 | :meta private: 559 | """ 560 | return True 561 | 562 | def render(self, img): 563 | """ 564 | :meta private: 565 | """ 566 | fill_coords(img, point_in_circle(0.5, 0.5, 0.31), self.color.rgb()) 567 | 568 | 569 | class Box(WorldObj): 570 | """ 571 | Box object that may contain other objects. 572 | """ 573 | 574 | def __new__(cls, color: str = Color.yellow, contains: WorldObj | None = None): 575 | """ 576 | Parameters 577 | ---------- 578 | color : str 579 | Object color 580 | contains : WorldObj or None 581 | Object contents 582 | """ 583 | box = super().__new__(cls, color=color) 584 | box.contains = contains 585 | return box 586 | 587 | def can_pickup(self) -> bool: 588 | """ 589 | :meta private: 590 | """ 591 | return True 592 | 593 | def can_contain(self) -> bool: 594 | """ 595 | :meta private: 596 | """ 597 | return True 598 | 599 | def toggle(self, env, agent, pos): 600 | """ 601 | :meta private: 602 | """ 603 | # Replace the box by its contents 604 | env.grid.set(*pos, self.contains) 605 | return True 606 | 607 | def render(self, img): 608 | """ 609 | :meta private: 610 | """ 611 | # Outline 612 | fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), self.color.rgb()) 613 | fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0)) 614 | 615 | # Horizontal slit 616 | fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), self.color.rgb()) 617 | -------------------------------------------------------------------------------- /multigrid/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import gymnasium as gym 4 | import math 5 | import numpy as np 6 | import pygame 7 | import pygame.freetype 8 | 9 | from abc import ABC, abstractmethod 10 | from collections import defaultdict 11 | from gymnasium import spaces 12 | from itertools import repeat 13 | from numpy.typing import NDArray as ndarray 14 | from typing import Any, Callable, Iterable, Literal, SupportsFloat 15 | 16 | from .core.actions import Action 17 | from .core.agent import Agent, AgentState 18 | from .core.constants import Type, TILE_PIXELS 19 | from .core.grid import Grid 20 | from .core.mission import MissionSpace 21 | from .core.world_object import WorldObj 22 | from .utils.obs import gen_obs_grid_encoding 23 | from .utils.random import RandomMixin 24 | 25 | 26 | 27 | ### Typing 28 | 29 | AgentID = int 30 | ObsType = dict[str, Any] 31 | 32 | 33 | 34 | ### Environment 35 | 36 | class MultiGridEnv(gym.Env, RandomMixin, ABC): 37 | """ 38 | Base class for multi-agent 2D gridworld environments. 39 | 40 | :Agents: 41 | 42 | The environment can be configured with any fixed number of agents. 43 | Agents are represented by :class:`.Agent` instances, and are 44 | identified by their index, from ``0`` to ``len(env.agents) - 1``. 45 | 46 | :Observation Space: 47 | 48 | The multi-agent observation space is a Dict mapping from agent index to 49 | corresponding agent observation space. 50 | 51 | The standard agent observation is a dictionary with the following entries: 52 | 53 | * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`) 54 | Encoding of the agent's view of the environment, 55 | where each grid object is encoded as a 3 dimensional tuple: 56 | (:class:`.Type`, :class:`.Color`, :class:`.State`) 57 | * direction : int 58 | Agent's direction (0: right, 1: down, 2: left, 3: up) 59 | * mission : Mission 60 | Task string corresponding to the current environment configuration 61 | 62 | :Action Space: 63 | 64 | The multi-agent action space is a Dict mapping from agent index to 65 | corresponding agent action space. 66 | 67 | Agent actions are discrete integers, as enumerated in :class:`.Action`. 68 | 69 | Attributes 70 | ---------- 71 | agents : list[Agent] 72 | List of agents in the environment 73 | grid : Grid 74 | Environment grid 75 | observation_space : spaces.Dict[AgentID, spaces.Space] 76 | Joint observation space of all agents 77 | action_space : spaces.Dict[AgentID, spaces.Space] 78 | Joint action space of all agents 79 | """ 80 | metadata = { 81 | 'render_modes': ['human', 'rgb_array'], 82 | 'render_fps': 20, 83 | } 84 | 85 | def __init__( 86 | self, 87 | mission_space: MissionSpace | str = "maximize reward", 88 | agents: Iterable[Agent] | int = 1, 89 | grid_size: int | None = None, 90 | width: int | None = None, 91 | height: int | None = None, 92 | max_steps: int = 100, 93 | see_through_walls: bool = False, 94 | agent_view_size: int = 7, 95 | allow_agent_overlap: bool = True, 96 | joint_reward: bool = False, 97 | success_termination_mode: Literal['any', 'all'] = 'any', 98 | failure_termination_mode: Literal['any', 'all'] = 'all', 99 | render_mode: str | None = None, 100 | screen_size: int | None = 640, 101 | highlight: bool = True, 102 | tile_size: int = TILE_PIXELS, 103 | agent_pov: bool = False): 104 | """ 105 | Parameters 106 | ---------- 107 | mission_space : MissionSpace 108 | Space of mission strings (i.e. agent instructions) 109 | agents : int or Iterable[Agent] 110 | Number of agents in the environment (or provide :class:`Agent` instances) 111 | grid_size : int 112 | Size of the environment grid (width and height) 113 | width : int 114 | Width of the environment grid (if `grid_size` is not provided) 115 | height : int 116 | Height of the environment grid (if `grid_size` is not provided) 117 | max_steps : int 118 | Maximum number of steps per episode 119 | see_through_walls : bool 120 | Whether agents can see through walls 121 | agent_view_size : int 122 | Size of agent view (must be odd) 123 | allow_agent_overlap : bool 124 | Whether agents are allowed to overlap 125 | joint_reward : bool 126 | Whether all agents receive the same joint reward 127 | success_termination_mode : 'any' or 'all' 128 | Whether to terminate when any agent completes its mission 129 | or when all agents complete their missions 130 | failure_termination_mode : 'any' or 'all' 131 | Whether to terminate when any agent fails its mission 132 | or when all agents fail their missions 133 | render_mode : str 134 | Rendering mode (human or rgb_array) 135 | screen_size : int 136 | Width and height of the rendering window (in pixels) 137 | highlight : bool 138 | Whether to highlight the view of each agent when rendering 139 | tile_size : int 140 | Width and height of each grid tiles (in pixels) 141 | """ 142 | gym.Env.__init__(self) 143 | RandomMixin.__init__(self, self.np_random) 144 | 145 | # Initialize mission space 146 | if isinstance(mission_space, str): 147 | self.mission_space = MissionSpace.from_string(mission_space) 148 | else: 149 | self.mission_space = mission_space 150 | 151 | # Initialize grid 152 | width, height = (grid_size, grid_size) if grid_size else (width, height) 153 | assert width is not None and height is not None 154 | self.width, self.height = width, height 155 | self.grid: Grid = Grid(width, height) 156 | 157 | # Initialize agents 158 | if isinstance(agents, int): 159 | self.num_agents = agents 160 | self.agent_states = AgentState(agents) # joint agent state (vectorized) 161 | self.agents: list[Agent] = [] 162 | for i in range(agents): 163 | agent = Agent( 164 | index=i, 165 | mission_space=self.mission_space, 166 | view_size=agent_view_size, 167 | see_through_walls=see_through_walls, 168 | ) 169 | agent.state = self.agent_states[i] 170 | self.agents.append(agent) 171 | elif isinstance(agents, Iterable): 172 | assert {agent.index for agent in agents} == set(range(len(agents))) 173 | self.num_agents = len(agents) 174 | self.agent_states = AgentState(self.num_agents) 175 | self.agents: list[Agent] = sorted(agents, key=lambda agent: agent.index) 176 | for agent in self.agents: 177 | self.agent_states[agent.index] = agent.state # copy to joint agent state 178 | agent.state = self.agent_states[agent.index] # reference joint agent state 179 | else: 180 | raise ValueError(f"Invalid argument for agents: {agents}") 181 | 182 | # Action enumeration for this environment 183 | self.actions = Action 184 | 185 | # Range of possible rewards 186 | self.reward_range = (0, 1) 187 | 188 | assert isinstance( 189 | max_steps, int 190 | ), f"The argument max_steps must be an integer, got: {type(max_steps)}" 191 | self.max_steps = max_steps 192 | 193 | # Rendering attributes 194 | self.render_mode = render_mode 195 | self.highlight = highlight 196 | self.tile_size = tile_size 197 | self.agent_pov = agent_pov 198 | self.screen_size = screen_size 199 | self.render_size = None 200 | self.window = None 201 | self.clock = None 202 | 203 | # Other 204 | self.allow_agent_overlap = allow_agent_overlap 205 | self.joint_reward = joint_reward 206 | self.success_termination_mode = success_termination_mode 207 | self.failure_termination_mode = failure_termination_mode 208 | 209 | @property 210 | def observation_space(self) -> spaces.Dict[AgentID, spaces.Space]: 211 | """ 212 | Return the joint observation space of all agents. 213 | """ 214 | return spaces.Dict({ 215 | agent.index: agent.observation_space 216 | for agent in self.agents 217 | }) 218 | 219 | @property 220 | def action_space(self) -> spaces.Dict[AgentID, spaces.Space]: 221 | """ 222 | Return the joint action space of all agents. 223 | """ 224 | return spaces.Dict({ 225 | agent.index: agent.action_space 226 | for agent in self.agents 227 | }) 228 | 229 | @abstractmethod 230 | def _gen_grid(self, width: int, height: int): 231 | """ 232 | :meta public: 233 | 234 | Generate the grid for a new episode. 235 | 236 | This method should: 237 | 238 | * Set ``self.grid`` and populate it with :class:`.WorldObj` instances 239 | * Set the positions and directions of each agent 240 | 241 | Parameters 242 | ---------- 243 | width : int 244 | Width of the grid 245 | height : int 246 | Height of the grid 247 | """ 248 | pass 249 | 250 | def reset( 251 | self, seed: int | None = None, **kwargs) -> tuple[ 252 | dict[AgentID, ObsType]: 253 | dict[AgentID, dict[str, Any]]]: 254 | """ 255 | Reset the environment. 256 | 257 | Parameters 258 | ---------- 259 | seed : int or None 260 | Seed for random number generator 261 | 262 | Returns 263 | ------- 264 | observations : dict[AgentID, ObsType] 265 | Observation for each agent 266 | infos : dict[AgentID, dict[str, Any]] 267 | Additional information for each agent 268 | """ 269 | super().reset(seed=seed, **kwargs) 270 | 271 | # Reset agents 272 | self.mission_space.seed(seed) 273 | self.mission = self.mission_space.sample() 274 | self.agent_states = AgentState(self.num_agents) 275 | for agent in self.agents: 276 | agent.state = self.agent_states[agent.index] 277 | agent.reset(mission=self.mission) 278 | 279 | # Generate a new random grid at the start of each episode 280 | self._gen_grid(self.width, self.height) 281 | 282 | # These fields should be defined by _gen_grid 283 | assert np.all(self.agent_states.pos >= 0) 284 | assert np.all(self.agent_states.dir >= 0) 285 | 286 | # Check that agents don't overlap with other objects 287 | for agent in self.agents: 288 | start_cell = self.grid.get(*agent.state.pos) 289 | assert start_cell is None or start_cell.can_overlap() 290 | 291 | # Step count since episode start 292 | self.step_count = 0 293 | 294 | # Return first observation 295 | observations = self.gen_obs() 296 | 297 | # Render environment 298 | if self.render_mode == 'human': 299 | self.render() 300 | 301 | return observations, defaultdict(dict) 302 | 303 | def step( 304 | self, 305 | actions: dict[AgentID, Action]) -> tuple[ 306 | dict[AgentID, ObsType], 307 | dict[AgentID, SupportsFloat], 308 | dict[AgentID, bool], 309 | dict[AgentID, bool], 310 | dict[AgentID, dict[str, Any]]]: 311 | """ 312 | Run one timestep of the environment’s dynamics 313 | using the provided agent actions. 314 | 315 | Parameters 316 | ---------- 317 | actions : dict[AgentID, Action] 318 | Action for each agent acting at this timestep 319 | 320 | Returns 321 | ------- 322 | observations : dict[AgentID, ObsType] 323 | Observation for each agent 324 | rewards : dict[AgentID, SupportsFloat] 325 | Reward for each agent 326 | terminations : dict[AgentID, bool] 327 | Whether the episode has been terminated for each agent (success or failure) 328 | truncations : dict[AgentID, bool] 329 | Whether the episode has been truncated for each agent (max steps reached) 330 | infos : dict[AgentID, dict[str, Any]] 331 | Additional information for each agent 332 | """ 333 | self.step_count += 1 334 | rewards = self.handle_actions(actions) 335 | 336 | # Generate outputs 337 | observations = self.gen_obs() 338 | terminations = dict(enumerate(self.agent_states.terminated)) 339 | truncated = self.step_count >= self.max_steps 340 | truncations = dict(enumerate(repeat(truncated, self.num_agents))) 341 | 342 | # Rendering 343 | if self.render_mode == 'human': 344 | self.render() 345 | 346 | return observations, rewards, terminations, truncations, defaultdict(dict) 347 | 348 | def gen_obs(self) -> dict[AgentID, ObsType]: 349 | """ 350 | Generate observations for each agent (partially observable, low-res encoding). 351 | 352 | Returns 353 | ------- 354 | observations : dict[AgentID, ObsType] 355 | Mapping from agent ID to observation dict, containing: 356 | * 'image': partially observable view of the environment 357 | * 'direction': agent's direction / orientation (acting as a compass) 358 | * 'mission': textual mission string (instructions for the agent) 359 | """ 360 | direction = self.agent_states.dir 361 | image = gen_obs_grid_encoding( 362 | self.grid.state, 363 | self.agent_states, 364 | self.agents[0].view_size, 365 | self.agents[0].see_through_walls, 366 | ) 367 | 368 | observations = {} 369 | for i in range(self.num_agents): 370 | observations[i] = { 371 | 'image': image[i], 372 | 'direction': direction[i], 373 | 'mission': self.agents[i].mission, 374 | } 375 | 376 | return observations 377 | 378 | def handle_actions( 379 | self, actions: dict[AgentID, Action]) -> dict[AgentID, SupportsFloat]: 380 | """ 381 | Handle actions taken by agents. 382 | 383 | Parameters 384 | ---------- 385 | actions : dict[AgentID, Action] 386 | Action for each agent acting at this timestep 387 | 388 | Returns 389 | ------- 390 | rewards : dict[AgentID, SupportsFloat] 391 | Reward for each agent 392 | """ 393 | rewards = {agent_index: 0 for agent_index in range(self.num_agents)} 394 | 395 | # Randomize agent action order 396 | if self.num_agents == 1: 397 | order = (0,) 398 | else: 399 | order = self.np_random.random(size=self.num_agents).argsort() 400 | 401 | # Update agent states, grid states, and reward from actions 402 | for i in order: 403 | if i not in actions: 404 | continue 405 | 406 | agent, action = self.agents[i], actions[i] 407 | 408 | if agent.state.terminated: 409 | continue 410 | 411 | # Rotate left 412 | if action == Action.left: 413 | agent.state.dir = (agent.state.dir - 1) % 4 414 | 415 | # Rotate right 416 | elif action == Action.right: 417 | agent.state.dir = (agent.state.dir + 1) % 4 418 | 419 | # Move forward 420 | elif action == Action.forward: 421 | fwd_pos = agent.front_pos 422 | fwd_obj = self.grid.get(*fwd_pos) 423 | 424 | if fwd_obj is None or fwd_obj.can_overlap(): 425 | if not self.allow_agent_overlap: 426 | agent_present = np.bitwise_and.reduce( 427 | self.agent_states.pos == fwd_pos, axis=1).any() 428 | if agent_present: 429 | continue 430 | 431 | agent.state.pos = fwd_pos 432 | if fwd_obj is not None: 433 | if fwd_obj.type == Type.goal: 434 | self.on_success(agent, rewards, {}) 435 | if fwd_obj.type == Type.lava: 436 | self.on_failure(agent, rewards, {}) 437 | 438 | # Pick up an object 439 | elif action == Action.pickup: 440 | fwd_pos = agent.front_pos 441 | fwd_obj = self.grid.get(*fwd_pos) 442 | 443 | if fwd_obj is not None and fwd_obj.can_pickup(): 444 | if agent.state.carrying is None: 445 | agent.state.carrying = fwd_obj 446 | self.grid.set(*fwd_pos, None) 447 | 448 | # Drop an object 449 | elif action == Action.drop: 450 | fwd_pos = agent.front_pos 451 | fwd_obj = self.grid.get(*fwd_pos) 452 | 453 | if agent.state.carrying and fwd_obj is None: 454 | agent_present = np.bitwise_and.reduce( 455 | self.agent_states.pos == fwd_pos, axis=1).any() 456 | if not agent_present: 457 | self.grid.set(*fwd_pos, agent.state.carrying) 458 | agent.state.carrying.cur_pos = fwd_pos 459 | agent.state.carrying = None 460 | 461 | # Toggle/activate an object 462 | elif action == Action.toggle: 463 | fwd_pos = agent.front_pos 464 | fwd_obj = self.grid.get(*fwd_pos) 465 | 466 | if fwd_obj is not None: 467 | fwd_obj.toggle(self, agent, fwd_pos) 468 | 469 | # Done action (not used by default) 470 | elif action == Action.done: 471 | pass 472 | 473 | else: 474 | raise ValueError(f"Unknown action: {action}") 475 | 476 | return rewards 477 | 478 | def on_success( 479 | self, 480 | agent: Agent, 481 | rewards: dict[AgentID, SupportsFloat], 482 | terminations: dict[AgentID, bool]): 483 | """ 484 | Callback for when an agent completes its mission. 485 | 486 | Parameters 487 | ---------- 488 | agent : Agent 489 | Agent that completed its mission 490 | rewards : dict[AgentID, SupportsFloat] 491 | Reward dictionary to be updated 492 | terminations : dict[AgentID, bool] 493 | Termination dictionary to be updated 494 | """ 495 | if self.success_termination_mode == 'any': 496 | self.agent_states.terminated = True # terminate all agents 497 | for i in range(self.num_agents): 498 | terminations[i] = True 499 | else: 500 | agent.state.terminated = True # terminate this agent only 501 | terminations[agent.index] = True 502 | 503 | if self.joint_reward: 504 | for i in range(self.num_agents): 505 | rewards[i] = self._reward() # reward all agents 506 | else: 507 | rewards[agent.index] = self._reward() # reward this agent only 508 | 509 | def on_failure( 510 | self, 511 | agent: Agent, 512 | rewards: dict[AgentID, SupportsFloat], 513 | terminations: dict[AgentID, bool]): 514 | """ 515 | Callback for when an agent fails its mission prematurely. 516 | 517 | Parameters 518 | ---------- 519 | agent : Agent 520 | Agent that failed its mission 521 | rewards : dict[AgentID, SupportsFloat] 522 | Reward dictionary to be updated 523 | terminations : dict[AgentID, bool] 524 | Termination dictionary to be updated 525 | """ 526 | if self.failure_termination_mode == 'any': 527 | self.agent_states.terminated = True # terminate all agents 528 | for i in range(self.num_agents): 529 | terminations[i] = True 530 | else: 531 | agent.state.terminated = True # terminate this agent only 532 | terminations[agent.index] = True 533 | 534 | def is_done(self) -> bool: 535 | """ 536 | Return whether the current episode is finished (for all agents). 537 | """ 538 | truncated = self.step_count >= self.max_steps 539 | return truncated or all(self.agent_states.terminated) 540 | 541 | def __str__(self): 542 | """ 543 | Produce a pretty string of the environment's grid along with the agent. 544 | A grid cell is represented by 2-character string, the first one for 545 | the object and the second one for the color. 546 | """ 547 | # Map of object types to short string 548 | OBJECT_TO_STR = { 549 | 'wall': 'W', 550 | 'floor': 'F', 551 | 'door': 'D', 552 | 'key': 'K', 553 | 'ball': 'A', 554 | 'box': 'B', 555 | 'goal': 'G', 556 | 'lava': 'V', 557 | } 558 | 559 | # Map agent's direction to short string 560 | AGENT_DIR_TO_STR = {0: '>', 1: 'V', 2: '<', 3: '^'} 561 | 562 | # Get agent locations 563 | location_to_agent = {tuple(agent.pos): agent for agent in self.agents} 564 | 565 | output = "" 566 | for j in range(self.grid.height): 567 | for i in range(self.grid.width): 568 | if (i, j) in location_to_agent: 569 | output += 2 * AGENT_DIR_TO_STR[location_to_agent[i, j].dir] 570 | continue 571 | 572 | tile = self.grid.get(i, j) 573 | 574 | if tile is None: 575 | output += ' ' 576 | continue 577 | 578 | if tile.type == 'agent': 579 | output += 2 * AGENT_DIR_TO_STR[tile.dir] 580 | continue 581 | 582 | if tile.type == 'door': 583 | if tile.is_open: 584 | output += '__' 585 | elif tile.is_locked: 586 | output += 'L' + tile.color[0].upper() 587 | else: 588 | output += 'D' + tile.color[0].upper() 589 | continue 590 | 591 | output += OBJECT_TO_STR[tile.type] + tile.color[0].upper() 592 | 593 | if j < self.grid.height - 1: 594 | output += '\n' 595 | 596 | return output 597 | 598 | def _reward(self) -> float: 599 | """ 600 | Compute the reward to be given upon success. 601 | """ 602 | return 1 - 0.9 * (self.step_count / self.max_steps) 603 | 604 | def place_obj( 605 | self, 606 | obj: WorldObj | None, 607 | top: tuple[int, int] = None, 608 | size: tuple[int, int] = None, 609 | reject_fn: Callable[[MultiGridEnv, tuple[int, int]], bool] | None = None, 610 | max_tries=math.inf) -> tuple[int, int]: 611 | """ 612 | Place an object at an empty position in the grid. 613 | 614 | Parameters 615 | ---------- 616 | obj: WorldObj 617 | Object to place in the grid 618 | top: tuple[int, int] 619 | Top-left position of the rectangular area where to place the object 620 | size: tuple[int, int] 621 | Width and height of the rectangular area where to place the object 622 | reject_fn: Callable(env, pos) -> bool 623 | Function to filter out potential positions 624 | max_tries: int 625 | Maximum number of attempts to place the object 626 | """ 627 | if top is None: 628 | top = (0, 0) 629 | else: 630 | top = (max(top[0], 0), max(top[1], 0)) 631 | 632 | if size is None: 633 | size = (self.grid.width, self.grid.height) 634 | 635 | num_tries = 0 636 | 637 | while True: 638 | # This is to handle with rare cases where rejection sampling 639 | # gets stuck in an infinite loop 640 | if num_tries > max_tries: 641 | raise RecursionError("rejection sampling failed in place_obj") 642 | 643 | num_tries += 1 644 | 645 | pos = ( 646 | self._rand_int(top[0], min(top[0] + size[0], self.grid.width)), 647 | self._rand_int(top[1], min(top[1] + size[1], self.grid.height)), 648 | ) 649 | 650 | # Don't place the object on top of another object 651 | if self.grid.get(*pos) is not None: 652 | continue 653 | 654 | # Don't place the object where agents are 655 | if np.bitwise_and.reduce(self.agent_states.pos == pos, axis=1).any(): 656 | continue 657 | 658 | # Check if there is a filtering criterion 659 | if reject_fn and reject_fn(self, pos): 660 | continue 661 | 662 | break 663 | 664 | self.grid.set(pos[0], pos[1], obj) 665 | 666 | if obj is not None: 667 | obj.init_pos = pos 668 | obj.cur_pos = pos 669 | 670 | return pos 671 | 672 | def put_obj(self, obj: WorldObj, i: int, j: int): 673 | """ 674 | Put an object at a specific position in the grid. 675 | """ 676 | self.grid.set(i, j, obj) 677 | obj.init_pos = (i, j) 678 | obj.cur_pos = (i, j) 679 | 680 | def place_agent( 681 | self, 682 | agent: Agent, 683 | top=None, 684 | size=None, 685 | rand_dir=True, 686 | max_tries=math.inf) -> tuple[int, int]: 687 | """ 688 | Set agent starting point at an empty position in the grid. 689 | """ 690 | agent.state.pos = (-1, -1) 691 | pos = self.place_obj(None, top, size, max_tries=max_tries) 692 | agent.state.pos = pos 693 | 694 | if rand_dir: 695 | agent.state.dir = self._rand_int(0, 4) 696 | 697 | return pos 698 | 699 | def get_pov_render(self, *args, **kwargs): 700 | """ 701 | Render an agent's POV observation for visualization. 702 | """ 703 | raise NotImplementedError( 704 | "POV rendering not supported for multiagent environments." 705 | ) 706 | 707 | def get_full_render(self, highlight: bool, tile_size: int): 708 | """ 709 | Render a non-partial observation for visualization. 710 | """ 711 | # Compute agent visibility masks 712 | obs_shape = self.agents[0].observation_space['image'].shape[:-1] 713 | vis_masks = np.zeros((self.num_agents, *obs_shape), dtype=bool) 714 | for i, agent_obs in self.gen_obs().items(): 715 | vis_masks[i] = (agent_obs['image'][..., 0] != Type.unseen.to_index()) 716 | 717 | # Mask of which cells to highlight 718 | highlight_mask = np.zeros((self.width, self.height), dtype=bool) 719 | 720 | for agent in self.agents: 721 | # Compute the world coordinates of the bottom-left corner 722 | # of the agent's view area 723 | f_vec = agent.state.dir.to_vec() 724 | r_vec = np.array((-f_vec[1], f_vec[0])) 725 | top_left = ( 726 | agent.state.pos 727 | + f_vec * (agent.view_size - 1) 728 | - r_vec * (agent.view_size // 2) 729 | ) 730 | 731 | # For each cell in the visibility mask 732 | for vis_j in range(0, agent.view_size): 733 | for vis_i in range(0, agent.view_size): 734 | # If this cell is not visible, don't highlight it 735 | if not vis_masks[agent.index][vis_i, vis_j]: 736 | continue 737 | 738 | # Compute the world coordinates of this cell 739 | abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i) 740 | 741 | if abs_i < 0 or abs_i >= self.width: 742 | continue 743 | if abs_j < 0 or abs_j >= self.height: 744 | continue 745 | 746 | # Mark this cell to be highlighted 747 | highlight_mask[abs_i, abs_j] = True 748 | 749 | # Render the whole grid 750 | img = self.grid.render( 751 | tile_size, 752 | agents=self.agents, 753 | highlight_mask=highlight_mask if highlight else None, 754 | ) 755 | 756 | return img 757 | 758 | def get_frame( 759 | self, 760 | highlight: bool = True, 761 | tile_size: int = TILE_PIXELS, 762 | agent_pov: bool = False) -> ndarray[np.uint8]: 763 | """ 764 | Returns an RGB image corresponding to the whole environment. 765 | 766 | Parameters 767 | ---------- 768 | highlight: bool 769 | Whether to highlight agents' field of view (with a lighter gray color) 770 | tile_size: int 771 | How many pixels will form a tile from the NxM grid 772 | agent_pov: bool 773 | Whether to render agent's POV or the full environment 774 | 775 | Returns 776 | ------- 777 | frame: ndarray of shape (H, W, 3) 778 | A frame representing RGB values for the HxW pixel image 779 | """ 780 | if agent_pov: 781 | return self.get_pov_render(tile_size) 782 | else: 783 | return self.get_full_render(highlight, tile_size) 784 | 785 | def render(self): 786 | """ 787 | Render the environment. 788 | """ 789 | img = self.get_frame(self.highlight, self.tile_size) 790 | 791 | if self.render_mode == 'human': 792 | img = np.transpose(img, axes=(1, 0, 2)) 793 | screen_size = ( 794 | self.screen_size * min(img.shape[0] / img.shape[1], 1.0), 795 | self.screen_size * min(img.shape[1] / img.shape[0], 1.0), 796 | ) 797 | if self.render_size is None: 798 | self.render_size = img.shape[:2] 799 | if self.window is None: 800 | pygame.init() 801 | pygame.display.init() 802 | pygame.display.set_caption(f'multigrid - {self.__class__.__name__}') 803 | self.window = pygame.display.set_mode(screen_size) 804 | if self.clock is None: 805 | self.clock = pygame.time.Clock() 806 | surf = pygame.surfarray.make_surface(img) 807 | 808 | # Create background with mission description 809 | offset = surf.get_size()[0] * 0.1 810 | # offset = 32 if self.agent_pov else 64 811 | bg = pygame.Surface( 812 | (int(surf.get_size()[0] + offset), int(surf.get_size()[1] + offset)) 813 | ) 814 | bg.convert() 815 | bg.fill((255, 255, 255)) 816 | bg.blit(surf, (offset / 2, 0)) 817 | 818 | bg = pygame.transform.smoothscale(bg, screen_size) 819 | 820 | font_size = 22 821 | text = str(self.mission) 822 | font = pygame.freetype.SysFont(pygame.font.get_default_font(), font_size) 823 | text_rect = font.get_rect(text, size=font_size) 824 | text_rect.center = bg.get_rect().center 825 | text_rect.y = bg.get_height() - font_size * 1.5 826 | font.render_to(bg, text_rect, text, size=font_size) 827 | 828 | self.window.blit(bg, (0, 0)) 829 | pygame.event.pump() 830 | self.clock.tick(self.metadata['render_fps']) 831 | pygame.display.flip() 832 | 833 | elif self.render_mode == 'rgb_array': 834 | return img 835 | 836 | def close(self): 837 | """ 838 | Close the rendering window. 839 | """ 840 | if self.window: 841 | pygame.quit() 842 | --------------------------------------------------------------------------------