├── entity_gym ├── py.typed ├── __init__.py ├── serialization │ ├── __init__.py │ ├── msgpack_ragged.py │ ├── sample_recorder.py │ └── sample_loader.py ├── env │ ├── common.py │ ├── __init__.py │ ├── add_metrics_wrapper.py │ ├── env_list.py │ ├── action.py │ ├── validator.py │ ├── parallel_env_list.py │ ├── environment.py │ └── vec_env.py ├── main.py ├── tests │ ├── test_validator.py │ └── test_sample_recorder.py ├── dataclass_utils.py ├── examples │ ├── __init__.py │ ├── multi_armed_bandit.py │ ├── xor.py │ ├── tutorial.py │ ├── not_hotdog.py │ ├── cherry_pick.py │ ├── count.py │ ├── rock_paper_scissors.py │ ├── floor_is_lava.py │ ├── pick_matching_balls.py │ ├── move_to_origin.py │ ├── minefield.py │ ├── minesweeper.py │ └── multi_snake.py ├── ragged_dict.py ├── simple_trace.py └── runner.py ├── docs ├── requirements.txt ├── source │ ├── tutorials.rst │ ├── index.rst │ ├── conf.py │ ├── quick-start-guide.rst │ └── complex-action-spaces.rst ├── Makefile └── make.bat ├── .gitignore ├── .readthedocs.yaml ├── pyproject.toml ├── LICENSE-MIT ├── mypy.ini ├── .pre-commit-config.yaml ├── README.md ├── .github └── workflows │ ├── publish.yaml │ └── checks.yaml └── LICENSE-APACHE /entity_gym/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==4.5.0 2 | sphinx_rtd_theme==1.0.0 3 | autoapi==2.0.1 4 | -------------------------------------------------------------------------------- /docs/source/tutorials.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | quick-start-guide 8 | complex-action-spaces 9 | -------------------------------------------------------------------------------- /entity_gym/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Entity Gym is an open source Python library that defines an entity based API for reinforcement learning environments. 3 | """ 4 | -------------------------------------------------------------------------------- /entity_gym/serialization/__init__.py: -------------------------------------------------------------------------------- 1 | from .sample_loader import Trace # noqa 2 | from .sample_recorder import Sample, SampleRecorder, SampleRecordingVecEnv # noqa 3 | -------------------------------------------------------------------------------- /entity_gym/env/common.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Sequence, Union 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | 6 | Features = Union[npt.NDArray[np.float32], Sequence[Sequence[float]]] 7 | EntityID = Any 8 | EntityName = str 9 | ActionName = str 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.egg-info 3 | *.pyc 4 | wandb/ 5 | runs/ 6 | .dmypy.json 7 | __pycache__ 8 | 9 | # Pycharm 10 | .idea/ 11 | 12 | # Meld 13 | *.orig 14 | 15 | # Docs 16 | entity_gym/docs/build 17 | entity_gym/docs/source/generated 18 | entity_gym/docs/source/entity_gym 19 | -------------------------------------------------------------------------------- /entity_gym/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from entity_gym.examples import ENV_REGISTRY 4 | from entity_gym.runner import CliRunner 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--env", type=str, default="MoveToOrigin") 9 | args = parser.parse_args() 10 | 11 | envs = ENV_REGISTRY 12 | if args.env not in envs: 13 | raise ValueError( 14 | f"Unknown environment {args.env}\nValid environments are {list(envs.keys())}" 15 | ) 16 | else: 17 | env_cls = envs[args.env] 18 | 19 | print(env_cls) 20 | env = env_cls() 21 | CliRunner(env).run() 22 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to Entity Gym's documentation! 2 | ====================================== 3 | 4 | **Entity Gym** is an open source Python library that defines an entity based API for reinforcement learning environments. 5 | Entity Gym extends the standard paradigm of fixed-size observation spaces by allowing observations to contain dynamically-sized lists of entities. 6 | This enables a seamless and highly efficient interface with simulators, games, and other complex environments whose state can be naturally expressed as a collection of entities. 7 | 8 | Contents 9 | -------- 10 | 11 | .. toctree:: 12 | :maxdepth: 3 13 | 14 | tutorials 15 | entity_gym/entity_gym 16 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.9" 13 | # You can also specify other tool versions: 14 | # nodejs: "16" 15 | # rust: "1.55" 16 | # golang: "1.17" 17 | apt_packages: 18 | - graphviz 19 | 20 | # Build documentation in the docs/ directory with Sphinx 21 | sphinx: 22 | configuration: docs/source/conf.py 23 | 24 | python: 25 | # Install our python package before building the docs 26 | install: 27 | - method: pip 28 | path: . 29 | - requirements: docs/requirements.txt 30 | -------------------------------------------------------------------------------- /entity_gym/tests/test_validator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ragged_buffer import RaggedBufferI64 3 | 4 | from entity_gym.env.env_list import EnvList 5 | from entity_gym.env.validator import ValidatingEnv 6 | from entity_gym.examples.minesweeper import MineSweeper 7 | 8 | 9 | def test_env_list() -> None: 10 | # 100 environments 11 | envs = EnvList(lambda: ValidatingEnv(MineSweeper()), 100) 12 | obs_space = envs.obs_space() 13 | 14 | obs_reset = envs.reset(obs_space) 15 | assert len(obs_reset.done) == 100 16 | 17 | actions = { 18 | "Move": RaggedBufferI64.from_array(np.zeros((100, 2, 1), np.int64)), 19 | "Fire Orbital Cannon": RaggedBufferI64.from_array( 20 | np.zeros((100, 0, 1), np.int64) 21 | ), 22 | } 23 | obs_act = envs.act(actions, obs_space) 24 | 25 | assert len(obs_act.done) == 100 26 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "entity-gym" 3 | version = "0.1.10" 4 | description = "Entity Gym" 5 | authors = ["Clemens Winter "] 6 | license = "MIT" 7 | readme = "README.md" 8 | classifiers = [ 9 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 10 | "License :: OSI Approved :: Apache Software License", 11 | "License :: OSI Approved :: MIT License", 12 | ] 13 | 14 | [tool.poetry.dependencies] 15 | python = ">=3.7.1,<3.11" 16 | ragged-buffer = "^0.4.3" 17 | msgpack = "^1.0.3" 18 | msgpack-numpy = "^0.4.7" 19 | cloudpickle = "^2.0.0" 20 | tqdm = "^4.63.1" 21 | click = "^8.1.3" 22 | 23 | [tool.poetry.dev-dependencies] 24 | sphinx-rtd-theme = "^1.0.0" 25 | Sphinx = "^4.5.0" 26 | autoapi = "^2.0.1" 27 | pytest = "^7.1.2" 28 | pre-commit = "^2.19.0" 29 | mypy = "^0.950" 30 | 31 | [build-system] 32 | requires = ["poetry-core>=1.0.0"] 33 | build-backend = "poetry.core.masonry.api" 34 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Entity Neural Network developers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /entity_gym/dataclass_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import is_dataclass 2 | from typing import Any, Dict, List, Type 3 | 4 | import numpy as np 5 | 6 | from .env import Entity, ObsSpace 7 | 8 | 9 | def obs_space_from_dataclasses(*dss: Type) -> ObsSpace: 10 | entities = {} 11 | for ds in dss: 12 | if not is_dataclass(ds): 13 | raise ValueError(f"{ds} is not a dataclass") 14 | # TODO: check field types are valid 15 | entities[ds.__name__] = Entity( 16 | features=list( 17 | key for key in ds.__dataclass_fields__.keys() if not key.startswith("_") 18 | ), 19 | ) 20 | return ObsSpace(entities=entities) 21 | 22 | 23 | def extract_features( 24 | entities: Dict[str, List[Any]], obs_filter: ObsSpace 25 | ) -> Dict[str, np.ndarray]: 26 | selectors = {} 27 | for entity_name, entity in obs_filter.entities.items(): 28 | selectors[entity_name] = np.array( 29 | [[getattr(e, f) for f in entity.features] for e in entities[entity_name]], 30 | dtype=np.float32, 31 | ).reshape(len(entities[entity_name]), len(entity.features)) 32 | return selectors 33 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | disallow_untyped_defs = True 3 | disallow_any_unimported = True 4 | no_implicit_optional = True 5 | check_untyped_defs = True 6 | warn_return_any = True 7 | show_error_codes = True 8 | warn_unused_ignores = True 9 | 10 | exclude = wandb 11 | 12 | [mypy-torch_scatter.*] 13 | ignore_missing_imports = True 14 | [mypy-wandb.*] 15 | ignore_missing_imports = True 16 | [mypy-msgpack.*] 17 | ignore_missing_imports = True 18 | [mypy-msgpack_numpy.*] 19 | ignore_missing_imports = True 20 | [mypy-tqdm.*] 21 | ignore_missing_imports = True 22 | [mypy-griddly.*] 23 | ignore_missing_imports = True 24 | [mypy-cloudpickle.*] 25 | ignore_missing_imports = True 26 | [mypy-optuna.*] 27 | ignore_missing_imports = True 28 | [mypy-orjson.*] 29 | ignore_missing_imports = True 30 | [mypy-gym_microrts.*] 31 | ignore_missing_imports = True 32 | [mypy-jpype.*] 33 | ignore_missing_imports = True 34 | [mypy-rts.*] 35 | ignore_missing_imports = True 36 | [mypy-ai.*] 37 | ignore_missing_imports = True 38 | [mypy-ts.*] 39 | ignore_missing_imports = True 40 | [mypy-PIL.*] 41 | ignore_missing_imports = True 42 | [mypy-web_pdb.*] 43 | ignore_missing_imports = True 44 | [mypy-procgen.*] 45 | ignore_missing_imports = True 46 | [mypy-opensimplex.*] 47 | ignore_missing_imports = True 48 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/asottile/pyupgrade 3 | rev: v2.31.1 4 | hooks: 5 | - id: pyupgrade 6 | args: 7 | - --py38-plus 8 | - repo: https://github.com/PyCQA/isort 9 | rev: 5.12.0 10 | hooks: 11 | - id: isort 12 | args: 13 | - --profile=black 14 | - --skip-glob=wandb/**/* 15 | - --thirdparty=wandb 16 | - repo: https://github.com/myint/autoflake 17 | rev: v1.4 18 | hooks: 19 | - id: autoflake 20 | args: 21 | - -r 22 | - --exclude=wandb 23 | - --in-place 24 | - --remove-unused-variables 25 | - --remove-all-unused-imports 26 | - repo: https://github.com/python/black 27 | rev: 22.3.0 28 | hooks: 29 | - id: black 30 | args: 31 | - --exclude=wandb 32 | - repo: https://github.com/codespell-project/codespell 33 | rev: v2.1.0 34 | hooks: 35 | - id: codespell 36 | args: 37 | - --ignore-words-list=nd,reacher,thist,ths,magent,crate 38 | - --skip=docs/css/termynal.css,docs/js/termynal.js 39 | - repo: local 40 | hooks: 41 | - id: mypy 42 | name: mypy 43 | entry: mypy 44 | language: system 45 | types: [python] 46 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | import os 3 | import sys 4 | 5 | sys.path.insert(0, os.path.abspath("../..")) 6 | 7 | # -- Project information 8 | 9 | project = "Entity Gym" 10 | copyright = "2021-2022, Clemens Winter" 11 | author = "Clemens Winter" 12 | 13 | # -- General configuration 14 | 15 | extensions = [ 16 | "sphinx.ext.duration", 17 | "sphinx.ext.doctest", 18 | "sphinx.ext.autodoc", 19 | "sphinx.ext.autosummary", 20 | "sphinx.ext.intersphinx", 21 | "sphinx.ext.inheritance_diagram", 22 | "autoapi.sphinx", 23 | ] 24 | # extensions = ["sphinx.ext.autodoc", "sphinx.ext.inheritance_diagram", "autoapi.sphinx"] 25 | 26 | intersphinx_mapping = { 27 | "python": ("https://docs.python.org/3/", None), 28 | "sphinx": ("https://www.sphinx-doc.org/en/master/", None), 29 | } 30 | intersphinx_disabled_domains = ["std"] 31 | 32 | autoapi_modules = { 33 | "entity_gym": { 34 | "prune": True, 35 | } 36 | } 37 | 38 | templates_path = ["_templates"] 39 | 40 | # -- Options for HTML output 41 | 42 | html_theme = "sphinx_rtd_theme" 43 | 44 | # -- Options for EPUB output 45 | epub_show_urls = "footnote" 46 | 47 | autosummary_generate = True 48 | 49 | autosummary_ignore_module_all = False 50 | autosummary_imported_members = False 51 | -------------------------------------------------------------------------------- /entity_gym/examples/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Type 2 | 3 | from entity_gym.env import Environment 4 | from entity_gym.examples.cherry_pick import CherryPick 5 | from entity_gym.examples.count import Count 6 | from entity_gym.examples.floor_is_lava import FloorIsLava 7 | from entity_gym.examples.minefield import Minefield 8 | from entity_gym.examples.minesweeper import MineSweeper 9 | from entity_gym.examples.move_to_origin import MoveToOrigin 10 | from entity_gym.examples.multi_armed_bandit import MultiArmedBandit 11 | from entity_gym.examples.multi_snake import MultiSnake 12 | from entity_gym.examples.not_hotdog import NotHotdog 13 | from entity_gym.examples.pick_matching_balls import PickMatchingBalls 14 | from entity_gym.examples.rock_paper_scissors import RockPaperScissors 15 | from entity_gym.examples.tutorial import TreasureHunt 16 | from entity_gym.examples.xor import Xor 17 | 18 | ENV_REGISTRY: Dict[str, Type[Environment]] = { 19 | "MoveToOrigin": MoveToOrigin, 20 | "CherryPick": CherryPick, 21 | "PickMatchingBalls": PickMatchingBalls, 22 | "Minefield": Minefield, 23 | "MultiSnake": MultiSnake, 24 | "MultiArmedBandit": MultiArmedBandit, 25 | "NotHotdog": NotHotdog, 26 | "Xor": Xor, 27 | "Count": Count, 28 | "FloorIsLava": FloorIsLava, 29 | "MineSweeper": MineSweeper, 30 | "RockPaperScissors": RockPaperScissors, 31 | "TreasureHunt": TreasureHunt, 32 | } 33 | 34 | __all__ = [ 35 | "MoveToOrigin", 36 | "CherryPick", 37 | "PickMatchingBalls", 38 | "Minefield", 39 | "MultiSnake", 40 | "MultiArmedBandit", 41 | "NotHotdog", 42 | "Xor", 43 | "Count", 44 | "FloorIsLava", 45 | "MineSweeper", 46 | "RockPaperScissors", 47 | "TreasureHunt", 48 | ] 49 | -------------------------------------------------------------------------------- /entity_gym/ragged_dict.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict, Generic, Mapping, Type, TypeVar 3 | 4 | import numpy as np 5 | import numpy.typing as npt 6 | from ragged_buffer import RaggedBuffer 7 | 8 | from .env.vec_env import VecActionMask 9 | 10 | ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True) 11 | 12 | 13 | @dataclass 14 | class RaggedBatchDict(Generic[ScalarType]): 15 | rb_cls: Type[RaggedBuffer[ScalarType]] 16 | buffers: Dict[str, RaggedBuffer[ScalarType]] = field(default_factory=dict) 17 | 18 | def extend(self, batch: Mapping[str, RaggedBuffer[ScalarType]]) -> None: 19 | for k, v in batch.items(): 20 | if k not in self.buffers: 21 | self.buffers[k] = v 22 | else: 23 | self.buffers[k].extend(v) 24 | 25 | def clear(self) -> None: 26 | for buffer in self.buffers.values(): 27 | buffer.clear() 28 | 29 | def __getitem__( 30 | self, index: npt.NDArray[np.int64] 31 | ) -> Dict[str, RaggedBuffer[ScalarType]]: 32 | return {k: v[index] for k, v in self.buffers.items()} 33 | 34 | 35 | @dataclass 36 | class RaggedActionDict: 37 | buffers: Dict[str, VecActionMask] = field(default_factory=dict) 38 | 39 | def extend(self, batch: Mapping[str, VecActionMask]) -> None: 40 | for k, v in batch.items(): 41 | if k not in self.buffers: 42 | self.buffers[k] = v 43 | else: 44 | self.buffers[k].extend(v) 45 | 46 | def clear(self) -> None: 47 | for buffer in self.buffers.values(): 48 | buffer.clear() 49 | 50 | def __getitem__(self, index: npt.NDArray[np.int64]) -> Dict[str, VecActionMask]: 51 | return {k: v[index] for k, v in self.buffers.items()} 52 | -------------------------------------------------------------------------------- /entity_gym/examples/multi_armed_bandit.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Mapping 3 | 4 | from entity_gym.env import Action, ActionSpace, Environment, Observation, ObsSpace 5 | from entity_gym.env.environment import ( 6 | GlobalCategoricalAction, 7 | GlobalCategoricalActionMask, 8 | GlobalCategoricalActionSpace, 9 | ) 10 | 11 | 12 | @dataclass 13 | class MultiArmedBandit(Environment): 14 | """ 15 | Task with single cateorical action with 5 choices which gives a reward of 1 for choosing action 0 and reward of 0 otherwise. 16 | """ 17 | 18 | def obs_space(cls) -> ObsSpace: 19 | return ObsSpace(global_features=["step"]) 20 | 21 | def action_space(cls) -> Dict[str, ActionSpace]: 22 | return { 23 | "pull": GlobalCategoricalActionSpace(["A", "B", "C", "D", "E"]), 24 | } 25 | 26 | def reset(self) -> Observation: 27 | self.step = 0 28 | self._total_reward = 0.0 29 | return self.observe() 30 | 31 | def act(self, actions: Mapping[str, Action]) -> Observation: 32 | self.step += 1 33 | 34 | a = actions["pull"] 35 | assert isinstance( 36 | a, GlobalCategoricalAction 37 | ), f"{a} is not a GlobalCategoricalAction" 38 | if a.label == "A": 39 | reward = 1 / 32.0 40 | else: 41 | reward = 0 42 | done = self.step >= 32 43 | self._total_reward += reward 44 | return self.observe(done, reward) 45 | 46 | def observe(self, done: bool = False, reward: float = 0) -> Observation: 47 | return Observation( 48 | global_features=[self.step], 49 | actions={ 50 | "pull": GlobalCategoricalActionMask(), 51 | }, 52 | reward=reward, 53 | done=done, 54 | ) 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Entity Gym 2 | 3 | [![Actions Status](https://github.com/entity-neural-network/entity-gym/workflows/Checks/badge.svg)](https://github.com/entity-neural-network/entity-gym/actions) 4 | [![PyPI](https://img.shields.io/pypi/v/entity-gym.svg?style=flat-square)](https://pypi.org/project/entity-gym/) 5 | [![Documentation Status](https://readthedocs.org/projects/entity-gym/badge/?version=latest&style=flat-square)](https://entity-gym.readthedocs.io/en/latest/?badge=latest) 6 | [![Discord](https://img.shields.io/discord/913497968701747270?style=flat-square)](https://discord.gg/SjVqhSW4Qf) 7 | 8 | 9 | Entity Gym is an open source Python library that defines an _entity based_ API for reinforcement learning environments. 10 | Entity Gym extends the standard paradigm of fixed-size observation spaces by allowing observations to contain dynamically-sized lists of entities. 11 | This enables a seamless and highly efficient interface with simulators, games, and other complex environments whose state can be naturally expressed as a collection of entities. 12 | 13 | The [enn-trainer library](https://github.com/entity-neural-network/enn-trainer) can be used to train agents for Entity Gym environments. 14 | 15 | ## Installation 16 | 17 | ``` 18 | pip install entity-gym 19 | ``` 20 | 21 | ## Usage 22 | 23 | You can find tutorials, guides, and an API reference on the [Entity Gym documentation website](https://entity-gym.readthedocs.io/en/latest/index.html). 24 | 25 | ## Examples 26 | 27 | A number of simple example environments can be found in [entity_gym/examples](https://github.com/entity-neural-network/entity-gym/tree/main/entity_gym/examples). More complex examples can be found in the [ENN-Zoo](https://github.com/entity-neural-network/incubator/tree/main/enn_zoo/enn_zoo) project, which contains Entity Gym bindings for [Procgen](https://github.com/openai/procgen), [Griddly](https://github.com/Bam4d/Griddly), [MicroRTS](https://github.com/santiontanon/microrts), [VizDoom](https://github.com/mwydmuch/ViZDoom), and [CodeCraft](https://github.com/cswinter/DeepCodeCraft). 28 | -------------------------------------------------------------------------------- /entity_gym/simple_trace.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import defaultdict 3 | from contextlib import contextmanager 4 | from typing import DefaultDict, Dict, Generator, List 5 | 6 | 7 | class Tracer: 8 | def __init__(self, cuda: bool = True) -> None: 9 | if cuda: 10 | try: 11 | import torch 12 | 13 | cuda = torch.cuda.is_available() 14 | except ImportError: 15 | cuda = False 16 | 17 | self.start_time: List[float] = [] 18 | self.callstack: List[str] = [] 19 | self.total_time: DefaultDict[str, float] = defaultdict(float) 20 | self.cuda = cuda 21 | 22 | def start(self, name: str) -> None: 23 | self.callstack.append(name) 24 | self.start_time.append(time.time()) 25 | 26 | def end(self, name: str) -> None: 27 | if self.cuda: 28 | import torch 29 | 30 | torch.cuda.synchronize() 31 | self.total_time[self.stack] += time.time() - self.start_time.pop() 32 | actual_name = self.callstack.pop() 33 | assert ( 34 | actual_name == name 35 | ), f"Expected to complete {name}, but currently active span is {actual_name}" 36 | 37 | def finish(self) -> Dict[str, float]: 38 | assert ( 39 | len(self.callstack) == 0 40 | ), f"Cannot finish when there are open traces: {self.stack}" 41 | self_times: Dict[str, float] = {} 42 | # Traverse the tree depth-first 43 | for name in reversed(sorted(self.total_time.keys())): 44 | time_in_children = sum( 45 | t for child, t in self_times.items() if child.startswith(name) 46 | ) 47 | self_times[f"{name}[self]"] = self.total_time[name] - time_in_children 48 | 49 | self_times.update(self.total_time) 50 | self.total_time = defaultdict(float) 51 | return self_times 52 | 53 | @contextmanager 54 | def span(self, name: str) -> Generator[None, None, None]: 55 | self.start(name) 56 | yield 57 | self.end(name) 58 | 59 | @property 60 | def stack(self) -> str: 61 | return ".".join(self.callstack) 62 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | workflow_dispatch: 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | defaults: 13 | run: 14 | working-directory: entity_gym 15 | strategy: 16 | matrix: 17 | python-version: [3.8] 18 | fail-fast: false 19 | 20 | environment: PyPI 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | #---------------------------------------------- 28 | # ----- install & configure poetry ----- 29 | #---------------------------------------------- 30 | - name: Install Poetry 31 | uses: snok/install-poetry@v1 32 | with: 33 | virtualenvs-create: true 34 | virtualenvs-in-project: true 35 | installer-parallel: true 36 | #---------------------------------------------- 37 | # load cached venv if cache exists 38 | #---------------------------------------------- 39 | - name: Load cached venv 40 | id: cached-poetry-dependencies 41 | uses: actions/cache@v2 42 | with: 43 | path: .venv 44 | key: venv-${{ runner.os }}-${{ hashFiles('poetry.lock') }} 45 | #---------------------------------------------- 46 | # install dependencies if cache does not exist 47 | #---------------------------------------------- 48 | - name: Install dependencies 49 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 50 | continue-on-error: true 51 | run: poetry install --no-interaction --no-root 52 | #---------------------------------------------- 53 | # install your root project, if required 54 | #---------------------------------------------- 55 | - name: Install library 56 | run: poetry install --no-interaction 57 | 58 | - name: Publish entity-gym to PyPI 59 | run: | 60 | poetry build 61 | poetry publish --username __token__ --password ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /entity_gym/examples/xor.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Dict, Mapping 4 | 5 | from entity_gym.env import Action, ActionSpace, Environment, Observation, ObsSpace 6 | from entity_gym.env.environment import ( 7 | Entity, 8 | GlobalCategoricalAction, 9 | GlobalCategoricalActionMask, 10 | GlobalCategoricalActionSpace, 11 | ) 12 | 13 | 14 | @dataclass 15 | class Input: 16 | is_set: float 17 | 18 | 19 | class Xor(Environment): 20 | """ 21 | There are three entities types, each with one instance on each timstep. 22 | The Bit1 and Bit2 entities are randomly set to 0 or 1. 23 | The Output entity has one action that should be set to the output of the XOR between the two bits. 24 | """ 25 | 26 | def obs_space(self) -> ObsSpace: 27 | return ObsSpace( 28 | global_features=["negate"], 29 | entities={"Input": Entity(["is_set"])}, 30 | ) 31 | 32 | def action_space(self) -> Dict[str, ActionSpace]: 33 | return {"output": GlobalCategoricalActionSpace(["0", "1"])} 34 | 35 | def reset_filter(self, obs_space: ObsSpace) -> Observation: 36 | self.bit1 = random.choice([0.0, 1.0]) 37 | self.bit2 = random.choice([0.0, 1.0]) 38 | self.negate = random.choice([0.0, 1.0]) 39 | return self.observe(obs_space) 40 | 41 | def reset(self) -> Observation: 42 | return self.reset_filter(self.obs_space()) 43 | 44 | def act_filter( 45 | self, action: Mapping[str, Action], obs_filter: ObsSpace 46 | ) -> Observation: 47 | reward = 0.0 48 | a = action["output"] 49 | assert isinstance(a, GlobalCategoricalAction) 50 | if a.index == self.negate and self.bit1 == self.bit2: 51 | reward = 1.0 52 | elif a.index == 1.0 - self.negate and self.bit1 != self.bit2: 53 | reward = 1.0 54 | 55 | return self.observe(obs_filter, done=True, reward=reward) 56 | 57 | def act(self, actions: Mapping[str, Action]) -> Observation: 58 | return self.act_filter( 59 | actions, 60 | self.obs_space(), 61 | ) 62 | 63 | def observe( 64 | self, obs_filter: ObsSpace, done: bool = False, reward: float = 0.0 65 | ) -> Observation: 66 | return Observation( 67 | reward=reward, 68 | done=done, 69 | features={"Input": [[self.bit1], [self.bit2]]}, 70 | global_features=[self.negate], 71 | actions={ 72 | "output": GlobalCategoricalActionMask(), 73 | }, 74 | ) 75 | -------------------------------------------------------------------------------- /entity_gym/env/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The environment module defines the core interfaces that make up an `Environment <#entity_gym.env.Environment>`_. 3 | 4 | Actions 5 | ------- 6 | 7 | Actions are how agents interact with the environment. 8 | There are three parts to every action: 9 | 10 | * :code:`ActionSpace` defines the shape of the action. For example, a categorical action space consisting of the 4 discrete choices "up", "down", "left", and "right". 11 | * :code:`ActionMask` is used to further constrain the available actions on a specific timestep. For example, only "up" and "down" may be available some timestep. 12 | * :code:`Action` represent the actual action that is chosen by an agent. For example, the "down" action may have been chosen. 13 | 14 | There are currently three different action spaces: 15 | 16 | * `GlobalCategoricalActionSpace <#entity_gym.env.GlobalCategoricalActionSpace>`_ allows the agent to choose a single option from a discrete set of actions. 17 | * `CategoricalActionSpace <#entity_gym.env.CategoricalActionSpace>`_ allows multiple entities to choose a single option from a discrete set of actions. 18 | * `SelectEntityActionSpace <#entity_gym.env.SelectEntityActionSpace>`_ allows multiple entities to choose another entity. 19 | 20 | Observations 21 | ------------ 22 | 23 | Observations are how agents receive information from the environment. 24 | Each `Environment <#entity_gym.env.Environment>`_ must define an `ObsSpace <#entity_gym.env.ObsSpace>`_, which specifies the shape of the observations returned by this environment. 25 | On each timestep, the environment returns an `Observation <#entity_gym.env.Observation>`_ object, which contains all the entities and features that are visible to the agent. 26 | """ 27 | from .action import * 28 | from .add_metrics_wrapper import AddMetricsWrapper 29 | from .env_list import * 30 | from .environment import * 31 | from .parallel_env_list import * 32 | from .validator import ValidatingEnv 33 | from .vec_env import * 34 | 35 | __all__ = [ 36 | "Environment", 37 | # Observation 38 | "ObsSpace", 39 | "Entity", 40 | "Observation", 41 | "EntityName", 42 | "ActionName", 43 | "EntityID", 44 | # Action 45 | "Action", 46 | "ActionSpace", 47 | "ActionMask", 48 | "CategoricalActionSpace", 49 | "CategoricalAction", 50 | "CategoricalActionMask", 51 | "GlobalCategoricalActionSpace", 52 | "GlobalCategoricalAction", 53 | "GlobalCategoricalActionMask", 54 | "SelectEntityAction", 55 | "SelectEntityActionSpace", 56 | "SelectEntityActionMask", 57 | # VecEnv 58 | "VecEnv", 59 | "EnvList", 60 | "ParallelEnvList", 61 | "VecActionMask", 62 | "VecObs", 63 | "VecCategoricalActionMask", 64 | "VecSelectEntityActionMask", 65 | # Wrappers 66 | "ValidatingEnv", 67 | "AddMetricsWrapper", 68 | ] 69 | -------------------------------------------------------------------------------- /entity_gym/examples/tutorial.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict, Mapping, Tuple 3 | 4 | from entity_gym.env import * 5 | from entity_gym.runner import CliRunner 6 | 7 | 8 | class TreasureHunt(Environment): 9 | def reset(self) -> Observation: 10 | self.x_pos = 0 11 | self.y_pos = 0 12 | self.game_over = False 13 | self.traps = [] 14 | self.treasure = [] 15 | for _ in range(5): 16 | self.traps.append(self._random_empty_pos()) 17 | for _ in range(5): 18 | self.treasure.append(self._random_empty_pos()) 19 | return self.observe() 20 | 21 | def obs_space(self) -> ObsSpace: 22 | return ObsSpace( 23 | global_features=["x_pos", "y_pos"], 24 | entities={ 25 | "Trap": Entity(features=["x_pos", "y_pos"]), 26 | "Treasure": Entity(features=["x_pos", "y_pos"]), 27 | }, 28 | ) 29 | 30 | def action_space(self) -> Dict[str, ActionSpace]: 31 | # The `GlobalCategoricalActionSpace` allows the agent to choose from set of discrete actions. 32 | return { 33 | "move": GlobalCategoricalActionSpace( 34 | index_to_label=["up", "down", "left", "right"] 35 | ) 36 | } 37 | 38 | def _random_empty_pos(self) -> Tuple[int, int]: 39 | # Generate a random position on the grid that is not occupied by a trap, treasure, or player. 40 | while True: 41 | x = random.randint(-5, 5) 42 | y = random.randint(-5, 5) 43 | if (x, y) not in (self.traps + self.treasure + [(self.x_pos, self.y_pos)]): 44 | return x, y 45 | 46 | def act(self, actions: Mapping[ActionName, Action]) -> Observation: 47 | action = actions["move"] 48 | assert isinstance(action, GlobalCategoricalAction) 49 | if action.label == "up" and self.y_pos < 10: 50 | self.y_pos += 1 51 | elif action.label == "down" and self.y_pos > -10: 52 | self.y_pos -= 1 53 | elif action.label == "left" and self.x_pos > -10: 54 | self.x_pos -= 1 55 | elif action.label == "right" and self.x_pos < 10: 56 | self.x_pos += 1 57 | 58 | reward = 0.0 59 | if (self.x_pos, self.y_pos) in self.treasure: 60 | reward = 1.0 61 | self.treasure.remove((self.x_pos, self.y_pos)) 62 | if (self.x_pos, self.y_pos) in self.traps or len(self.treasure) == 0: 63 | self.game_over = True 64 | 65 | return self.observe(reward) 66 | 67 | def observe(self, reward: float = 0.0) -> Observation: 68 | return Observation( 69 | global_features=[self.x_pos, self.y_pos], 70 | features={ 71 | "Trap": self.traps, 72 | "Treasure": self.treasure, 73 | }, 74 | done=self.game_over, 75 | reward=reward, 76 | actions={"move": GlobalCategoricalActionMask()}, 77 | ) 78 | 79 | 80 | if __name__ == "__main__": 81 | env = TreasureHunt() 82 | CliRunner(env).run() 83 | -------------------------------------------------------------------------------- /entity_gym/examples/not_hotdog.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Dict, Mapping 4 | 5 | import numpy as np 6 | 7 | from entity_gym.env import ( 8 | Action, 9 | ActionSpace, 10 | CategoricalAction, 11 | CategoricalActionMask, 12 | CategoricalActionSpace, 13 | Entity, 14 | Environment, 15 | Observation, 16 | ObsSpace, 17 | ) 18 | 19 | 20 | @dataclass 21 | class NotHotdog(Environment): 22 | """ 23 | On each timestep, there is either a generic "Object" entity with a `is_hotdog` property, or a "Hotdog" object. 24 | The "Player" entity is always present, and has an action to classify the other entity as hotdog or not hotdog. 25 | """ 26 | 27 | def obs_space(self) -> ObsSpace: 28 | return ObsSpace( 29 | entities={ 30 | "Player": Entity(["step"]), 31 | "Object": Entity(["is_hotdog"]), 32 | "Hotdog": Entity([]), 33 | } 34 | ) 35 | 36 | def action_space(self) -> Dict[str, ActionSpace]: 37 | return { 38 | "classify": CategoricalActionSpace(["hotdog", "not_hotdog"]), 39 | "unused_action": CategoricalActionSpace(["0", "1"]), 40 | } 41 | 42 | def reset(self) -> Observation: 43 | self.step = 0 44 | self.is_hotdog = random.randint(0, 1) 45 | self.hotdog_object = random.randint(0, 1) == 1 46 | return self.observe() 47 | 48 | def act(self, actions: Mapping[str, Action]) -> Observation: 49 | self.step += 1 50 | 51 | a = actions["classify"] 52 | assert isinstance(a, CategoricalAction), f"{a} is not a CategoricalAction" 53 | if a.indices[0] == self.is_hotdog: 54 | reward = 1 55 | else: 56 | reward = 0 57 | done = True 58 | return self.observe(done, reward) 59 | 60 | def observe(self, done: bool = False, reward: float = 0) -> Observation: 61 | return Observation( 62 | features={ 63 | "Player": np.array( 64 | [ 65 | [ 66 | self.step, 67 | ] 68 | ], 69 | dtype=np.float32, 70 | ), 71 | "Object": np.array( 72 | [ 73 | [ 74 | self.is_hotdog, 75 | ] 76 | ], 77 | dtype=np.float32, 78 | ) 79 | if (self.hotdog_object and self.is_hotdog == 0) 80 | or not self.hotdog_object 81 | else np.zeros((0, 1), dtype=np.float32).reshape(0, 1), 82 | "Hotdog": np.zeros((1, 0), dtype=np.float32) 83 | if self.hotdog_object and self.is_hotdog == 1 84 | else np.zeros((0, 0), dtype=np.float32), 85 | }, 86 | actions={ 87 | "classify": CategoricalActionMask(actor_ids=[0]), 88 | "unused_action": CategoricalActionMask(actor_ids=[]), 89 | }, 90 | ids={"Player": [0]}, 91 | reward=reward, 92 | done=done, 93 | ) 94 | -------------------------------------------------------------------------------- /entity_gym/examples/cherry_pick.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict, List, Mapping 3 | 4 | import numpy as np 5 | 6 | from entity_gym.env import ( 7 | Action, 8 | ActionSpace, 9 | Entity, 10 | EntityID, 11 | Environment, 12 | Observation, 13 | ObsSpace, 14 | SelectEntityAction, 15 | SelectEntityActionMask, 16 | SelectEntityActionSpace, 17 | ) 18 | from entity_gym.env.environment import EntityName 19 | 20 | 21 | @dataclass 22 | class CherryPick(Environment): 23 | """ 24 | The CherryPick environment is initialized with a list of 32 cherries of random quality. 25 | On each timestep, the player can pick up one of the cherries. 26 | The player receives a reward of the quality of the cherry picked. 27 | The environment ends after 16 steps. 28 | The quality of the top 16 cherries is normalized so that the maximum total achievable reward is 1.0. 29 | """ 30 | 31 | num_cherries: int = 32 32 | cherries: List[float] = field(default_factory=list) 33 | last_reward: float = 0.0 34 | step: int = 0 35 | 36 | def obs_space(self) -> ObsSpace: 37 | return ObsSpace( 38 | entities={ 39 | "Cherry": Entity(["quality"]), 40 | "Player": Entity([]), 41 | } 42 | ) 43 | 44 | def action_space(self) -> Dict[str, ActionSpace]: 45 | return {"Pick Cherry": SelectEntityActionSpace()} 46 | 47 | def reset(self) -> Observation: 48 | cherries = [np.random.normal() for _ in range(self.num_cherries)] 49 | # Normalize so that the sum of the top half is 1.0 50 | top_half = sorted(cherries, reverse=True)[: self.num_cherries // 2] 51 | sum_top_half = sum(top_half) 52 | add = 2.0 * (1.0 - sum_top_half) / self.num_cherries 53 | self.cherries = [c + add for c in cherries] 54 | self.last_reward = 0.0 55 | self.step = 0 56 | self.total_reward = 0.0 57 | return self.observe() 58 | 59 | def observe(self) -> Observation: 60 | done = self.step == self.num_cherries // 2 61 | ids: Dict[EntityName, List[EntityID]] = { 62 | "Cherry": [("Cherry", a) for a in range(len(self.cherries))], 63 | "Player": ["Player"], 64 | } 65 | return Observation( 66 | features={ 67 | "Cherry": np.array(self.cherries, dtype=np.float32).reshape(-1, 1), 68 | "Player": np.zeros([1, 0], dtype=np.float32), 69 | }, 70 | ids=ids, 71 | actions={ 72 | "Pick Cherry": SelectEntityActionMask( 73 | actor_ids=["Player"], 74 | actee_ids=ids["Cherry"], 75 | ), 76 | }, 77 | reward=self.last_reward, 78 | done=done, 79 | ) 80 | 81 | def act(self, actions: Mapping[str, Action]) -> Observation: 82 | assert len(actions) == 1, actions 83 | a = actions["Pick Cherry"] 84 | assert isinstance(a, SelectEntityAction) 85 | _, chosen_cherry_idx = a.actees[0] 86 | self.last_reward = self.cherries.pop(chosen_cherry_idx) 87 | self.total_reward += self.last_reward 88 | self.step += 1 89 | return self.observe() 90 | -------------------------------------------------------------------------------- /entity_gym/serialization/msgpack_ragged.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any 3 | 4 | import msgpack_numpy 5 | from ragged_buffer import RaggedBufferBool, RaggedBufferF32, RaggedBufferI64 6 | 7 | from entity_gym.env.environment import * 8 | from entity_gym.env.vec_env import * 9 | 10 | # For security reasons we don't want to deserialize classes that are not in this list. 11 | WHITELIST = { 12 | "ObsSpace": ObsSpace, 13 | "VecObs": VecObs, 14 | "VecCategoricalActionMask": VecCategoricalActionMask, 15 | "VecSelectEntityActionMask": VecSelectEntityActionMask, 16 | "SelectEntityAction": SelectEntityAction, 17 | "CategoricalAction": CategoricalAction, 18 | "GlobalCategoricalAction": GlobalCategoricalAction, 19 | "SelectEntityActionSpace": SelectEntityActionSpace, 20 | "CategoricalActionSpace": CategoricalActionSpace, 21 | "GlobalCategoricalActionSpace": GlobalCategoricalActionSpace, 22 | "Entity": Entity, 23 | "Metric": Metric, 24 | } 25 | 26 | 27 | def ragged_buffer_encode(obj: Any) -> Any: 28 | # Suppress "msgpack_numpy.py:96: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison" 29 | with warnings.catch_warnings(): 30 | warnings.simplefilter(action="ignore", category=FutureWarning) 31 | 32 | if isinstance(obj, RaggedBufferF32) or isinstance(obj, RaggedBufferI64) or isinstance(obj, RaggedBufferBool): # type: ignore 33 | flattened = obj.as_array() 34 | lengths = obj.size1() 35 | return { 36 | "__flattened__": msgpack_numpy.encode(flattened), 37 | "__lengths__": msgpack_numpy.encode(lengths), 38 | } 39 | elif hasattr(obj, "__dict__"): 40 | return {"__classname__": obj.__class__.__name__, "data": vars(obj)} 41 | else: 42 | return obj 43 | 44 | 45 | def ragged_buffer_decode(obj: Any) -> Any: 46 | # Suppress "msgpack_numpy.py:96: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison" 47 | with warnings.catch_warnings(): 48 | warnings.simplefilter(action="ignore", category=FutureWarning) 49 | 50 | if "__flattened__" in obj: 51 | flattened = msgpack_numpy.decode(obj["__flattened__"]) 52 | lengths = msgpack_numpy.decode(obj["__lengths__"]) 53 | 54 | dtype = flattened.dtype 55 | 56 | if dtype == np.float32: 57 | return RaggedBufferF32.from_flattened(flattened, lengths) 58 | elif dtype == int or dtype == np.int64: 59 | return RaggedBufferI64.from_flattened(flattened, lengths) 60 | elif dtype == bool or dtype == np.bool8: 61 | return RaggedBufferBool.from_flattened(flattened, lengths) 62 | else: 63 | raise ValueError(f"Unsupported RaggedBuffer dtype: {dtype}") 64 | elif "__classname__" in obj: 65 | classname = obj["__classname__"] 66 | if classname in WHITELIST: 67 | cls_name = globals()[classname] 68 | return cls_name(**obj["data"]) 69 | else: 70 | raise RuntimeError( 71 | f"Attempt to deserialize class {classname} outside whitelist." 72 | ) 73 | else: 74 | return obj 75 | -------------------------------------------------------------------------------- /entity_gym/examples/count.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Dict, Mapping, Optional 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | 8 | from entity_gym.dataclass_utils import extract_features, obs_space_from_dataclasses 9 | from entity_gym.env import ( 10 | Action, 11 | ActionSpace, 12 | CategoricalAction, 13 | CategoricalActionMask, 14 | CategoricalActionSpace, 15 | Environment, 16 | Observation, 17 | ObsSpace, 18 | ) 19 | 20 | 21 | @dataclass 22 | class Player: 23 | pass 24 | 25 | 26 | @dataclass 27 | class Bean: 28 | pass 29 | 30 | 31 | class Count(Environment): 32 | """ 33 | There are between 0 and 10 "Bean" entities. 34 | The "Player" entity gets 1 reward for counting the correct number of beans and 0 otherwise. 35 | 36 | This environment also randomly masks off some of the incorrect answers. 37 | 38 | Masking by default allows all actions, which is equivalent to disabling masking. 39 | """ 40 | 41 | def __init__(self, masked_choices: int = 10): 42 | assert ( 43 | masked_choices >= 1 and masked_choices <= 10 44 | ), "masked_choices must be between 1 and 10" 45 | self.masked_choices = masked_choices 46 | 47 | def obs_space(self) -> ObsSpace: 48 | return obs_space_from_dataclasses(Player, Bean) 49 | 50 | def action_space(self) -> Dict[str, ActionSpace]: 51 | return { 52 | "count": CategoricalActionSpace( 53 | ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] 54 | ) 55 | } 56 | 57 | def reset_filter(self, obs_space: ObsSpace) -> Observation: 58 | self.count = random.randint(0, self.masked_choices - 1) 59 | possible_counts = { 60 | self.count, 61 | *random.sample( 62 | range(0, self.masked_choices), random.randint(0, self.masked_choices) 63 | ), 64 | } 65 | mask = np.zeros((1, 10), dtype=np.bool_) 66 | mask[:, list(possible_counts)] = True 67 | return self.observe(obs_space, mask) 68 | 69 | def reset(self) -> Observation: 70 | return self.reset_filter(self.obs_space()) 71 | 72 | def act_filter( 73 | self, action: Mapping[str, Action], obs_filter: ObsSpace 74 | ) -> Observation: 75 | reward = 0.0 76 | assert len(action) == 1 77 | a = action["count"] 78 | assert isinstance(a, CategoricalAction) 79 | assert len(a.indices) == 1 80 | choice = a.indices[0] 81 | if choice == self.count: 82 | reward = 1.0 83 | return self.observe(obs_filter, None, done=True, reward=reward) 84 | 85 | def act(self, actions: Mapping[str, Action]) -> Observation: 86 | return self.act_filter( 87 | actions, 88 | self.obs_space(), 89 | ) 90 | 91 | def observe( 92 | self, 93 | obs_filter: ObsSpace, 94 | mask: Optional[npt.NDArray[np.bool_]], 95 | done: bool = False, 96 | reward: float = 0.0, 97 | ) -> Observation: 98 | return Observation( 99 | features=extract_features( 100 | { 101 | "Player": [Player()], 102 | "Bean": [Bean()] * self.count, 103 | }, 104 | obs_filter, 105 | ), 106 | actions={ 107 | "count": CategoricalActionMask(actor_ids=["Player"], mask=mask), 108 | }, 109 | ids={ 110 | "Player": ["Player"], 111 | "Bean": [f"Bean{i}" for i in range(1, self.count + 1)], 112 | }, 113 | reward=reward, 114 | done=done, 115 | ) 116 | -------------------------------------------------------------------------------- /entity_gym/examples/rock_paper_scissors.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Dict, Mapping 4 | 5 | from entity_gym.dataclass_utils import extract_features, obs_space_from_dataclasses 6 | from entity_gym.env import ( 7 | Action, 8 | ActionSpace, 9 | CategoricalAction, 10 | CategoricalActionMask, 11 | CategoricalActionSpace, 12 | Environment, 13 | Observation, 14 | ObsSpace, 15 | ) 16 | 17 | 18 | @dataclass 19 | class Player: 20 | pass 21 | 22 | 23 | @dataclass 24 | class Opponent: 25 | rock: float = 0.0 26 | paper: float = 0.0 27 | scissors: float = 0.0 28 | 29 | 30 | class RockPaperScissors(Environment): 31 | """ 32 | This environment tests giving additional information to the value function 33 | which can not be observed by the policy. 34 | 35 | On each timestep, the opponent randomly chooses rock, paper or scissors with 36 | probability of 50%, 30% and 20% respectively. The value function can observe 37 | the opponent's choice, but the policy can not. 38 | The agent must choose either rock, paper or scissors. If the agent beats the 39 | opponent, the agent receives a reward of 2.0, otherwise it receives a reward of 0.0. 40 | The optimal strategy is to always choose paper for an average reward of 1.0. 41 | Since the value function can observe the opponent's choice, it can perfectly 42 | predict reward. 43 | """ 44 | 45 | def __init__(self, cheat: bool = False) -> None: 46 | self.cheat = cheat 47 | self.reset() 48 | 49 | def obs_space(self) -> ObsSpace: 50 | return obs_space_from_dataclasses(Player, Opponent) 51 | 52 | def action_space(self) -> Dict[str, ActionSpace]: 53 | return {"throw": CategoricalActionSpace(["rock", "paper", "scissors"])} 54 | 55 | def reset_filter(self, obs_space: ObsSpace) -> Observation: 56 | rand = random.random() 57 | if rand < 0.5: 58 | self.opponent = Opponent(rock=1.0) 59 | elif rand < 0.8: 60 | self.opponent = Opponent(paper=1.0) 61 | else: 62 | self.opponent = Opponent(scissors=1.0) 63 | return self.observe(obs_space) 64 | 65 | def reset(self) -> Observation: 66 | return self.reset_filter(self.obs_space()) 67 | 68 | def act_filter( 69 | self, action: Mapping[str, Action], obs_filter: ObsSpace 70 | ) -> Observation: 71 | reward = 0.0 72 | for action_name, a in action.items(): 73 | assert isinstance(a, CategoricalAction) 74 | if action_name == "throw": 75 | if ( 76 | (a.indices[0] == 0 and self.opponent.scissors == 1.0) 77 | or (a.indices[0] == 1 and self.opponent.rock == 1.0) 78 | or (a.indices[0] == 2 and self.opponent.paper == 1.0) 79 | ): 80 | reward = 2.0 81 | return self.observe(obs_filter, done=True, reward=reward) 82 | 83 | def act(self, actions: Mapping[str, Action]) -> Observation: 84 | return self.act_filter( 85 | actions, 86 | self.obs_space(), 87 | ) 88 | 89 | def observe( 90 | self, obs_filter: ObsSpace, done: bool = False, reward: float = 0.0 91 | ) -> Observation: 92 | return Observation( 93 | features=extract_features( 94 | { 95 | "Player": [Player()], 96 | "Opponent": [self.opponent], 97 | }, 98 | obs_filter, 99 | ), 100 | actions={ 101 | "throw": CategoricalActionMask(actor_ids=[0]), 102 | }, 103 | ids={"Player": [0]}, 104 | visible={"Opponent": [self.cheat]}, 105 | reward=reward, 106 | done=done, 107 | ) 108 | -------------------------------------------------------------------------------- /entity_gym/examples/floor_is_lava.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Dict, Mapping 4 | 5 | from entity_gym.dataclass_utils import extract_features, obs_space_from_dataclasses 6 | from entity_gym.env import ( 7 | Action, 8 | ActionSpace, 9 | CategoricalAction, 10 | CategoricalActionMask, 11 | CategoricalActionSpace, 12 | Environment, 13 | Observation, 14 | ObsSpace, 15 | ) 16 | 17 | 18 | @dataclass 19 | class Lava: 20 | x: float 21 | y: float 22 | 23 | 24 | @dataclass 25 | class HighGround: 26 | x: float 27 | y: float 28 | 29 | 30 | @dataclass 31 | class Player: 32 | x: float 33 | y: float 34 | 35 | 36 | class FloorIsLava(Environment): 37 | """ 38 | The player is surrounded by 8 tiles, 7 of which are lava and 1 of which is high ground. 39 | The player must move to one of the tiles. 40 | The player receives a reward of 1 if they move to the high ground, and 0 otherwise. 41 | """ 42 | 43 | def obs_space(self) -> ObsSpace: 44 | return obs_space_from_dataclasses(Lava, HighGround, Player) 45 | 46 | def action_space(self) -> Dict[str, ActionSpace]: 47 | return { 48 | "move": CategoricalActionSpace(["n", "ne", "e", "se", "s", "sw", "w", "nw"]) 49 | } 50 | 51 | def reset_filter(self, obs_space: ObsSpace) -> Observation: 52 | width = 1000 53 | x = random.randint(-width, width) 54 | y = random.randint(-width, width) 55 | self.player = Player(x, y) 56 | self.lava = random.sample( 57 | [ 58 | Lava(x + i, y + j) 59 | for i in range(-1, 2) 60 | for j in range(-1, 2) 61 | if not (i == 0 and j == 0) 62 | ], 63 | random.randint(1, 8), 64 | ) 65 | safe = random.randint(0, len(self.lava) - 1) 66 | self.high_ground = HighGround(self.lava[safe].x, self.lava[safe].y) 67 | self.lava.pop(safe) 68 | obs = self.observe(obs_space) 69 | return obs 70 | 71 | def reset(self) -> Observation: 72 | return self.reset_filter(self.obs_space()) 73 | 74 | def act_filter( 75 | self, action: Mapping[str, Action], obs_filter: ObsSpace 76 | ) -> Observation: 77 | for action_name, a in action.items(): 78 | assert isinstance(a, CategoricalAction) and action_name == "move" 79 | dx, dy = [ 80 | (0, 1), 81 | (1, 1), 82 | (1, 0), 83 | (1, -1), 84 | (0, -1), 85 | (-1, -1), 86 | (-1, 0), 87 | (-1, 1), 88 | ][a.indices[0]] 89 | self.player.x += dx 90 | self.player.y += dy 91 | obs = self.observe(obs_filter, done=True) 92 | return obs 93 | 94 | def act(self, actions: Mapping[str, Action]) -> Observation: 95 | return self.act_filter( 96 | actions, 97 | self.obs_space(), 98 | ) 99 | 100 | def observe(self, obs_filter: ObsSpace, done: bool = False) -> Observation: 101 | if ( 102 | done 103 | and self.player.x == self.high_ground.x 104 | and self.player.y == self.high_ground.y 105 | ): 106 | reward = 1.0 107 | else: 108 | reward = 0.0 109 | return Observation( 110 | features=extract_features( 111 | { 112 | "Player": [self.player], 113 | "Lava": self.lava, 114 | "HighGround": [self.high_ground], 115 | }, 116 | obs_filter, 117 | ), 118 | actions={ 119 | "move": CategoricalActionMask(actor_types=["Player"]), 120 | }, 121 | ids={"Player": [0]}, 122 | reward=reward, 123 | done=done, 124 | ) 125 | -------------------------------------------------------------------------------- /entity_gym/env/add_metrics_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Mapping, Optional 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | from ragged_buffer import RaggedBufferI64 6 | 7 | from entity_gym.env.environment import ActionName, ActionSpace, ObsSpace 8 | from entity_gym.env.vec_env import Metric, VecEnv, VecObs 9 | 10 | 11 | class AddMetricsWrapper(VecEnv): 12 | def __init__( 13 | self, env: VecEnv, filter: Optional[npt.NDArray[np.bool8]] = None 14 | ) -> None: 15 | """ 16 | Wrap a VecEnv to track and add metrics for (episodic) rewards and episode lengths. 17 | 18 | Args: 19 | env: The VecEnv to wrap. 20 | filter: A boolean array of length len(env) indicating which environments to 21 | track metrics for. If filter[i] is True, then the metrics for the i-th 22 | environment will be tracked. 23 | """ 24 | self.env = env 25 | self.entity_types = list(env.obs_space().entities.keys()) 26 | if env.has_global_entity(): 27 | self.entity_types.append("__global__") 28 | self.total_reward = np.zeros(len(env), dtype=np.float32) 29 | self.total_steps = np.zeros(len(env), dtype=np.int64) 30 | self.filter = np.ones(len(env), dtype=np.bool8) if filter is None else filter 31 | 32 | def reset(self, obs_config: ObsSpace) -> VecObs: 33 | return self.track_metrics(self.env.reset(obs_config)) 34 | 35 | def act( 36 | self, actions: Mapping[ActionName, RaggedBufferI64], obs_filter: ObsSpace 37 | ) -> VecObs: 38 | return self.track_metrics(self.env.act(actions, obs_filter)) 39 | 40 | def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]: 41 | return self.env.render(**kwargs) 42 | 43 | def __len__(self) -> int: 44 | return len(self.env) 45 | 46 | def close(self) -> None: 47 | self.env.close() 48 | 49 | def track_metrics(self, obs: VecObs) -> VecObs: 50 | self.total_reward += obs.reward 51 | self.total_steps += 1 52 | episodic_reward = Metric() 53 | episodic_length = Metric() 54 | count = len(self.total_steps) 55 | obs.metrics["step"] = Metric( 56 | sum=self.total_steps.sum(), 57 | count=count, 58 | min=self.total_steps.min(), 59 | max=self.total_steps.max(), 60 | ) 61 | 62 | for entity in self.entity_types: 63 | if entity in obs.features: 64 | _sum = obs.features[entity].items() 65 | counts = obs.features[entity].size1() 66 | _min = counts.min() 67 | _max = counts.max() 68 | else: 69 | _sum = 0 70 | _min = 0 71 | _max = 0 72 | obs.metrics[f"entity_count/{entity}"] = Metric( 73 | sum=_sum, count=count, min=_min, max=_max 74 | ) 75 | if len(obs.features) > 0: 76 | combined_counts: Any = sum( 77 | features.size1() for features in obs.features.values() 78 | ) 79 | else: 80 | combined_counts = np.zeros(count, dtype=np.int64) 81 | obs.metrics["entity_count"] = Metric( 82 | sum=combined_counts.sum(), 83 | count=count, 84 | min=combined_counts.min(), 85 | max=combined_counts.max(), 86 | ) 87 | 88 | for i in np.arange(len(self))[obs.done & self.filter]: 89 | episodic_reward.push(self.total_reward[i]) 90 | episodic_length.push(self.total_steps[i]) 91 | self.total_reward[i] = 0.0 92 | self.total_steps[i] = 0 93 | obs.metrics["episodic_reward"] = episodic_reward 94 | obs.metrics["episode_length"] = episodic_length 95 | obs.metrics["reward"] = Metric( 96 | sum=obs.reward.sum(), 97 | count=obs.reward.size, 98 | min=obs.reward.min(), 99 | max=obs.reward.max(), 100 | ) 101 | return obs 102 | 103 | def action_space(self) -> Dict[ActionName, ActionSpace]: 104 | return self.env.action_space() 105 | 106 | def obs_space(self) -> ObsSpace: 107 | return self.env.obs_space() 108 | -------------------------------------------------------------------------------- /entity_gym/tests/test_sample_recorder.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import numpy as np 4 | from ragged_buffer import RaggedBufferBool, RaggedBufferF32, RaggedBufferI64 5 | 6 | from entity_gym.env import VecCategoricalActionMask, VecObs 7 | from entity_gym.serialization import Sample, SampleRecorder, Trace 8 | 9 | 10 | def test_serde_sample() -> None: 11 | sample = Sample( 12 | obs=VecObs( 13 | features={ 14 | "hero": RaggedBufferF32.from_array( 15 | np.array([[[1.0, 2.0, 0.3, 100.0, 10.0]]], dtype=np.float32), 16 | ), 17 | "enemy": RaggedBufferF32.from_array( 18 | np.array( 19 | [ 20 | [ 21 | [4.0, -2.0, 0.3, 100.0], 22 | [5.0, -2.0, 0.3, 100.0], 23 | [6.0, -2.0, 0.3, 100.0], 24 | ] 25 | ], 26 | dtype=np.float32, 27 | ), 28 | ), 29 | "box": RaggedBufferF32.from_array( 30 | np.array( 31 | [ 32 | [ 33 | [0.0, 0.0, 0.3, 100.0], 34 | [1.0, 0.0, 0.3, 100.0], 35 | [2.0, 0.0, 0.3, 100.0], 36 | ] 37 | ], 38 | dtype=np.float32, 39 | ), 40 | ), 41 | }, 42 | visible={}, 43 | action_masks={ 44 | "move": VecCategoricalActionMask( 45 | actors=RaggedBufferI64.from_array(np.array([[[0]]])), 46 | mask=RaggedBufferBool.from_array(np.array([[[True, False, True]]])), 47 | ), 48 | "shoot": VecCategoricalActionMask( 49 | actors=RaggedBufferI64.from_array(np.array([[[0]]])), mask=None 50 | ), 51 | "explode": VecCategoricalActionMask( 52 | actors=RaggedBufferI64.from_array(np.array([[[4], [5], [6]]])), 53 | mask=None, 54 | ), 55 | }, 56 | reward=np.array([0.3124125987123489]), 57 | done=np.array([False]), 58 | metrics={}, 59 | ), 60 | probs={ 61 | "move": RaggedBufferF32.from_array( 62 | np.array([[[0.5], [0.2], [0.3], [0.0]]], dtype=np.float32) 63 | ), 64 | "shoot": RaggedBufferF32.from_array( 65 | np.array([[[0.9], [0.1]]], dtype=np.float32) 66 | ), 67 | "explode": RaggedBufferF32.from_array( 68 | np.array( 69 | [[[0.3], [0.7]], [[0.2], [0.8]], [[0.1], [0.9]]], dtype=np.float32 70 | ) 71 | ), 72 | }, 73 | logits=None, 74 | actions={}, 75 | step=[13], 76 | episode=[4213], 77 | ) 78 | 79 | with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f: 80 | sample_recorder = SampleRecorder(f.name, act_space=None, obs_space=None, subsample=1) # type: ignore 81 | sample_recorder.record(sample) 82 | # modify the sample 83 | sample.obs.reward = np.array([1.0]) 84 | sample.obs.features["hero"] = RaggedBufferF32.from_array( 85 | np.array([[[1.0, 2.0, 0.3, 200.0, 10.0]]], dtype=np.float32), 86 | ) 87 | sample_recorder.record(sample) 88 | sample_recorder.close() 89 | 90 | with open(f.name, "rb") as f2: 91 | trace = Trace.deserialize(f2.read()) 92 | assert len(trace.samples) == 2 93 | assert trace.samples[0].obs.reward[0] == 0.3124125987123489 94 | assert trace.samples[1].obs.reward[0] == 1.0 95 | assert ( 96 | trace.samples[0].obs.action_masks["move"] 97 | == sample.obs.action_masks["move"] 98 | ) 99 | np.testing.assert_equal( 100 | trace.samples[0].obs.features["hero"][0].as_array(), 101 | np.array([[1.0, 2.0, 0.3, 100.0, 10.0]], dtype=np.float32), 102 | ) 103 | np.testing.assert_equal( 104 | trace.samples[1].obs.features["hero"][0].as_array(), 105 | np.array([[1.0, 2.0, 0.3, 200.0, 10.0]], dtype=np.float32), 106 | ) 107 | -------------------------------------------------------------------------------- /entity_gym/examples/pick_matching_balls.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | from typing import Dict, List, Mapping 4 | 5 | import numpy as np 6 | 7 | from entity_gym.env import ( 8 | Action, 9 | ActionSpace, 10 | Entity, 11 | Environment, 12 | Observation, 13 | ObsSpace, 14 | SelectEntityAction, 15 | SelectEntityActionMask, 16 | SelectEntityActionSpace, 17 | ) 18 | from entity_gym.env.environment import ActionName 19 | 20 | 21 | @dataclass 22 | class Ball: 23 | color: int 24 | selected: bool = False 25 | 26 | 27 | @dataclass 28 | class PickMatchingBalls(Environment): 29 | """ 30 | The PickMatchingBalls environment is initialized with a list of 32 balls of different colors. 31 | On each timestamp, the player can pick up one of the balls. 32 | The episode ends when the player picks up a ball of a different color from the last one. 33 | The player receives a reward equal to the number of balls picked up divided by the maximum number of balls of the same color. 34 | """ 35 | 36 | max_balls: int = 32 37 | balls: List[Ball] = field(default_factory=list) 38 | one_hot: bool = False # use one-hot encoding for the ball color feature 39 | randomize: bool = ( 40 | False # randomize the number of balls to be between 3 and max_balls 41 | ) 42 | 43 | def obs_space(self) -> ObsSpace: 44 | return ObsSpace( 45 | entities={ 46 | "Ball": Entity( 47 | # TODO: better support for categorical features 48 | [ 49 | "color0", 50 | "color1", 51 | "color2", 52 | "color3", 53 | "color4", 54 | "color5", 55 | "selected", 56 | ], 57 | ), 58 | "Player": Entity([]), 59 | } 60 | ) 61 | 62 | def action_space(self) -> Dict[ActionName, ActionSpace]: 63 | return {"Pick Ball": SelectEntityActionSpace()} 64 | 65 | def reset(self) -> Observation: 66 | num_balls = ( 67 | self.max_balls if not self.randomize else random.randint(3, self.max_balls) 68 | ) 69 | self.balls = [Ball(color=random.randint(0, 5)) for _ in range(num_balls)] 70 | return self.observe() 71 | 72 | def observe(self) -> Observation: 73 | done = len({b.color for b in self.balls if b.selected}) > 1 or all( 74 | b.selected for b in self.balls 75 | ) 76 | if done: 77 | if all(b.selected for b in self.balls): 78 | reward = 1.0 79 | else: 80 | reward = (sum(b.selected for b in self.balls) - 1) / max( 81 | len([b for b in self.balls if b.color == color]) 82 | for color in range(6) 83 | ) 84 | else: 85 | reward = 0.0 86 | 87 | return Observation( 88 | features={ 89 | "Ball": np.array( 90 | [ 91 | [float(b.color == c) for c in range(6)] + [float(b.selected)] 92 | for b in self.balls 93 | ] 94 | if self.one_hot 95 | else [ 96 | [float(b.color) for _ in range(6)] + [float(b.selected)] 97 | for b in self.balls 98 | ], 99 | dtype=np.float32, 100 | ), 101 | "Player": np.zeros([1, 0], dtype=np.float32), 102 | }, 103 | ids={ 104 | "Ball": list(range(len(self.balls))), 105 | "Player": [len(self.balls)], 106 | }, 107 | actions={ 108 | "Pick Ball": SelectEntityActionMask( 109 | actor_ids=[len(self.balls)], 110 | actee_ids=[i for i, b in enumerate(self.balls) if not b.selected], 111 | ), 112 | }, 113 | reward=reward, 114 | done=done, 115 | ) 116 | 117 | def act(self, actions: Mapping[ActionName, Action]) -> Observation: 118 | action = actions["Pick Ball"] 119 | assert isinstance(action, SelectEntityAction) 120 | for selected_ball in action.actees: 121 | assert not self.balls[selected_ball].selected 122 | self.balls[selected_ball].selected = True 123 | return self.observe() 124 | -------------------------------------------------------------------------------- /entity_gym/examples/move_to_origin.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Dict, Mapping 4 | 5 | import numpy as np 6 | 7 | from entity_gym.env import ( 8 | Action, 9 | ActionSpace, 10 | CategoricalAction, 11 | CategoricalActionMask, 12 | CategoricalActionSpace, 13 | Entity, 14 | Environment, 15 | Observation, 16 | ObsSpace, 17 | ) 18 | 19 | 20 | @dataclass 21 | class MoveToOrigin(Environment): 22 | """ 23 | Task with a single Spaceship that is rewarded for moving as close to the origin as possible. 24 | The Spaceship has two actions for accelerating the Spaceship in the x and y directions. 25 | """ 26 | 27 | x_pos: float = 0.0 28 | y_pos: float = 0.0 29 | x_velocity: float = 0.0 30 | y_velocity: float = 0.0 31 | last_x_pos = 0.0 32 | last_y_pos = 0.0 33 | step: int = 0 34 | 35 | def obs_space(cls) -> ObsSpace: 36 | return ObsSpace( 37 | entities={ 38 | "Spaceship": Entity( 39 | ["x_pos", "y_pos", "x_velocity", "y_velocity", "step"] 40 | ), 41 | } 42 | ) 43 | 44 | def action_space(cls) -> Dict[str, ActionSpace]: 45 | return { 46 | "horizontal_thruster": CategoricalActionSpace( 47 | [ 48 | "100% right", 49 | "10% right", 50 | "hold", 51 | "10% left", 52 | "100% left", 53 | ], 54 | ), 55 | "vertical_thruster": CategoricalActionSpace( 56 | ["100% up", "10% up", "hold", "10% down", "100% down"], 57 | ), 58 | } 59 | 60 | def reset(self) -> Observation: 61 | angle = random.uniform(0, 2 * np.pi) 62 | self.x_pos = np.cos(angle) 63 | self.y_pos = np.sin(angle) 64 | self.last_x_pos = self.x_pos 65 | self.last_y_pos = self.y_pos 66 | self.x_velocity = 0 67 | self.y_velocity = 0 68 | self.step = 0 69 | return self.observe() 70 | 71 | def act(self, actions: Mapping[str, Action]) -> Observation: 72 | self.step += 1 73 | 74 | for action_name, a in actions.items(): 75 | assert isinstance(a, CategoricalAction), f"{a} is not a CategoricalAction" 76 | if action_name == "horizontal_thruster": 77 | for label in a.labels: 78 | if label == "100% right": 79 | self.x_velocity += 0.01 80 | elif label == "10% right": 81 | self.x_velocity += 0.001 82 | elif label == "hold": 83 | pass 84 | elif label == "10% left": 85 | self.x_velocity -= 0.001 86 | elif label == "100% left": 87 | self.x_velocity -= 0.01 88 | else: 89 | raise ValueError(f"Invalid choice id {label}") 90 | elif action_name == "vertical_thruster": 91 | for label in a.labels: 92 | if label == "100% up": 93 | self.y_velocity += 0.01 94 | elif label == "10% up": 95 | self.y_velocity += 0.001 96 | elif label == "hold": 97 | pass 98 | elif label == "10% down": 99 | self.y_velocity -= 0.001 100 | elif label == "100% down": 101 | self.y_velocity -= 0.01 102 | else: 103 | raise ValueError(f"Invalid choice id {label}") 104 | else: 105 | raise ValueError(f"Unknown action type {action_name}") 106 | 107 | self.last_x_pos = self.x_pos 108 | self.last_y_pos = self.y_pos 109 | 110 | self.x_pos += self.x_velocity 111 | self.y_pos += self.y_velocity 112 | 113 | done = self.step >= 32 114 | return self.observe(done) 115 | 116 | def observe(self, done: bool = False) -> Observation: 117 | return Observation( 118 | ids={ 119 | "Spaceship": [0], 120 | }, 121 | features={ 122 | "Spaceship": np.array( 123 | [ 124 | [ 125 | self.x_pos, 126 | self.y_pos, 127 | self.x_velocity, 128 | self.y_velocity, 129 | self.step, 130 | ] 131 | ], 132 | dtype=np.float32, 133 | ), 134 | }, 135 | actions={ 136 | "horizontal_thruster": CategoricalActionMask(), 137 | "vertical_thruster": CategoricalActionMask(), 138 | }, 139 | reward=(self.last_x_pos**2 + self.last_y_pos**2) ** 0.5 140 | - (self.x_pos**2 + self.y_pos**2) ** 0.5, 141 | done=done, 142 | ) 143 | -------------------------------------------------------------------------------- /.github/workflows/checks.yaml: -------------------------------------------------------------------------------- 1 | name: Checks 2 | 3 | on: 4 | push: 5 | branches: [ '*' ] 6 | pull_request: 7 | branches: [ main ] 8 | jobs: 9 | pre-commit: 10 | name: pre-commit 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [3.9, 3.8] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | #---------------------------------------------- 23 | # ----- install & configure poetry ----- 24 | #---------------------------------------------- 25 | - name: Install Poetry 26 | uses: snok/install-poetry@v1 27 | with: 28 | virtualenvs-create: true 29 | virtualenvs-in-project: true 30 | installer-parallel: true 31 | #---------------------------------------------- 32 | # load cached venv if cache exists 33 | #---------------------------------------------- 34 | - name: Load cached venv 35 | id: cached-poetry-dependencies 36 | uses: actions/cache@v2 37 | with: 38 | path: .venv 39 | key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} 40 | #---------------------------------------------- 41 | # install dependencies if cache does not exist 42 | #---------------------------------------------- 43 | - name: Install dependencies 44 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 45 | continue-on-error: true 46 | run: poetry install --no-interaction --no-root 47 | #---------------------------------------------- 48 | # install cuda related dependencies 49 | #---------------------------------------------- 50 | - name: Install dependencies 51 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 52 | run: poetry run pip install torch 53 | - name: Install dependencies 54 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 55 | run: poetry run pip install torch-scatter 56 | #---------------------------------------------- 57 | # install your root project, if required 58 | #---------------------------------------------- 59 | - name: Install library 60 | run: poetry install --no-interaction 61 | #---------------------------------------------- 62 | # format, type check, and test your code 63 | #---------------------------------------------- 64 | 65 | - name: Run pre-commit utilities 66 | run: | 67 | poetry run pre-commit run --all-files 68 | 69 | unit_test: 70 | name: unit test 71 | runs-on: ubuntu-latest 72 | strategy: 73 | matrix: 74 | python-version: [3.9, 3.8] 75 | 76 | steps: 77 | - uses: actions/checkout@v2 78 | - name: Set up Python ${{ matrix.python-version }} 79 | uses: actions/setup-python@v2 80 | with: 81 | python-version: ${{ matrix.python-version }} 82 | #---------------------------------------------- 83 | # ----- install & configure poetry ----- 84 | #---------------------------------------------- 85 | - name: Install Poetry 86 | uses: snok/install-poetry@v1 87 | with: 88 | virtualenvs-create: true 89 | virtualenvs-in-project: true 90 | installer-parallel: true 91 | #---------------------------------------------- 92 | # load cached venv if cache exists 93 | #---------------------------------------------- 94 | - name: Load cached venv 95 | id: cached-poetry-dependencies 96 | uses: actions/cache@v2 97 | with: 98 | path: .venv 99 | key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} 100 | #---------------------------------------------- 101 | # install dependencies if cache does not exist 102 | #---------------------------------------------- 103 | - name: Install dependencies 104 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 105 | continue-on-error: true 106 | run: poetry install --no-interaction --no-root 107 | #---------------------------------------------- 108 | # install cuda related dependencies 109 | #---------------------------------------------- 110 | - name: Install dependencies 111 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 112 | run: poetry run pip install torch 113 | - name: Install dependencies 114 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 115 | run: poetry run pip install torch-scatter 116 | #---------------------------------------------- 117 | # install your root project, if required 118 | #---------------------------------------------- 119 | - name: Install library 120 | run: poetry install --no-interaction 121 | #---------------------------------------------- 122 | # format, type check, and test your code 123 | #---------------------------------------------- 124 | 125 | - name: Run tests 126 | run: | 127 | poetry run pytest -v -------------------------------------------------------------------------------- /entity_gym/env/env_list.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Mapping, Optional 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | from ragged_buffer import RaggedBufferI64 6 | 7 | from entity_gym.env.environment import ( 8 | Action, 9 | ActionName, 10 | ActionSpace, 11 | CategoricalAction, 12 | CategoricalActionMask, 13 | CategoricalActionSpace, 14 | EntityID, 15 | Environment, 16 | GlobalCategoricalAction, 17 | GlobalCategoricalActionSpace, 18 | Observation, 19 | ObsSpace, 20 | SelectEntityAction, 21 | SelectEntityActionMask, 22 | SelectEntityActionSpace, 23 | ) 24 | from entity_gym.env.vec_env import VecEnv, VecObs, batch_obs 25 | 26 | 27 | class EnvList(VecEnv): 28 | def __init__(self, create_env: Callable[[], Environment], num_envs: int): 29 | self.envs = [create_env() for _ in range(num_envs)] 30 | self.last_obs: List[Observation] = [] 31 | env = self.envs[0] if num_envs > 0 else create_env() 32 | self._obs_space = env.obs_space() 33 | self._action_space = env.action_space() 34 | 35 | def reset(self, obs_space: ObsSpace) -> VecObs: 36 | batch = self._batch_obs([e.reset_filter(obs_space) for e in self.envs]) 37 | return batch 38 | 39 | def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]: 40 | return np.stack([e.render(**kwargs) for e in self.envs]) 41 | 42 | def close(self) -> None: 43 | for env in self.envs: 44 | env.close() 45 | 46 | def act( 47 | self, actions: Mapping[str, RaggedBufferI64], obs_space: ObsSpace 48 | ) -> VecObs: 49 | observations = [] 50 | action_spaces = self.action_space() 51 | for i, env in enumerate(self.envs): 52 | _actions = action_index_to_actions( 53 | self._obs_space, action_spaces, actions, self.last_obs[i], i 54 | ) 55 | obs = env.act_filter(_actions, obs_space) 56 | if obs.done: 57 | new_obs = env.reset_filter(obs_space) 58 | new_obs.done = True 59 | new_obs.reward = obs.reward 60 | new_obs.metrics = obs.metrics 61 | observations.append(new_obs) 62 | else: 63 | observations.append(obs) 64 | return self._batch_obs(observations) 65 | 66 | def _batch_obs(self, obs: List[Observation]) -> VecObs: 67 | self.last_obs = obs 68 | return batch_obs(obs, self.obs_space(), self.action_space()) 69 | 70 | def __len__(self) -> int: 71 | return len(self.envs) 72 | 73 | def obs_space(self) -> ObsSpace: 74 | return self._obs_space 75 | 76 | def action_space(self) -> Dict[ActionName, ActionSpace]: 77 | return self._action_space 78 | 79 | 80 | def action_index_to_actions( 81 | obs_space: ObsSpace, 82 | action_spaces: Dict[ActionName, ActionSpace], 83 | actions: Mapping[ActionName, RaggedBufferI64], 84 | last_obs: Observation, 85 | index: int = 0, 86 | probs: Optional[Dict[ActionName, npt.NDArray[np.float32]]] = None, 87 | ) -> Dict[ActionName, Action]: 88 | _actions: Dict[ActionName, Action] = {} 89 | for atype, action in actions.items(): 90 | action_space = action_spaces[atype] 91 | if isinstance(action_space, GlobalCategoricalActionSpace): 92 | aindex = action[index].as_array()[0, 0] 93 | _actions[atype] = GlobalCategoricalAction( 94 | index=aindex, 95 | label=action_space.index_to_label[aindex], 96 | probs=probs[atype].reshape(-1) if probs is not None else None, 97 | ) 98 | continue 99 | mask = last_obs.actions[atype] 100 | assert isinstance(mask, SelectEntityActionMask) or isinstance( 101 | mask, CategoricalActionMask 102 | ) 103 | if mask.actor_ids is not None: 104 | actors = mask.actor_ids 105 | elif mask.actor_types is not None: 106 | actors = [] 107 | for etype in mask.actor_types: 108 | actors.extend(last_obs.ids[etype]) 109 | else: 110 | actors = [] 111 | for ids in last_obs.ids.values(): 112 | actors.extend(ids) 113 | aspace = action_spaces[atype] 114 | if isinstance(aspace, CategoricalActionSpace): 115 | _actions[atype] = CategoricalAction( 116 | actors=actors, 117 | indices=action[index].as_array().reshape(-1), 118 | index_to_label=aspace.index_to_label, 119 | probs=probs[atype] if probs is not None else None, 120 | ) 121 | elif isinstance(action_spaces[atype], SelectEntityActionSpace): 122 | assert isinstance(mask, SelectEntityActionMask) 123 | if mask.actee_types is not None: 124 | index_to_actee: List[EntityID] = [] 125 | for etype in mask.actee_types: 126 | index_to_actee.extend(last_obs.ids[etype]) 127 | actees = [ 128 | index_to_actee[a] for a in action[index].as_array().reshape(-1) 129 | ] 130 | elif mask.actee_ids is not None: 131 | actees = [ 132 | mask.actee_ids[e] for e in action[index].as_array().reshape(-1) 133 | ] 134 | else: 135 | index_to_id = last_obs.index_to_id(obs_space) 136 | actees = [index_to_id[e] for e in action[index].as_array().reshape(-1)] 137 | _actions[atype] = SelectEntityAction( 138 | actors=actors, 139 | actees=actees, 140 | probs=probs[atype] if probs is not None else None, 141 | ) 142 | else: 143 | raise NotImplementedError( 144 | f"Action space type {type(action_spaces[atype])} not supported" 145 | ) 146 | return _actions 147 | -------------------------------------------------------------------------------- /entity_gym/examples/minefield.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | from typing import Dict, List, Mapping, Tuple 4 | 5 | import numpy as np 6 | 7 | from entity_gym.dataclass_utils import extract_features, obs_space_from_dataclasses 8 | from entity_gym.env import ( 9 | Action, 10 | ActionSpace, 11 | CategoricalAction, 12 | CategoricalActionMask, 13 | CategoricalActionSpace, 14 | Environment, 15 | Observation, 16 | ObsSpace, 17 | ) 18 | 19 | 20 | @dataclass 21 | class Vehicle: 22 | x_pos: float = 0.0 23 | y_pos: float = 0.0 24 | direction: float = 0.0 25 | step: int = 0 26 | 27 | 28 | @dataclass 29 | class Target: 30 | x_pos: float = 0.0 31 | y_pos: float = 0.0 32 | 33 | 34 | @dataclass 35 | class Mine: 36 | x_pos: float = 0.0 37 | y_pos: float = 0.0 38 | 39 | 40 | @dataclass 41 | class Minefield(Environment): 42 | """ 43 | Task with a Vehicle entity that has to reach a target point, receiving a reward of 1. 44 | If the vehicle collides with any of the randomly placed mines, the episode ends without reward. 45 | The available actions either turn the vehicle left, right, or go straight. 46 | """ 47 | 48 | vehicle: Vehicle = field(default_factory=Vehicle) 49 | target: Target = field(default_factory=Target) 50 | mine: Mine = field(default_factory=Mine) 51 | max_mines: int = 10 52 | max_steps: int = 200 53 | translate: bool = False 54 | width: float = 200.0 55 | 56 | def obs_space(self) -> ObsSpace: 57 | return obs_space_from_dataclasses(Vehicle, Mine, Target) 58 | 59 | def action_space(self) -> Dict[str, ActionSpace]: 60 | return { 61 | "move": CategoricalActionSpace( 62 | ["turn left", "move forward", "turn right"], 63 | ) 64 | } 65 | 66 | def reset_filter(self, obs_space: ObsSpace) -> Observation: 67 | def randpos() -> Tuple[float, float]: 68 | return ( 69 | random.uniform(-self.width / 2, self.width / 2), 70 | random.uniform(-self.width / 2, self.width / 2), 71 | ) 72 | 73 | self.vehicle.x_pos, self.vehicle.y_pos = randpos() 74 | self.target.x_pos, self.target.y_pos = randpos() 75 | mines: List[Mine] = [] 76 | for _ in range(self.max_mines): 77 | x, y = randpos() 78 | # Check that the mine is not too close to the vehicle, target, or any other mine 79 | pos = [(m.x_pos, m.y_pos) for m in mines] + [ 80 | (self.vehicle.x_pos, self.vehicle.y_pos), 81 | (self.target.x_pos, self.target.y_pos), 82 | ] 83 | if any(map(lambda p: (x - p[0]) ** 2 + (y - p[1]) ** 2 < 15 * 15, pos)): 84 | continue 85 | mines.append(Mine(x, y)) 86 | self.vehicle.direction = random.uniform(0, 2 * np.pi) 87 | self.step = 0 88 | self.mines = mines 89 | return self.observe(obs_space) 90 | 91 | def reset(self) -> Observation: 92 | return self.reset_filter(self.obs_space()) 93 | 94 | def act_filter( 95 | self, action: Mapping[str, Action], obs_filter: ObsSpace 96 | ) -> Observation: 97 | for action_name, a in action.items(): 98 | assert isinstance(a, CategoricalAction) 99 | if action_name == "move": 100 | move = a.indices[0] 101 | if move == 0: 102 | self.vehicle.direction -= np.pi / 8 103 | elif move == 1: 104 | self.vehicle.x_pos += 3 * np.cos(self.vehicle.direction) 105 | self.vehicle.y_pos += 3 * np.sin(self.vehicle.direction) 106 | elif move == 2: 107 | self.vehicle.direction += np.pi / 8 108 | else: 109 | raise ValueError( 110 | f"Invalid action {move} for action space {action_name}" 111 | ) 112 | self.vehicle.direction %= 2 * np.pi 113 | else: 114 | raise ValueError(f"Unknown action type {action_name}") 115 | 116 | self.step += 1 117 | self.vehicle.step = self.step 118 | 119 | return self.observe(obs_filter) 120 | 121 | def act(self, actions: Mapping[str, Action]) -> Observation: 122 | return self.act_filter( 123 | actions, 124 | self.obs_space(), 125 | ) 126 | 127 | def observe(self, obs_filter: ObsSpace, done: bool = False) -> Observation: 128 | if (self.target.x_pos - self.vehicle.x_pos) ** 2 + ( 129 | self.target.y_pos - self.vehicle.y_pos 130 | ) ** 2 < 5 * 5: 131 | done = True 132 | reward = 1 133 | elif ( 134 | any( 135 | map( 136 | lambda m: (self.vehicle.x_pos - m.x_pos) ** 2 137 | + (self.vehicle.y_pos - m.y_pos) ** 2 138 | < 5 * 5, 139 | self.mines, 140 | ) 141 | ) 142 | or self.step >= self.max_steps 143 | ): 144 | done = True 145 | reward = 0 146 | else: 147 | done = False 148 | reward = 0 149 | 150 | if self.translate: 151 | ox = self.vehicle.x_pos 152 | oy = self.vehicle.y_pos 153 | else: 154 | ox = oy = 0 155 | return Observation( 156 | features=extract_features( 157 | { 158 | "Mine": [Mine(m.x_pos - ox, m.y_pos - oy) for m in self.mines], 159 | "Vehicle": [self.vehicle], 160 | "Target": [Target(self.target.x_pos - ox, self.target.y_pos - oy)], 161 | }, 162 | obs_filter, 163 | ), 164 | actions={ 165 | "move": CategoricalActionMask(actor_types=["Vehicle"]), 166 | }, 167 | ids={"Vehicle": ["Vehicle"]}, 168 | reward=reward, 169 | done=done, 170 | ) 171 | -------------------------------------------------------------------------------- /entity_gym/examples/minesweeper.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict, List, Mapping, Tuple 3 | 4 | from entity_gym.env import * 5 | 6 | 7 | class MineSweeper(Environment): 8 | """ 9 | The MineSweeper environment contains two types of objects, mines and robots. 10 | The player controls all robots in the environment. 11 | On every step, each robot may move in one of four cardinal directions, or stay in place and defuse all adjacent mines. 12 | If a robot defuses a mine, it is removed from the environment. 13 | If a robot steps on a mine, it is removed from the environment and the player loses the game. 14 | The player wins the game when all mines are defused. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | width: int = 6, 20 | height: int = 6, 21 | nmines: int = 5, 22 | nrobots: int = 2, 23 | orbital_cannon: bool = False, 24 | cooldown_period: int = 5, 25 | ): 26 | self.width = width 27 | self.height = height 28 | self.nmines = nmines 29 | self.nrobots = nrobots 30 | self.orbital_cannon = orbital_cannon 31 | self.cooldown_period = cooldown_period 32 | self.orbital_cannon_cooldown = cooldown_period 33 | # Positions of robots and mines 34 | self.robots: List[Tuple[int, int]] = [] 35 | self.mines: List[Tuple[int, int]] = [] 36 | 37 | def obs_space(cls) -> ObsSpace: 38 | return ObsSpace( 39 | entities={ 40 | "Mine": Entity(features=["x", "y"]), 41 | "Robot": Entity(features=["x", "y"]), 42 | "Orbital Cannon": Entity(["cooldown"]), 43 | } 44 | ) 45 | 46 | def action_space(cls) -> Dict[ActionName, ActionSpace]: 47 | return { 48 | "Move": CategoricalActionSpace( 49 | ["Up", "Down", "Left", "Right", "Defuse Mines"], 50 | ), 51 | "Fire Orbital Cannon": SelectEntityActionSpace(), 52 | } 53 | 54 | def reset(self) -> Observation: 55 | positions = random.sample( 56 | [(x, y) for x in range(self.width) for y in range(self.height)], 57 | self.nmines + self.nrobots, 58 | ) 59 | self.mines = positions[: self.nmines] 60 | self.robots = positions[self.nmines :] 61 | self.orbital_cannon_cooldown = self.cooldown_period 62 | return self.observe() 63 | 64 | def observe(self) -> Observation: 65 | done = len(self.mines) == 0 or len(self.robots) == 0 66 | reward = 1.0 if len(self.mines) == 0 else 0.0 67 | return Observation( 68 | entities={ 69 | "Mine": ( 70 | self.mines, 71 | [("Mine", i) for i in range(len(self.mines))], 72 | ), 73 | "Robot": ( 74 | self.robots, 75 | [("Robot", i) for i in range(len(self.robots))], 76 | ), 77 | "Orbital Cannon": ( 78 | [(self.orbital_cannon_cooldown,)], 79 | [("Orbital Cannon", 0)], 80 | ) 81 | if self.orbital_cannon 82 | else None, 83 | }, 84 | actions={ 85 | "Move": CategoricalActionMask( 86 | # Allow all robots to move 87 | actor_types=["Robot"], 88 | mask=[self.valid_moves(x, y) for x, y in self.robots], 89 | ), 90 | "Fire Orbital Cannon": SelectEntityActionMask( 91 | # Only the Orbital Cannon can fire, but not if cooldown > 0 92 | actor_types=["Orbital Cannon"] 93 | if self.orbital_cannon_cooldown == 0 94 | else [], 95 | # Both mines and robots can be fired at 96 | actee_types=["Mine", "Robot"], 97 | ), 98 | }, 99 | # The game is done once there are no more mines or robots 100 | done=done, 101 | # Give reward of 1.0 for defusing all mines 102 | reward=reward, 103 | ) 104 | 105 | def act(self, actions: Mapping[ActionName, Action]) -> Observation: 106 | fire = actions["Fire Orbital Cannon"] 107 | assert isinstance(fire, SelectEntityAction) 108 | remove_robot = None 109 | for (entity_type, i) in fire.actees: 110 | if entity_type == "Mine": 111 | self.mines.remove(self.mines[i]) 112 | elif entity_type == "Robot": 113 | # Don't remove yet to keep indices valid 114 | remove_robot = i 115 | 116 | move = actions["Move"] 117 | assert isinstance(move, CategoricalAction) 118 | for (_, i), choice in zip(move.actors, move.indices): 119 | if self.robots[i] is None: 120 | continue 121 | # Action space is ["Up", "Down", "Left", "Right", "Defuse Mines"], 122 | x, y = self.robots[i] 123 | if choice == 0 and y < self.height - 1: 124 | self.robots[i] = (x, y + 1) 125 | elif choice == 1 and y > 0: 126 | self.robots[i] = (x, y - 1) 127 | elif choice == 2 and x > 0: 128 | self.robots[i] = (x - 1, y) 129 | elif choice == 3 and x < self.width - 1: 130 | self.robots[i] = (x + 1, y) 131 | elif choice == 4: 132 | # Remove all mines adjacent to this robot 133 | rx, ry = self.robots[i] 134 | self.mines = [ 135 | (x, y) for (x, y) in self.mines if abs(x - rx) + abs(y - ry) > 1 136 | ] 137 | 138 | if remove_robot is not None: 139 | self.robots.pop(remove_robot) 140 | # Remove all robots that stepped on a mine 141 | self.robots = [r for r in self.robots if r not in self.mines] 142 | 143 | return self.observe() 144 | 145 | def valid_moves(self, x: int, y: int) -> List[bool]: 146 | return [ 147 | x < self.width - 1, 148 | x > 0, 149 | y < self.height - 1, 150 | y > 0, 151 | # Always allow staying in place and defusing mines 152 | True, 153 | ] 154 | -------------------------------------------------------------------------------- /entity_gym/env/action.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Sequence, Union 3 | 4 | import numpy as np 5 | import numpy.typing as npt 6 | 7 | from .common import EntityID, EntityName 8 | 9 | 10 | @dataclass 11 | class CategoricalActionSpace: 12 | """ 13 | Defines a discrete set of actions that can be taken by multiple entities. 14 | """ 15 | 16 | index_to_label: List[str] 17 | """list of human-readable labels for each action""" 18 | 19 | def __len__(self) -> int: 20 | return len(self.index_to_label) 21 | 22 | 23 | @dataclass 24 | class GlobalCategoricalActionSpace: 25 | """ 26 | Defines a discrete set of actions that can be taken on each timestep. 27 | 28 | For example, the following actions space allows the agent to choose between four actions "up", "down", "left", and "right": 29 | 30 | .. code-block:: python 31 | 32 | GlobalCategoricalActionSpace(["up", "down", "left", "right"]) 33 | """ 34 | 35 | index_to_label: List[str] 36 | """list of human-readable labels for each action""" 37 | 38 | def __len__(self) -> int: 39 | return len(self.index_to_label) 40 | 41 | 42 | @dataclass 43 | class SelectEntityActionSpace: 44 | """ 45 | Allows multiple entities to each select another entity. 46 | """ 47 | 48 | 49 | ActionSpace = Union[ 50 | CategoricalActionSpace, SelectEntityActionSpace, GlobalCategoricalActionSpace 51 | ] 52 | 53 | 54 | @dataclass 55 | class CategoricalActionMask: 56 | """ 57 | Action mask for categorical action that specifies which agents can perform the action, 58 | and includes a dense mask that further constraints the choices available to each agent. 59 | """ 60 | 61 | actor_ids: Optional[Sequence[EntityID]] = None 62 | """ 63 | The ids of the entities that can perform the action. 64 | If ``None``, all entities can perform the action. 65 | Mutually exclusive with ``actor_types``. 66 | """ 67 | 68 | actor_types: Optional[Sequence[EntityName]] = None 69 | """ 70 | The types of the entities that can perform the action. 71 | If ``None``, all entities can perform the action. 72 | Mutually exclusive with ``actor_ids``. 73 | """ 74 | 75 | mask: Union[Sequence[Sequence[bool]], np.ndarray, None] = None 76 | """ 77 | A boolean array of shape ``(len(actor_ids), len(choices))`` that prevents specific actions from being available to certain entities. 78 | If ``mask[i, j]`` is ``True``, then the entity with id ``actor_ids[i]`` can perform action ``j``. 79 | """ 80 | 81 | def __post_init__(self) -> None: 82 | assert ( 83 | self.actor_ids is None or self.actor_types is None 84 | ), "Only one of actor_ids or actor_types can be specified" 85 | 86 | 87 | @dataclass 88 | class GlobalCategoricalActionMask: 89 | """ 90 | Action mask for global categorical action. 91 | """ 92 | 93 | mask: Union[Sequence[Sequence[bool]], np.ndarray, None] = None 94 | """ 95 | An optional boolean array of shape (len(choices),). If mask[i] is True, then 96 | action choice i can be performed. 97 | """ 98 | 99 | 100 | @dataclass 101 | class SelectEntityActionMask: 102 | """ 103 | Action mask for select entity action that specifies which agents can perform the action, 104 | and includes a dense mask that further constraints what other entities can be selected by 105 | each actor. 106 | """ 107 | 108 | actor_ids: Optional[Sequence[EntityID]] = None 109 | """ 110 | The ids of the entities that can perform the action. 111 | If None, all entities can perform the action. 112 | """ 113 | 114 | actor_types: Optional[Sequence[EntityName]] = None 115 | """ 116 | The types of the entities that can perform the action. 117 | If None, all entities can perform the action. 118 | """ 119 | 120 | actee_types: Optional[Sequence[EntityName]] = None 121 | """ 122 | The types of entities that can be selected by each actor. 123 | If None, all entities types can be selected by each actor. 124 | """ 125 | 126 | actee_ids: Optional[Sequence[EntityID]] = None 127 | """ 128 | The ids of the entities of each type that can be selected by each actor. 129 | If None, all entities can be selected by each actor. 130 | """ 131 | 132 | mask: Optional[npt.NDArray[np.bool_]] = None 133 | """ 134 | An boolean array of shape (len(actor_ids), len(actee_ids)). If mask[i, j] is True, then 135 | the agent with id actor_ids[i] can select entity with id actee_ids[j]. 136 | (NOT CURRENTLY IMPLEMENTED) 137 | """ 138 | 139 | def __post_init__(self) -> None: 140 | assert ( 141 | self.actor_ids is None or self.actor_types is None 142 | ), "Only one of actor_ids or actor_types can be specified" 143 | assert ( 144 | self.actee_types is None or self.actee_ids is None 145 | ), "Either actee_entity_types or actees can be specified, but not both." 146 | 147 | 148 | ActionMask = Union[ 149 | CategoricalActionMask, SelectEntityActionMask, GlobalCategoricalActionMask 150 | ] 151 | 152 | 153 | @dataclass 154 | class CategoricalAction: 155 | """ 156 | Outcome of a categorical action. 157 | """ 158 | 159 | actors: Sequence[EntityID] 160 | """the ids of the entities that chose the actions""" 161 | 162 | indices: npt.NDArray[np.int64] 163 | """the indices of the actions that were chosen""" 164 | 165 | index_to_label: List[str] 166 | """mapping from action indices to human readable labels""" 167 | 168 | probs: Optional[npt.NDArray[np.float32]] = None 169 | """the probablity assigned to each action by each agent""" 170 | 171 | @property 172 | def labels(self) -> List[str]: 173 | """the human readable labels of the actions that were performed""" 174 | return [self.index_to_label[i] for i in self.indices] 175 | 176 | 177 | @dataclass 178 | class SelectEntityAction: 179 | """ 180 | Outcome of a select entity action. 181 | """ 182 | 183 | actors: Sequence[EntityID] 184 | """the ids of the entities that chose the action""" 185 | actees: Sequence[EntityID] 186 | """the ids of the entities that were selected by the actors""" 187 | probs: Optional[npt.NDArray[np.float32]] = None 188 | """the probablity assigned to each selection by each agent""" 189 | 190 | 191 | @dataclass 192 | class GlobalCategoricalAction: 193 | """Outcome of a global categorical action.""" 194 | 195 | index: int 196 | """the index of the action that was chosen""" 197 | label: str 198 | """the human readable label of the action that was chosen""" 199 | probs: Optional[npt.NDArray[np.float32]] = None 200 | """the probablity assigned to the action by each agent""" 201 | 202 | 203 | Action = Union[CategoricalAction, SelectEntityAction, GlobalCategoricalAction] 204 | -------------------------------------------------------------------------------- /entity_gym/serialization/sample_recorder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, List, Mapping, Optional 3 | 4 | import msgpack_numpy 5 | import numpy as np 6 | from ragged_buffer import RaggedBufferF32, RaggedBufferI64 7 | 8 | from entity_gym.env import ActionSpace, ObsSpace, VecEnv, VecObs 9 | from entity_gym.env.environment import ActionName 10 | from entity_gym.serialization.msgpack_ragged import ( 11 | ragged_buffer_decode, 12 | ragged_buffer_encode, 13 | ) 14 | 15 | 16 | @dataclass 17 | class Sample: 18 | obs: VecObs 19 | step: List[int] 20 | episode: List[int] 21 | actions: Mapping[str, RaggedBufferI64] 22 | probs: Dict[str, RaggedBufferF32] 23 | logits: Optional[Dict[str, RaggedBufferF32]] 24 | 25 | def serialize(self) -> bytes: 26 | return msgpack_numpy.dumps( # type: ignore 27 | { 28 | "obs": self.obs, 29 | "step": self.step, 30 | "episode": self.episode, 31 | "actions": self.actions, 32 | "probs": self.probs, 33 | "logits": self.logits, 34 | }, 35 | default=ragged_buffer_encode, 36 | ) 37 | 38 | @classmethod 39 | def deserialize(cls, data: bytes) -> "Sample": 40 | return Sample( 41 | **msgpack_numpy.loads( 42 | data, object_hook=ragged_buffer_decode, strict_map_key=False 43 | ) 44 | ) 45 | 46 | 47 | class SampleRecorder: 48 | """ 49 | Writes samples to disk. 50 | """ 51 | 52 | def __init__( 53 | self, 54 | path: str, 55 | act_space: Dict[str, ActionSpace], 56 | obs_space: ObsSpace, 57 | subsample: int, 58 | ) -> None: 59 | self.path = path 60 | self.file = open(path, "wb") 61 | 62 | # Version 0 63 | self.file.write(np.uint64(1).tobytes()) 64 | 65 | bytes = msgpack_numpy.dumps( 66 | { 67 | "act_space": act_space, 68 | "obs_space": obs_space, 69 | "subsample": subsample, 70 | }, 71 | default=ragged_buffer_encode, 72 | ) 73 | self.file.write(np.uint64(len(bytes)).tobytes()) 74 | self.file.write(bytes) 75 | 76 | def record( 77 | self, 78 | sample: Sample, 79 | ) -> None: 80 | bytes = sample.serialize() 81 | # Write 8 bytes unsigned int for the size of the serialized sample 82 | self.file.write(np.uint64(len(bytes)).tobytes()) 83 | self.file.write(bytes) 84 | 85 | def close(self) -> None: 86 | self.file.close() 87 | 88 | 89 | class SampleRecordingVecEnv(VecEnv): 90 | def __init__( 91 | self, 92 | inner: VecEnv, 93 | out_path: str, 94 | subsample: int = 1, 95 | ) -> None: 96 | self.inner = inner 97 | self.out_path = out_path 98 | self.subsample = subsample 99 | self.sample_recorder = SampleRecorder( 100 | out_path, 101 | inner.action_space(), 102 | inner.obs_space(), 103 | subsample, 104 | ) 105 | self.last_obs: Optional[VecObs] = None 106 | self.episodes = list(range(len(inner))) 107 | self.curr_step = [0] * len(inner) 108 | self.next_episode = len(inner) 109 | self.rng = np.random.default_rng(0) 110 | 111 | def reset(self, obs_config: ObsSpace) -> VecObs: 112 | self.curr_step = [0] * len(self) 113 | self.last_obs = self.record_obs(self.inner.reset(obs_config)) 114 | return self.last_obs 115 | 116 | def record_obs(self, obs: VecObs) -> VecObs: 117 | for i, done in enumerate(obs.done): 118 | if done: 119 | self.episodes[i] = self.next_episode 120 | self.next_episode += 1 121 | self.curr_step[i] = 0 122 | else: 123 | self.curr_step[i] += 1 124 | self.last_obs = obs 125 | return obs 126 | 127 | def act( 128 | self, 129 | actions: Mapping[str, RaggedBufferI64], 130 | obs_filter: ObsSpace, 131 | probs: Optional[Dict[str, RaggedBufferF32]] = None, 132 | logits: Optional[Dict[str, RaggedBufferF32]] = None, 133 | ) -> VecObs: 134 | if probs is None: 135 | probs = {} 136 | # with tracer.span("record_samples"): 137 | assert self.last_obs is not None 138 | if self.subsample > 1: 139 | select = self.rng.integers(0, self.subsample, size=len(self.episodes)) == 0 140 | indices = np.arange(len(self.episodes))[select] 141 | if len(indices) > 0: 142 | last_obs = VecObs( 143 | features={k: v[indices] for k, v in self.last_obs.features.items()}, 144 | action_masks={ 145 | k: v[indices] for k, v in self.last_obs.action_masks.items() 146 | }, 147 | reward=self.last_obs.reward[select], 148 | done=self.last_obs.done[select], 149 | metrics=self.last_obs.metrics, 150 | visible={k: v[indices] for k, v in self.last_obs.visible.items()}, 151 | ) 152 | self.sample_recorder.record( 153 | Sample( 154 | obs=last_obs, 155 | step=[step for step, s in zip(self.curr_step, select) if s], 156 | episode=[e for e, s in zip(self.episodes, select) if s], 157 | actions={k: v[indices] for k, v in actions.items()}, 158 | probs={k: v[indices] for k, v in probs.items()} 159 | if probs is not None 160 | else None, 161 | logits=( 162 | {k: v[indices] for k, v in logits.items()} 163 | if logits is not None 164 | else None 165 | ), 166 | ) 167 | ) 168 | else: 169 | self.sample_recorder.record( 170 | Sample( 171 | self.last_obs, 172 | step=list(self.curr_step), 173 | episode=list(self.episodes), 174 | actions=actions, 175 | probs=probs, 176 | logits=logits, 177 | ) 178 | ) 179 | return self.record_obs(self.inner.act(actions, obs_filter)) 180 | 181 | def render(self, **kwargs: Any) -> np.ndarray: 182 | return self.inner.render(**kwargs) 183 | 184 | def action_space(self) -> Dict[ActionName, ActionSpace]: 185 | return self.inner.action_space() 186 | 187 | def obs_space(self) -> ObsSpace: 188 | return self.inner.obs_space() 189 | 190 | def __len__(self) -> int: 191 | return len(self.inner) 192 | 193 | def close(self) -> None: 194 | self.sample_recorder.close() 195 | print("Recorded samples to: ", self.sample_recorder.path) 196 | self.inner.close() 197 | -------------------------------------------------------------------------------- /entity_gym/env/validator.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Mapping 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | 6 | from entity_gym.env.environment import ( 7 | Action, 8 | ActionName, 9 | ActionSpace, 10 | CategoricalActionMask, 11 | CategoricalActionSpace, 12 | Environment, 13 | Observation, 14 | ObsSpace, 15 | SelectEntityActionMask, 16 | SelectEntityActionSpace, 17 | ) 18 | 19 | 20 | class ValidatingEnv(Environment): 21 | def __init__(self, env: Environment) -> None: 22 | self.env = env 23 | self._obs_space = env.obs_space() 24 | self._action_space = env.action_space() 25 | 26 | def act(self, actions: Mapping[ActionName, Action]) -> Observation: 27 | obs = self.env.act(actions) 28 | try: 29 | self._validate(obs) 30 | except AssertionError as e: 31 | print(f"Invalid observation:\n{e}") 32 | raise e 33 | return obs 34 | 35 | def reset(self) -> Observation: 36 | obs = self.env.reset() 37 | try: 38 | self._validate(obs) 39 | except AssertionError as e: 40 | print(f"Invalid observation:\n{e}") 41 | raise e 42 | return obs 43 | 44 | def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]: 45 | return self.env.render(**kwargs) 46 | 47 | def obs_space(self) -> ObsSpace: 48 | return self._obs_space 49 | 50 | def action_space(self) -> Dict[str, ActionSpace]: 51 | return self._action_space 52 | 53 | def _validate(self, obs: Observation) -> None: 54 | assert isinstance( 55 | obs, Observation 56 | ), f"Observation has invalid type: {type(obs)}" 57 | 58 | # Validate features 59 | for entity_type, entity_features in obs.features.items(): 60 | assert ( 61 | entity_type in self._obs_space.entities 62 | ), f"Features contain entity of type '{entity_type}' which is not in observation space: {list(self._obs_space.entities.keys())}" 63 | if isinstance(entity_features, np.ndarray): 64 | assert ( 65 | entity_features.dtype == np.float32 66 | ), f"Features of entity of type '{entity_type}' have invalid dtype: {entity_features.dtype}. Expected: {np.float32}" 67 | shape = entity_features.shape 68 | assert len(shape) == 2 and shape[1] == len( 69 | self._obs_space.entities[entity_type].features 70 | ), f"Features of entity of type '{entity_type}' have invalid shape: {shape}. Expected: (n, {len(self._obs_space.entities[entity_type].features)})" 71 | else: 72 | for i, entity in enumerate(entity_features): 73 | assert len(entity) == len( 74 | self._obs_space.entities[entity_type].features 75 | ), f"Features of {i}-th entity of type '{entity_type}' have invalid length: {len(entity)}. Expected: {len(self._obs_space.entities[entity_type].features)}" 76 | 77 | if entity_type in obs.ids: 78 | assert len(obs.ids[entity_type]) == len( 79 | entity_features 80 | ), f"Length of ids of entity of type '{entity_type}' does not match length of features: {len(obs.ids[entity_type])} != {len(entity_features)}" 81 | 82 | # Validate global features 83 | if len(obs.global_features) != len(self._obs_space.global_features): 84 | raise AssertionError( 85 | f"Length of global features does not match length of global features in observation space: {len(obs.global_features)} != {len(self._obs_space.global_features)}" 86 | ) 87 | 88 | # Validate ids 89 | previous_ids = set() 90 | for entity_type, entity_ids in obs.ids.items(): 91 | assert ( 92 | entity_type in self._obs_space.entities 93 | ), f"IDs contain entity of type '{entity_type}' which is not in observation space: {list(self._obs_space.entities.keys())}" 94 | for id in entity_ids: 95 | assert id not in previous_ids, f"Observation has duplicate id '{id}'" 96 | previous_ids.add(id) 97 | 98 | # Validate actions 99 | ids = obs.id_to_index(self._obs_space) 100 | for action_type, action_mask in obs.actions.items(): 101 | assert ( 102 | action_type in self._action_space 103 | ), f"Actions contain action of type '{action_type}' which is not in action space: {list(self._action_space.keys())}" 104 | space = self._action_space[action_type] 105 | if isinstance(space, CategoricalActionSpace): 106 | assert isinstance( 107 | action_mask, CategoricalActionMask 108 | ), f"Action of type '{action_type}' has invalid type: {type(action_mask)}. Expected: CategoricalActionMask" 109 | if action_mask.actor_ids is not None: 110 | for id in action_mask.actor_ids: 111 | assert ( 112 | id in ids 113 | ), f"Action of type '{action_type}' contains invalid actor id {id} which is not in ids: {obs.ids}" 114 | if action_mask.actor_types is not None: 115 | for actor_type in action_mask.actor_types: 116 | assert ( 117 | actor_type in obs.ids 118 | ), f"Action of type '{action_type}' contains invalid actor type {actor_type} which is not in ids: {obs.ids.keys()}" 119 | mask = action_mask.mask 120 | actor_indices = obs._actor_indices(action_type, self._obs_space) 121 | if isinstance(mask, np.ndarray): 122 | assert ( 123 | mask.dtype == np.bool_ 124 | ), f"Action of type '{action_type}' has invalid dtype: {mask.dtype}. Expected: {np.bool_}" 125 | shape = mask.shape 126 | if shape[0] != 0: 127 | assert shape == ( 128 | len(actor_indices), 129 | len(space.index_to_label), 130 | ), f"Action of type '{action_type}' has invalid shape: {shape}. Expected: ({len(actor_indices), len(space.index_to_label)})" 131 | unmasked_count = mask.sum(axis=1) 132 | for i in range(len(unmasked_count)): 133 | assert ( 134 | unmasked_count[i] > 0 135 | ), f"Action of type '{action_type}' contains invalid mask for {i}-th actor: {mask[i]}. Expected at least one possible action" 136 | elif mask is not None: 137 | assert len(mask) == len( 138 | actor_indices 139 | ), f"Action of type '{action_type}' has invalid length: {len(mask)}. Expected: {len(actor_indices)}" 140 | for i in range(len(mask)): 141 | assert len(mask[i]) == len( 142 | space.index_to_label 143 | ), f"Action of type '{action_type}' has invalid length of mask for {i}-th actor: {len(mask[i])}. Expected: {len(space.index_to_label)}" 144 | assert any( 145 | mask[i] 146 | ), f"Action of type '{action_type}' contains invalid mask for {i}-th actor: {mask[i]}. Expected at least one possible action" 147 | 148 | elif isinstance(self._action_space[action_type], SelectEntityActionSpace): 149 | assert isinstance( 150 | action_mask, SelectEntityActionMask 151 | ), f"Action of type '{action_type}' has invalid type: {type(action_mask)}. Expected: SelectEntityActionMask" 152 | -------------------------------------------------------------------------------- /entity_gym/env/parallel_env_list.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import multiprocessing.connection as conn 3 | from multiprocessing.connection import Connection 4 | from typing import Any, Callable, Dict, Generator, List, Mapping, Optional 5 | 6 | import cloudpickle 7 | import msgpack_numpy 8 | import numpy as np 9 | import numpy.typing as npt 10 | from ragged_buffer import RaggedBufferI64 11 | 12 | from entity_gym.env.env_list import EnvList 13 | from entity_gym.env.environment import ( 14 | ActionName, 15 | ActionSpace, 16 | Environment, 17 | Observation, 18 | ObsSpace, 19 | ) 20 | from entity_gym.env.vec_env import VecEnv, VecObs, batch_obs 21 | from entity_gym.serialization.msgpack_ragged import ( 22 | ragged_buffer_decode, 23 | ragged_buffer_encode, 24 | ) 25 | 26 | 27 | class CloudpickleWrapper: 28 | def __init__(self, var: Any): 29 | self.var = var 30 | 31 | def __getstate__(self) -> Any: 32 | return cloudpickle.dumps(self.var) 33 | 34 | def __setstate__(self, var: Any) -> None: 35 | self.var = cloudpickle.loads(var) 36 | 37 | 38 | class MsgpackConnectionWrapper: 39 | """ 40 | Use msgpack instead of pickle to send and receive data from workers. 41 | """ 42 | 43 | def __init__(self, conn: Connection) -> None: 44 | self._conn = conn 45 | 46 | def close(self) -> None: 47 | self._conn.close() 48 | 49 | def send(self, data: Any) -> None: 50 | s = msgpack_numpy.dumps(data, default=ragged_buffer_encode) 51 | self._conn.send_bytes(s) 52 | 53 | def recv(self) -> Any: 54 | data_bytes = self._conn.recv_bytes() 55 | return msgpack_numpy.loads( 56 | data_bytes, 57 | object_hook=ragged_buffer_decode, 58 | strict_map_key=False, 59 | ) 60 | 61 | 62 | def _worker( 63 | remote: conn.Connection, 64 | parent_remote: conn.Connection, 65 | env_list_config: CloudpickleWrapper, 66 | ) -> None: 67 | parent_remote.close() 68 | env_args = env_list_config.var 69 | envs = EnvList(*env_args) 70 | while True: 71 | try: 72 | cmd, data = remote.recv() 73 | if cmd == "act": 74 | observation = envs.act(data[0], data[1]) 75 | remote.send(observation) 76 | elif cmd == "reset": 77 | observation = envs.reset(data) 78 | remote.send(observation) 79 | elif cmd == "render": 80 | rgb_pixels = envs.render(**data) 81 | remote.send(rgb_pixels) 82 | elif cmd == "close": 83 | envs.close() 84 | remote.close() 85 | break 86 | else: 87 | raise NotImplementedError(f"`{cmd}` is not implemented in the worker") 88 | except EOFError: 89 | break 90 | 91 | 92 | class ParallelEnvList(VecEnv): 93 | """ 94 | We fork the subprocessing from the stable-baselines implementation, but use RaggedBuffers for collecting batches 95 | 96 | Citation here: https://github.com/DLR-RM/stable-baselines3/blob/master/CITATION.bib 97 | """ 98 | 99 | def __init__( 100 | self, 101 | create_env: Callable[[], Environment], 102 | num_envs: int, 103 | num_processes: int, 104 | start_method: Optional[str] = None, 105 | ): 106 | 107 | if start_method is None: 108 | # Fork is not a thread safe method (see issue #217) 109 | # but is more user friendly (does not require to wrap the code in 110 | # a `if __name__ == "__main__":`) 111 | forkserver_available = "forkserver" in mp.get_all_start_methods() 112 | start_method = "forkserver" if forkserver_available else "spawn" 113 | ctx = mp.get_context(start_method) 114 | 115 | assert ( 116 | num_envs % num_processes == 0 117 | ), "The required number of environments can not be equally split into the number of specified processes." 118 | 119 | self.num_processes = num_processes 120 | self.num_envs = num_envs 121 | self.envs_per_process = int(num_envs / num_processes) 122 | 123 | env_list_configs = [ 124 | (create_env, self.envs_per_process) for _ in range(self.num_processes) 125 | ] 126 | 127 | self.remotes = [] 128 | self.work_remotes = [] 129 | for i in range(self.num_processes): 130 | pipe = ctx.Pipe() 131 | self.remotes.append(MsgpackConnectionWrapper(pipe[0])) 132 | self.work_remotes.append(MsgpackConnectionWrapper(pipe[1])) 133 | 134 | self.processes = [] 135 | for work_remote, remote, env_list_config in zip( 136 | self.work_remotes, self.remotes, env_list_configs 137 | ): 138 | # Have to use cloudpickle wrapper here to serialize the ABCMeta class reference 139 | # TODO: Can this be achieved with custom msgpack somehow? 140 | args = (work_remote, remote, CloudpickleWrapper(env_list_config)) 141 | # daemon=True: if the main process crashes, we should not cause things to hang 142 | process = ctx.Process( 143 | target=_worker, args=args, daemon=True 144 | ) # pytype:disable=attribute-error 145 | process.start() 146 | self.processes.append(process) 147 | work_remote.close() 148 | 149 | env = create_env() 150 | self._obs_space = env.obs_space() 151 | self._action_space = env.action_space() 152 | 153 | def reset(self, obs_space: ObsSpace) -> VecObs: 154 | for remote in self.remotes: 155 | remote.send(("reset", obs_space)) 156 | 157 | # Empty initialized observation batch 158 | observations = batch_obs([], self.obs_space(), self.action_space()) 159 | 160 | for remote in self.remotes: 161 | remote_obs_batch = remote.recv() 162 | observations.extend(remote_obs_batch) 163 | 164 | assert isinstance(observations, VecObs) 165 | return observations 166 | 167 | def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]: 168 | rgb_arrays = [] 169 | for remote in self.remotes: 170 | remote.send(("render", kwargs)) 171 | rgb_arrays.append(remote.recv()) 172 | 173 | np_rgb_arrays = np.concatenate(rgb_arrays) 174 | assert isinstance(np_rgb_arrays, np.ndarray) 175 | return np_rgb_arrays 176 | 177 | def close(self) -> None: 178 | for remote in self.remotes: 179 | remote.send(("close", None)) 180 | for process in self.processes: 181 | process.join() 182 | 183 | def _chunk_actions( 184 | self, actions: Mapping[str, RaggedBufferI64] 185 | ) -> Generator[Mapping[str, RaggedBufferI64], List[Observation], None]: 186 | for i in range(0, self.num_envs, self.envs_per_process): 187 | yield { 188 | atype: a[i : i + self.envs_per_process, :, :] 189 | for atype, a in actions.items() 190 | } 191 | 192 | def act( 193 | self, actions: Mapping[str, RaggedBufferI64], obs_space: ObsSpace 194 | ) -> VecObs: 195 | remote_actions = self._chunk_actions(actions) 196 | for remote, action in zip(self.remotes, remote_actions): 197 | remote.send(("act", (action, obs_space))) 198 | 199 | # Empty initialized observation batch 200 | observations = batch_obs([], self.obs_space(), self.action_space()) 201 | 202 | for remote in self.remotes: 203 | remote_obs_batch = remote.recv() 204 | observations.extend(remote_obs_batch) 205 | return observations 206 | 207 | def __len__(self) -> int: 208 | return self.num_envs 209 | 210 | def obs_space(self) -> ObsSpace: 211 | return self._obs_space 212 | 213 | def action_space(self) -> Dict[ActionName, ActionSpace]: 214 | return self._action_space 215 | -------------------------------------------------------------------------------- /entity_gym/examples/multi_snake.py: -------------------------------------------------------------------------------- 1 | import random 2 | from copy import deepcopy 3 | from dataclasses import dataclass 4 | from typing import Dict, List, Mapping, Tuple 5 | 6 | from entity_gym.env import ( 7 | Action, 8 | ActionSpace, 9 | CategoricalAction, 10 | CategoricalActionMask, 11 | CategoricalActionSpace, 12 | Entity, 13 | Environment, 14 | Observation, 15 | ObsSpace, 16 | ) 17 | 18 | 19 | @dataclass 20 | class Snake: 21 | color: int 22 | segments: List[Tuple[int, int]] 23 | 24 | 25 | @dataclass 26 | class Food: 27 | color: int 28 | position: Tuple[int, int] 29 | 30 | 31 | class MultiSnake(Environment): 32 | """ 33 | Turn-based version of Snake with multiple snakes. 34 | Each snake has a different color. 35 | For each snake, Food of that color is placed randomly on the board. 36 | Snakes can only eat Food of their color. 37 | When a snake eats Food of the same color, it grows by one unit. 38 | When a snake grows and it's length was less than 11, the player receives a reward of 0.1 / num_snakes. 39 | The game ends when a snake collides with another snake, runs into a wall, eats Food of another color, or all snakes reach a length of 11. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | board_size: int = 10, 45 | num_snakes: int = 2, 46 | num_players: int = 1, 47 | max_snake_length: int = 11, 48 | max_steps: int = 180, 49 | ): 50 | """ 51 | :param num_players: number of players 52 | :param board_size: size of the board 53 | :param num_snakes: number of snakes per player 54 | """ 55 | assert num_snakes < 10, f"num_snakes must be less than 10, got {num_snakes}" 56 | self.board_size = board_size 57 | self.num_snakes = num_snakes 58 | self.num_players = num_players 59 | self.max_snake_length = max_snake_length 60 | self.snakes: List[Snake] = [] 61 | self.food: List[Food] = [] 62 | self.game_over = False 63 | self.last_scores = [0] * self.num_players 64 | self.scores = [0] * self.num_players 65 | self.step = 0 66 | self.max_steps = max_steps 67 | 68 | def obs_space(cls) -> ObsSpace: 69 | return ObsSpace( 70 | global_features=["step"], 71 | entities={ 72 | "SnakeHead": Entity(["x", "y", "color"]), 73 | "SnakeBody": Entity(["x", "y", "color"]), 74 | "Food": Entity(["x", "y", "color"]), 75 | }, 76 | ) 77 | 78 | def action_space(cls) -> Dict[str, ActionSpace]: 79 | return { 80 | "move": CategoricalActionSpace(["up", "down", "left", "right"]), 81 | } 82 | 83 | def _spawn_snake(self, color: int) -> None: 84 | while True: 85 | x = random.randint(0, self.board_size - 1) 86 | y = random.randint(0, self.board_size - 1) 87 | if any( 88 | (x, y) == (sx, sy) for snake in self.snakes for sx, sy in snake.segments 89 | ): 90 | continue 91 | self.snakes.append(Snake(color, [(x, y)])) 92 | break 93 | 94 | def _spawn_food(self, color: int) -> None: 95 | while True: 96 | x = random.randint(0, self.board_size - 1) 97 | y = random.randint(0, self.board_size - 1) 98 | if any((x, y) == (f.position[0], f.position[1]) for f in self.food) or any( 99 | (x, y) == (sx, sy) for snake in self.snakes for sx, sy in snake.segments 100 | ): 101 | continue 102 | self.food.append(Food(color, (x, y))) 103 | break 104 | 105 | def reset(self) -> Observation: 106 | self.snakes = [] 107 | self.food = [] 108 | self.game_over = False 109 | self.last_scores = [0] * self.num_players 110 | self.scores = [0] * self.num_players 111 | self.step = 0 112 | for i in range(self.num_snakes): 113 | self._spawn_snake(i) 114 | for i in range(self.num_snakes): 115 | self._spawn_food(i) 116 | return self.observe() 117 | 118 | def act(self, actions: Mapping[str, Action]) -> Observation: 119 | game_over = False 120 | self.step += 1 121 | move_action = actions["move"] 122 | self.last_scores = deepcopy(self.scores) 123 | food_to_spawn = [] 124 | assert isinstance(move_action, CategoricalAction) 125 | for id, move in zip(move_action.actors, move_action.indices): 126 | snake = self.snakes[id] 127 | x, y = snake.segments[-1] 128 | if move == 0: 129 | y += 1 130 | elif move == 1: 131 | y -= 1 132 | elif move == 2: 133 | x -= 1 134 | elif move == 3: 135 | x += 1 136 | if x < 0 or x >= self.board_size or y < 0 or y >= self.board_size: 137 | game_over = True 138 | if any((x, y) == (sx, sy) for s in self.snakes for sx, sy in s.segments): 139 | game_over = True 140 | ate_food = False 141 | snake.segments.append((x, y)) 142 | for i in range(len(self.food)): 143 | if self.food[i].position == (x, y): 144 | if self.food[i].color != snake.color: 145 | game_over = True 146 | elif len(snake.segments) <= self.max_snake_length: 147 | ate_food = True 148 | self.scores[id // self.num_snakes] += ( 149 | 1.0 / (self.max_snake_length - 1) / self.num_snakes 150 | ) 151 | self.food.pop(i) 152 | # Don't spawn food immediately since it might spawn in front of another snake that hasn't moved yet 153 | food_to_spawn.append(snake.color) 154 | break 155 | if not ate_food: 156 | snake.segments = snake.segments[1:] 157 | for player in range(self.num_players): 158 | snakes_per_player = self.num_snakes // self.num_players 159 | if all( 160 | len(s.segments) >= self.max_snake_length 161 | for s in self.snakes[ 162 | player * snakes_per_player : (player + 1) * snakes_per_player 163 | ] 164 | ): 165 | game_over = True 166 | for color in food_to_spawn: 167 | self._spawn_food(color) 168 | if self.step >= self.max_steps: 169 | game_over = True 170 | return self.observe(done=game_over) 171 | 172 | def observe(self, done: bool = False, player: int = 0) -> Observation: 173 | color_offset = player * (self.num_snakes // self.num_players) 174 | 175 | def cycle_color(color: int) -> int: 176 | return (color - color_offset) % self.num_snakes 177 | 178 | return Observation( 179 | global_features=[self.step], 180 | features={ 181 | "SnakeHead": [ 182 | ( 183 | s.segments[-1][0], 184 | s.segments[-1][1], 185 | cycle_color(s.color), 186 | ) 187 | for s in self.snakes 188 | ], 189 | "SnakeBody": [ 190 | (sx, sy, cycle_color(snake.color)) 191 | for snake in self.snakes 192 | for sx, sy in snake.segments[:-1] 193 | ], 194 | "Food": [ 195 | ( 196 | f.position[0], 197 | f.position[1], 198 | cycle_color(f.color), 199 | ) 200 | for f in self.food 201 | ], 202 | }, 203 | ids={ 204 | "SnakeHead": list(range(self.num_snakes)), 205 | }, 206 | actions={ 207 | "move": CategoricalActionMask( 208 | actor_types=["SnakeHead"], 209 | ), 210 | }, 211 | reward=self.scores[player] - self.last_scores[player], 212 | done=done, 213 | ) 214 | -------------------------------------------------------------------------------- /entity_gym/serialization/sample_loader.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, List, Optional, Set, Tuple 3 | 4 | import msgpack_numpy 5 | import numpy as np 6 | import tqdm 7 | from ragged_buffer import RaggedBufferBool, RaggedBufferF32, RaggedBufferI64 8 | 9 | from entity_gym.env import ObsSpace 10 | from entity_gym.env.environment import ActionSpace 11 | from entity_gym.env.vec_env import VecActionMask 12 | from entity_gym.ragged_dict import RaggedActionDict, RaggedBatchDict 13 | from entity_gym.serialization.msgpack_ragged import ragged_buffer_decode 14 | from entity_gym.serialization.sample_recorder import Sample 15 | 16 | 17 | @dataclass 18 | class Episode: 19 | number: int 20 | steps: int 21 | entities: Dict[str, RaggedBufferF32] 22 | visible: Dict[str, RaggedBufferBool] 23 | actions: Dict[str, RaggedBufferI64] 24 | masks: Dict[str, VecActionMask] 25 | logprobs: Dict[str, RaggedBufferF32] 26 | logits: Dict[str, RaggedBufferF32] 27 | total_reward: float 28 | complete: bool = False 29 | 30 | 31 | @dataclass 32 | class MergedSamples: 33 | entities: RaggedBatchDict[np.float32] 34 | visible: RaggedBatchDict[np.bool_] 35 | actions: RaggedBatchDict[np.int64] 36 | logprobs: RaggedBatchDict[np.float32] 37 | masks: RaggedActionDict 38 | logits: Optional[RaggedBatchDict[np.float32]] 39 | frames: int 40 | 41 | @classmethod 42 | def empty(clz) -> "MergedSamples": 43 | return MergedSamples( 44 | entities=RaggedBatchDict(RaggedBufferF32), 45 | visible=RaggedBatchDict(RaggedBufferBool), 46 | actions=RaggedBatchDict(RaggedBufferI64), 47 | logprobs=RaggedBatchDict(RaggedBufferF32), 48 | logits=None, 49 | masks=RaggedActionDict(), 50 | frames=0, 51 | ) 52 | 53 | def push_sample(self, sample: Sample) -> None: 54 | self.entities.extend(sample.obs.features) 55 | self.visible.extend(sample.obs.visible) 56 | self.actions.extend(sample.actions) 57 | self.logprobs.extend(sample.probs) 58 | if sample.logits is not None: 59 | if self.logits is None: 60 | self.logits = RaggedBatchDict(RaggedBufferF32) 61 | self.logits.extend(sample.logits) 62 | self.masks.extend(sample.obs.action_masks) 63 | self.frames += len(sample.episode) 64 | 65 | 66 | @dataclass 67 | class Trace: 68 | action_space: Dict[str, ActionSpace] 69 | obs_space: ObsSpace 70 | samples: List[Sample] 71 | subsample: int = 1 72 | 73 | @classmethod 74 | def load(cls, path: str, progress_bar: bool = False) -> "Trace": 75 | with open(path, "rb") as f: 76 | return cls.deserialize(f.read(), progress_bar=progress_bar) 77 | 78 | @classmethod 79 | def deserialize(cls, data: bytes, progress_bar: bool = False) -> "Trace": 80 | samples: List[Sample] = [] 81 | if progress_bar: 82 | pbar = tqdm.tqdm(total=len(data)) 83 | 84 | offset = 0 85 | # Read version 86 | version = int(np.frombuffer(data[:8], dtype=np.uint64)[0]) 87 | assert version == 0 or version == 1 88 | header_len = int(np.frombuffer(data[8:16], dtype=np.uint64)[0]) 89 | header = msgpack_numpy.loads( 90 | data[16 : 16 + header_len], 91 | object_hook=ragged_buffer_decode, 92 | strict_map_key=False, 93 | ) 94 | action_space = header["act_space"] 95 | obs_space = header["obs_space"] 96 | subsample = header.get("subsample", 1) 97 | 98 | offset = 16 + header_len 99 | while offset < len(data): 100 | size = int(np.frombuffer(data[offset : offset + 8], dtype=np.uint64)[0]) 101 | offset += 8 102 | samples.append(Sample.deserialize(data[offset : offset + size])) 103 | offset += size 104 | if progress_bar: 105 | pbar.update(size + 8) 106 | return Trace(action_space, obs_space, samples, subsample=subsample) 107 | 108 | def episodes( 109 | self, include_incomplete: bool = False, progress_bar: bool = False 110 | ) -> List[Episode]: 111 | episodes = {} 112 | prev_episodes: Optional[List[int]] = None 113 | if progress_bar: 114 | samples = tqdm.tqdm(self.samples) 115 | else: 116 | samples = self.samples 117 | for sample in samples: 118 | for i, e in enumerate(sample.episode): 119 | if e not in episodes: 120 | episodes[e] = Episode( 121 | e, 122 | 0, 123 | {}, 124 | {}, 125 | {}, 126 | {}, 127 | {}, 128 | {}, 129 | 0.0, 130 | ) 131 | 132 | episodes[e].steps += 1 133 | episodes[e].total_reward += sample.obs.reward[i] 134 | if sample.obs.done[i] and prev_episodes is not None: 135 | episodes[prev_episodes[i]].complete = True 136 | 137 | for name, feats in sample.obs.features.items(): 138 | if name not in episodes[e].entities: 139 | episodes[e].entities[name] = feats[i] 140 | else: 141 | episodes[e].entities[name].extend(feats[i]) 142 | for name, vis in sample.obs.visible.items(): 143 | if name not in episodes[e].visible: 144 | episodes[e].visible[name] = vis[i] 145 | else: 146 | episodes[e].visible[name].extend(vis[i]) 147 | for name, acts in sample.actions.items(): 148 | if name not in episodes[e].actions: 149 | episodes[e].actions[name] = acts[i] 150 | else: 151 | episodes[e].actions[name].extend(acts[i]) 152 | for name, mask in sample.obs.action_masks.items(): 153 | if name not in episodes[e].masks: 154 | episodes[e].masks[name] = mask[i] 155 | else: 156 | episodes[e].masks[name].extend(mask[i]) 157 | for name, logprobs in sample.probs.items(): 158 | if name not in episodes[e].logprobs: 159 | episodes[e].logprobs[name] = logprobs[i] 160 | else: 161 | episodes[e].logprobs[name].extend(logprobs[i]) 162 | if sample.logits is not None: 163 | for name, logits in sample.logits.items(): 164 | if name not in episodes[e].logits: 165 | episodes[e].logits[name] = logits[i] 166 | else: 167 | episodes[e].logits[name].extend(logits[i]) 168 | prev_episodes = sample.episode 169 | return sorted( 170 | (e for e in episodes.values() if e.complete or include_incomplete), 171 | key=lambda e: e.number, 172 | ) 173 | 174 | def train_test_split( 175 | self, test_frac: float = 0.1, progress_bar: bool = False 176 | ) -> Tuple[MergedSamples, MergedSamples]: 177 | if self.subsample == 1: 178 | total_frames = len(self.samples) * len(self.samples[0].episode) 179 | else: 180 | total_frames = sum(len(s.episode) for s in self.samples) 181 | if progress_bar: 182 | pbar = tqdm.tqdm(total=len(self.samples)) 183 | 184 | test = MergedSamples.empty() 185 | test_episodes: Set[int] = set() 186 | i = 0 187 | while test.frames < total_frames * test_frac: 188 | sample = self.samples[i] 189 | test_episodes.update(sample.episode) 190 | test.push_sample(sample) 191 | i += 1 192 | if progress_bar: 193 | pbar.update(1) 194 | 195 | train = MergedSamples.empty() 196 | for sample in self.samples[i:]: 197 | # TODO: could be more efficient 198 | if any(e in test_episodes for e in sample.episode): 199 | continue 200 | train.push_sample(sample) 201 | if progress_bar: 202 | i += 1 203 | pbar.update(1) 204 | 205 | return train, test 206 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. -------------------------------------------------------------------------------- /docs/source/quick-start-guide.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | Quick Start Guide 3 | ================= 4 | 5 | This tutorial will guide you through the process of creating a simple entity gym environment. 6 | 7 | .. toctree:: 8 | 9 | Installation 10 | ============ 11 | 12 | .. code-block:: console 13 | 14 | $ pip install entity_gym 15 | 16 | The ``Environment`` class 17 | ========================= 18 | 19 | Create a new file ``treasure_hunt.py`` with the following contents: 20 | 21 | .. code-block:: python 22 | 23 | from typing import Dict, Mapping 24 | from entity_gym.runner import CliRunner 25 | from entity_gym.env import * 26 | 27 | # The `Environment` class defines the interface that all entity gym environments must implement. 28 | class TreasureHunt(Environment): 29 | # The `obs_space` specifies the shape of observations returned by the environment. 30 | def obs_space(self) -> ObsSpace: 31 | return ObsSpace() 32 | 33 | # The `action_space` specifies what actions that can be performed by the agent. 34 | def action_space(self) -> Dict[str, ActionSpace]: 35 | return {} 36 | 37 | # `reset` should initialize the environment and return the initial observation. 38 | def reset(self) -> Observation: 39 | return Observation.empty() 40 | 41 | # `act` performs the chosen actions and returns the new observation. 42 | def act(self, actions: Mapping[ActionName, Action]) -> Observation: 43 | return Observation.empty() 44 | 45 | 46 | if __name__ == "__main__": 47 | env = TreasureHunt() 48 | # The `CliRunner` can run any environment with a command line interface. 49 | CliRunner(env).run() 50 | 51 | Try it out by running the following command: 52 | 53 | .. code-block:: console 54 | 55 | $ python treasure_hunt.py 56 | 57 | Since we haven't implemented any functionality for our environment, this won't do much yet. 58 | However, you should still see something like the following output: 59 | 60 | .. code-block:: text 61 | 62 | Environment: TreasureHunt 63 | 64 | Step 0 65 | Reward: 0.0 66 | Total: 0.0 67 | Press ENTER to continue, CTRL-C to exit 68 | 69 | Adding global features 70 | ====================== 71 | 72 | Let's add some logic to keep track of the player's position and expose it in observations: 73 | 74 | .. code-block:: python 75 | 76 | class TreasureHunt(Environment): 77 | def obs_space(self) -> ObsSpace: 78 | # `global_features` adds a fixed-length vector of features to the observation. 79 | return ObsSpace(global_features=["x_pos", "y_pos"]) 80 | 81 | def reset(self) -> Observation: 82 | self.x_pos = 0 83 | self.y_pos = 0 84 | return self.observe() 85 | 86 | def observe(self) -> Observation: 87 | return Observation( 88 | global_features=[self.x_pos, self.y_pos], done=False, reward=0 89 | ) 90 | 91 | def act(self, actions: Mapping[ActionName, Action]) -> Observation: 92 | return self.observe() 93 | 94 | def action_space(self) -> Dict[str, ActionSpace]: 95 | return {} 96 | 97 | If you run the environment again, you should now see it print out the player's position: 98 | 99 | .. code-block:: text 100 | 101 | Environment: TreasureHunt 102 | Global features: x_pos, y_pos 103 | 104 | Step 0 105 | Reward: 0 106 | Total: 0 107 | Global features: x_pos=0, y_pos=0 108 | Press ENTER to continue, CTRL-C to exit 109 | 110 | Implementing a "move" action 111 | ============================ 112 | 113 | Now that the player has a position, we can add an action that moves the player. 114 | We change the ``action_space`` method to define ``"move"`` as global categorical action with 4 choices. 115 | We implement the logic for the action in the ``act`` method. 116 | Finally, we include a ``GlobalCategoricalActionMask`` for the ``"move"`` action in the ``Observation`` returned by ``observe``. 117 | If we wanted the ``"move"`` action to be unavailable on some timestep, we could omit the mask from the corresponding observation. 118 | 119 | .. code-block:: python 120 | 121 | class TreasureHunt(Environment): 122 | ... 123 | 124 | def action_space(self) -> Dict[str, ActionSpace]: 125 | # The `GlobalCategoricalActionSpace` allows the agent to choose from set of discrete actions. 126 | return { 127 | "move": GlobalCategoricalActionSpace(["up", "down", "left", "right"]) 128 | } 129 | 130 | def act(self, actions: Mapping[ActionName, Action]) -> Observation: 131 | # Adjust the player's position according to the chosen action. 132 | action = actions["move"] 133 | assert isinstance(action, GlobalCategoricalAction) 134 | if action.label == "up" and self.y_pos < 10: 135 | self.y_pos += 1 136 | elif action.label == "down" and self.y_pos > -10: 137 | self.y_pos -= 1 138 | elif action.label == "left" and self.x_pos > -10: 139 | self.x_pos -= 1 140 | elif action.label == "right" and self.x_pos < 10: 141 | self.x_pos += 1 142 | return self.observe() 143 | 144 | def observe(self) -> Observation: 145 | return Observation( 146 | global_features=[self.x_pos, self.y_pos], 147 | done=False, 148 | reward=0, 149 | # Each `Observation` must specify which actions are available on the current step. 150 | actions={"move": GlobalCategoricalActionMask()}, 151 | ) 152 | 153 | It is now possible to move the player: 154 | 155 | .. code-block:: text 156 | 157 | Environment: TreasureHunt 158 | Global features: x_pos, y_pos 159 | Categorical move: up, down, left, right 160 | 161 | Step 0 162 | Reward: 0 163 | Total: 0 164 | Global features: x_pos=0, y_pos=0 165 | Choose move (0/up 1/down 2/left 3/right) 166 | 0 167 | Step 1 168 | Reward: 0 169 | Total: 0 170 | Global features: x_pos=0, y_pos=1 171 | Choose move (0/up 1/down 2/left 3/right) 172 | 3 173 | Step 2 174 | Reward: 0 175 | Total: 0 176 | Global features: x_pos=1, y_pos=1 177 | Choose move (0/up 1/down 2/left 3/right) 178 | 179 | Adding "Trap" and "Treasure" entities 180 | ===================================== 181 | 182 | Now, we are going to place additional entities in the environment: 183 | 184 | * *Treasure* can be collected by the player and increases the player's score by 1.0. Once all treasures are collected, the game is won. 185 | * Moving onto a *trap* immediately ends the game. 186 | 187 | We define the new entity types by specifying the ``ObsSpace.entities`` dictionary in the ``obs_space`` method. 188 | Similarly, ``observe`` now returns a ``features`` dictionary with an entry specifying the current positions of both entities. 189 | The logic that defines how the entities are spawned and affect the game is added to ``reset`` and ``act``. 190 | 191 | .. code-block:: python 192 | 193 | import random 194 | from typing import Mapping, Tuple, Dict 195 | 196 | class TreasureHunt(Environment): 197 | def reset(self) -> Observation: 198 | self.x_pos = 0 199 | self.y_pos = 0 200 | self.game_over = False 201 | self.traps = [] 202 | self.treasure = [] 203 | for _ in range(5): 204 | self.traps.append(self._random_empty_pos()) 205 | for _ in range(5): 206 | self.treasure.append(self._random_empty_pos()) 207 | return self.observe() 208 | 209 | def obs_space(self) -> ObsSpace: 210 | return ObsSpace( 211 | global_features=["x_pos", "y_pos"], 212 | # An observation space can have several entities with different features. 213 | # On any given step, an observation may include any number of the defined entities. 214 | entities={ 215 | "Trap": Entity(features=["x_pos", "y_pos"]), 216 | "Treasure": Entity(features=["x_pos", "y_pos"]), 217 | } 218 | ) 219 | 220 | def act(self, actions: Mapping[ActionName, Action]) -> Observation: 221 | action = actions["move"] 222 | assert isinstance(action, GlobalCategoricalAction) 223 | if action.label == "up" and self.y_pos < 10: 224 | self.y_pos += 1 225 | elif action.label == "down" and self.y_pos > -10: 226 | self.y_pos -= 1 227 | elif action.label == "left" and self.x_pos > -10: 228 | self.x_pos -= 1 229 | elif action.label == "right" and self.x_pos < 10: 230 | self.x_pos += 1 231 | 232 | reward = 0.0 233 | if (self.x_pos, self.y_pos) in self.treasure: 234 | reward = 1.0 235 | self.treasure.remove((self.x_pos, self.y_pos)) 236 | if (self.x_pos, self.y_pos) in self.traps or len(self.treasure) == 0: 237 | self.game_over = True 238 | 239 | return self.observe(reward) 240 | 241 | def observe(self, reward: float = 0.0) -> Observation: 242 | return Observation( 243 | global_features=[self.x_pos, self.y_pos], 244 | features={ 245 | "Trap": self.traps, 246 | "Treasure": self.treasure, 247 | }, 248 | done=self.game_over, 249 | reward=reward, 250 | actions={"move": GlobalCategoricalActionMask()}, 251 | ) 252 | 253 | def _random_empty_pos(self) -> Tuple[int, int]: 254 | # Generate a random position on the grid that is not occupied by a trap, treasure, or player. 255 | while True: 256 | x = random.randint(-5, 5) 257 | y = random.randint(-5, 5) 258 | if (x, y) not in (self.traps + self.treasure + [(self.x_pos, self.y_pos)]): 259 | return x, y 260 | 261 | 262 | If you run the environment again, you will now see and be able to interact with all the entities: 263 | 264 | .. code-block:: text 265 | 266 | Environment: TreasureHunt 267 | Global features: x_pos, y_pos 268 | Entity Trap: x_pos, y_pos 269 | Entity Treasure: x_pos, y_pos 270 | Categorical move: up, down, left, right 271 | 272 | Step 0 273 | Reward: 0.0 274 | Total: 0.0 275 | Global features: x_pos=0, y_pos=0 276 | Entities 277 | 0 Trap(x_pos=-2, y_pos=5) 278 | 1 Trap(x_pos=-1, y_pos=-4) 279 | 2 Trap(x_pos=0, y_pos=2) 280 | 3 Trap(x_pos=-5, y_pos=-3) 281 | 4 Trap(x_pos=4, y_pos=3) 282 | 5 Treasure(x_pos=-3, y_pos=3) 283 | 6 Treasure(x_pos=3, y_pos=4) 284 | 7 Treasure(x_pos=5, y_pos=5) 285 | 8 Treasure(x_pos=-1, y_pos=-5) 286 | 9 Treasure(x_pos=5, y_pos=3) 287 | Choose move (0/up 1/down 2/left 3/right) 288 | 289 | This concludes the tutorial. 290 | If you want to learn how to train a neural network to play the game we just implemented, 291 | check out the `enn-trainer tutorial `_. -------------------------------------------------------------------------------- /entity_gym/env/environment.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass, field 3 | from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | 8 | from .action import * 9 | from .common import ActionName, EntityID, EntityName, Features 10 | 11 | 12 | @dataclass 13 | class Entity: 14 | """Defines the set of features for an entity type.""" 15 | 16 | features: List[str] 17 | 18 | 19 | @dataclass 20 | class ObsSpace: 21 | """ 22 | Defines what features can be observed by an agent. 23 | """ 24 | 25 | global_features: List[str] = field(default_factory=list) 26 | """fixed size list of features that are observable on each timestep""" 27 | entities: Dict[EntityName, Entity] = field(default_factory=dict) 28 | """ 29 | Defines the types of entities that can be observed. 30 | On a given timestep, an ``Observation`` may contain multiple entities of each type. 31 | """ 32 | 33 | 34 | class Observation: 35 | """ 36 | Observation returned by the environment on one timestep. 37 | 38 | :param features: Maps each entity type to a list of features for the entities of that type. 39 | :param actions: Maps each action type to an ActionMask specifying which entities can perform 40 | the action. 41 | :param reward: Reward received on this timestep. 42 | :param done: Whether the episode has ended. 43 | :param ids: Maps each entity type to a list of entity ids for the entities of that type. 44 | :param visible: Optional mask for each entity type that prevents the policy but not the 45 | value function from observing certain entities. 46 | """ 47 | 48 | global_features: Union[npt.NDArray[np.float32], Sequence[float]] 49 | features: Mapping[EntityName, Features] 50 | actions: Mapping[ActionName, ActionMask] 51 | done: bool 52 | reward: float 53 | ids: Mapping[EntityName, Sequence[EntityID]] = field(default_factory=dict) 54 | visible: Mapping[EntityName, Union[npt.NDArray[np.bool_], Sequence[bool]]] = field( 55 | default_factory=dict 56 | ) 57 | metrics: Dict[str, float] = field(default_factory=dict) 58 | 59 | def __init__( 60 | self, 61 | *, 62 | done: bool, 63 | reward: float, 64 | visible: Optional[ 65 | Mapping[EntityName, Union[npt.NDArray[np.bool_], Sequence[bool]]] 66 | ] = None, 67 | entities: Optional[ 68 | Mapping[ 69 | EntityName, 70 | Union[ 71 | Features, 72 | Tuple[Features, Sequence[EntityID]], 73 | None, 74 | ], 75 | ] 76 | ] = None, 77 | features: Optional[Mapping[EntityName, Features]] = None, 78 | ids: Optional[Mapping[EntityName, Sequence[EntityID]]] = None, 79 | global_features: Union[npt.NDArray[np.float32], Sequence[float], None] = None, 80 | actions: Optional[Mapping[ActionName, ActionMask]] = None, 81 | metrics: Optional[Dict[str, float]] = None, 82 | ) -> None: 83 | self.global_features = global_features if global_features is not None else [] 84 | self.actions = actions or {} 85 | self.done = done 86 | self.reward = reward 87 | if features is not None: 88 | assert entities is None, "Cannot specify both features and entities" 89 | self.features = features 90 | self.ids = ids or {} 91 | else: 92 | self.features = { 93 | etype: entity[0] if isinstance(entity, tuple) else entity 94 | for etype, entity in (entities or {}).items() 95 | if entity is not None 96 | } 97 | self.ids = { 98 | etype: entity[1] 99 | for etype, entity in (entities or {}).items() 100 | if entity is not None and isinstance(entity, tuple) 101 | } 102 | self.metrics = metrics or {} 103 | self.visible = visible or {} 104 | self._id_to_index: Optional[Dict[EntityID, int]] = None 105 | self._index_to_id: Optional[List[EntityID]] = None 106 | 107 | @classmethod 108 | def empty(cls) -> "Observation": 109 | return Observation( 110 | actions={}, 111 | done=False, 112 | reward=0.0, 113 | ) 114 | 115 | def _actor_indices( 116 | self, atype: ActionName, obs_space: ObsSpace 117 | ) -> npt.NDArray[np.int64]: 118 | action = self.actions[atype] 119 | if isinstance(action, GlobalCategoricalActionMask): 120 | return np.array( 121 | [sum(len(v) for v in self.features.values())], dtype=np.int64 122 | ) 123 | elif action.actor_ids is not None: 124 | id_to_index = self.id_to_index(obs_space) 125 | return np.array( 126 | [id_to_index[id] for id in action.actor_ids], dtype=np.int64 127 | ) 128 | elif action.actor_types is not None: 129 | ids: List[int] = [] 130 | id_to_index = self.id_to_index(obs_space) 131 | for etype in action.actor_types: 132 | ids.extend(id_to_index[id] for id in self.ids[etype]) 133 | return np.array(ids, dtype=np.int64) 134 | else: 135 | return np.array( 136 | np.arange( 137 | sum(len(self.ids[etype]) for etype in obs_space.entities), 138 | dtype=np.int64, 139 | ), 140 | dtype=np.int64, 141 | ) 142 | 143 | def _actee_indices( 144 | self, atype: ActionName, obs_space: ObsSpace 145 | ) -> npt.NDArray[np.int64]: 146 | action = self.actions[atype] 147 | assert isinstance(action, SelectEntityActionMask) 148 | if action.actee_ids is not None: 149 | id_to_index = self.id_to_index(obs_space) 150 | return np.array( 151 | [id_to_index[id] for id in action.actee_ids], dtype=np.int64 152 | ) 153 | elif action.actee_types is not None: 154 | ids: List[int] = [] 155 | id_to_index = self.id_to_index(obs_space) 156 | for etype in action.actee_types: 157 | ids.extend(id_to_index[id] for id in self.ids[etype]) 158 | return np.array(ids, dtype=np.int64) 159 | else: 160 | return np.array( 161 | np.arange( 162 | sum(len(self.ids[etype]) for etype in obs_space.entities), 163 | dtype=np.int64, 164 | ), 165 | dtype=np.int64, 166 | ) 167 | 168 | def id_to_index(self, obs_space: ObsSpace) -> Dict[EntityID, int]: 169 | offset = 0 170 | if self._id_to_index is None: 171 | self._id_to_index = {} 172 | for etype in obs_space.entities.keys(): 173 | ids = self.ids.get(etype) 174 | if ids is None: 175 | continue 176 | for i, id in enumerate(ids): 177 | self._id_to_index[id] = i + offset 178 | offset += len(ids) 179 | return self._id_to_index 180 | 181 | def index_to_id(self, obs_space: ObsSpace) -> List[EntityID]: 182 | if self._index_to_id is None: 183 | self._index_to_id = [] 184 | for etype in obs_space.entities.keys(): 185 | ids = self.ids.get(etype) 186 | if ids is None: 187 | ids = [None] * self.num_entities(etype) 188 | self._index_to_id.extend(ids) 189 | return self._index_to_id 190 | 191 | def num_entities(self, entity: EntityName) -> int: 192 | feats = self.features[entity] 193 | if isinstance(feats, np.ndarray): 194 | return feats.shape[0] 195 | else: 196 | return len(feats) 197 | 198 | 199 | class Environment(ABC): 200 | """ 201 | Abstract base class for all environments. 202 | """ 203 | 204 | @abstractmethod 205 | def obs_space(self) -> ObsSpace: 206 | """ 207 | Defines the shape of observations returned by the environment. 208 | """ 209 | raise NotImplementedError 210 | 211 | @abstractmethod 212 | def action_space(self) -> Dict[str, ActionSpace]: 213 | """ 214 | Defines the types of actions that can be taken in the environment. 215 | """ 216 | raise NotImplementedError 217 | 218 | @abstractmethod 219 | def reset(self) -> Observation: 220 | """ 221 | Resets the environment and returns the initial observation. 222 | """ 223 | raise NotImplementedError 224 | 225 | @abstractmethod 226 | def act(self, actions: Mapping[ActionName, Action]) -> Observation: 227 | """ 228 | Performs the given action and returns the resulting observation. 229 | 230 | :param actions: Maps the name of each action type to the action to perform. 231 | """ 232 | raise NotImplementedError 233 | 234 | def reset_filter(self, obs_filter: ObsSpace) -> Observation: 235 | """ 236 | Resets the environment and returns the initial observation. 237 | Any entities or features that are not present in the filter are removed from the observation. 238 | """ 239 | return self._filter_obs(self.reset(), obs_filter) 240 | 241 | def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]: 242 | """ 243 | Renders the environment. 244 | 245 | :param kwargs: a dictionary of arguments to send to the rendering process 246 | """ 247 | raise NotImplementedError 248 | 249 | def act_filter( 250 | self, actions: Mapping[ActionName, Action], obs_filter: ObsSpace 251 | ) -> Observation: 252 | """ 253 | Performs the given action and returns the resulting observation. 254 | Any entities or features that are not present in the filter are removed from the observation. 255 | """ 256 | return self._filter_obs(self.act(actions), obs_filter) 257 | 258 | def close(self) -> None: 259 | """Closes the environment.""" 260 | 261 | def _filter_obs(self, obs: Observation, obs_filter: ObsSpace) -> Observation: 262 | selectors = self._compile_feature_filter(obs_filter) 263 | features: Dict[ 264 | EntityName, Union[npt.NDArray[np.float32], Sequence[Sequence[float]]] 265 | ] = {} 266 | for etype, feats in obs.features.items(): 267 | selector = selectors[etype] 268 | if isinstance(feats, np.ndarray): 269 | features[etype] = feats[:, selector].reshape( 270 | feats.shape[0], len(selector) 271 | ) 272 | else: 273 | features[etype] = [[entity[i] for i in selector] for entity in feats] 274 | return Observation( 275 | global_features=obs.global_features, 276 | features=features, 277 | ids=obs.ids, 278 | actions=obs.actions, 279 | done=obs.done, 280 | reward=obs.reward, 281 | metrics=obs.metrics, 282 | visible=obs.visible, 283 | ) 284 | 285 | def _compile_feature_filter(self, obs_space: ObsSpace) -> Dict[str, np.ndarray]: 286 | obs_space = self.obs_space() 287 | feature_selection = {} 288 | for entity_name, entity in obs_space.entities.items(): 289 | feature_selection[entity_name] = np.array( 290 | [entity.features.index(f) for f in entity.features], dtype=np.int32 291 | ) 292 | feature_selection["__global__"] = np.array( 293 | [obs_space.global_features.index(f) for f in obs_space.global_features], 294 | dtype=np.int32, 295 | ) 296 | return feature_selection 297 | -------------------------------------------------------------------------------- /docs/source/complex-action-spaces.rst: -------------------------------------------------------------------------------- 1 | ===================== 2 | Complex Action Spaces 3 | ===================== 4 | 5 | This tutorial walks you through implementing a grid-world environment in which the player controls multiple entities at the same time. 6 | You will learn how to use the ``CategoricalActionSpace`` to allow multiple entities perform an action, use action masks to limit the set of available action choices, and use the ``SelectEntiyActionSpace`` to implement an action that allows entities to select other entities. 7 | 8 | An extended version of the environment implemented in this tutorial can be found in `entity_gym/examples/minesweeper.py `_. 9 | 10 | .. toctree:: 11 | 12 | Overview 13 | ======== 14 | 15 | The environment we will implement contains two types of objects, mines and robots. 16 | 17 | .. image:: https://user-images.githubusercontent.com/12845088/151688370-4ab0dd31-2dd9-4d25-9a4e-531c24b99865.png 18 | 19 | The player controls all robots in the environment. 20 | On every step, each robot may move in one of four cardinal directions, or stay in place and defuse all adjacent mines. 21 | If a robot defuses a mine, the mine is removed from the environment. 22 | If a robot steps on a mine, the robot is removed from the environment. 23 | If there are no more robots, the player loses. 24 | The player wins the game when all mines are defused. 25 | 26 | Environment 27 | =========== 28 | 29 | We start off by defining the initial state, observation space, and action space of the environment. 30 | The observation space has two different types of entities, mines and robots, both of which have an x and y coordinate. 31 | The action space has a single categorical action with five possible choices, which will be used to move the robots. 32 | 33 | .. code-block:: python 34 | 35 | from typing import List, Tuple, Dict 36 | from entity_gym.env import * 37 | 38 | class MineSweeper(Environment): 39 | def reset(self) -> Observation: 40 | positions = random.sample( 41 | [(x, y) for x in range(6) for y in range(6)], 42 | 7, 43 | ) 44 | self.mines = positions[:5] 45 | self.robots = positions[5:] 46 | return self.observe() 47 | 48 | @classmethod 49 | def obs_space(cls) -> ObsSpace: 50 | return ObsSpace({ 51 | "Mine": Entity(features=["x", "y"]), 52 | "Robot": Entity(features=["x", "y"]), 53 | }) 54 | 55 | @classmethod 56 | def action_space(cls) -> Dict[ActionName, ActionSpace]: 57 | return { 58 | "Move": CategoricalActionSpace( 59 | ["Up", "Down", "Left", "Right", "Defuse Mines"], 60 | ), 61 | } 62 | 63 | def observe(self) -> Observation: 64 | raise NotImplementedError 65 | 66 | def act(self, actions: Action) -> Observation: 67 | raise NotImplementedError 68 | 69 | Observation 70 | =========== 71 | 72 | Next, we implement the ``observe`` method, which returns an `Observation `_ representing the current state of the environment. 73 | 74 | The ``entities`` dictionary contains the current state of the environment. 75 | For the "Mine" entities, we need to specify only the features for each entity. 76 | Because the "Robot" entities will be performing an action, we have to additionally supply a list of IDs for the "Robot" entities. 77 | The IDs will later be used to determine which "Robot" entity performed which action. 78 | 79 | On every step, we make the "Move" action available by specifying a ``CategoricalActionMask``. 80 | The ``actor_types`` parameter specifies the types of entities that can perform the action. 81 | In this case, we only allow "Robot" entities to perform the action (and not "Mine" entities). 82 | As an alternative to ``actor_types``, ``CategoricalActionMask`` can also be supplied with an ``actor_ids`` list with the IDs of the entities that can perform the action. 83 | 84 | The game is ``done`` once there are no more mines or robots, and we award a ``reward`` of 1.0 if all mines are defused. 85 | 86 | .. code-block:: python 87 | 88 | def observe(self) -> Observation: 89 | return Observation( 90 | actions={ 91 | "Move": CategoricalActionMask( 92 | # Allow all robots to move 93 | actor_types=["Robot"], 94 | ), 95 | }, 96 | entities={ 97 | "Robot": ( 98 | self.robots, 99 | # Unique identifiers for all "Robot" entities 100 | [("Robot", i) for i in range(len(self.robots))], 101 | ), 102 | # We don't need identifiers for mines since they are not 103 | # directly referenced by any actions. 104 | "Mine": self.mines, 105 | }, 106 | # The game is done once there are no more mines or robots 107 | done=len(self.mines) == 0 or len(self.robots) == 0, 108 | # Give reward of 1.0 for defusing all mines 109 | reward=1.0 if len(self.mines) == 0 else 0, 110 | ) 111 | 112 | Actions 113 | ======= 114 | 115 | Finally, we implement the `act` method that takes an action and returns the next observation. 116 | 117 | .. code-block:: python 118 | 119 | def act(self, actions: Mapping[ActionName, Action]) -> Observation: 120 | move = actions["Move"] 121 | assert isinstance(move, CategoricalAction) 122 | for (_, i), action in zip(move.actors, move.indices): 123 | # Action space is ["Up", "Down", "Left", "Right", "Defuse Mines"], 124 | x, y = self.robots[i] 125 | if choice == 0 and y < self.height - 1: 126 | self.robots[i] = (x, y + 1) 127 | elif choice == 1 and y > 0: 128 | self.robots[i] = (x, y - 1) 129 | elif choice == 2 and x > 0: 130 | self.robots[i] = (x - 1, y) 131 | elif choice == 3 and x < self.width - 1: 132 | self.robots[i] = (x + 1, y) 133 | elif choice == 4: 134 | # Remove all mines adjacent to this robot 135 | rx, ry = self.robots[i] 136 | self.mines = [ 137 | (x, y) 138 | for (x, y) in self.mines 139 | if abs(x - rx) + abs(y - ry) > 1 140 | ] 141 | 142 | # Remove all robots that stepped on a mine 143 | self.robots = [ 144 | (x, y) 145 | for (x, y) in self.robots 146 | if (x, y) not in self.mines 147 | ] 148 | 149 | return self.observe() 150 | 151 | Action Masks 152 | ============ 153 | 154 | Currently, robots may move in any direction, but any movement that would take a robot outside the grid will be ignored. 155 | We may want to restrict the robots choices so that they cannot move outside the grid. 156 | We can do this by setting the `mask` attribute of the [`ActionMask`](todo link to docs) object to a boolean array of shape (number_entities, number_actions) that specifies which actions are allowed. 157 | 158 | .. code-block:: python 159 | 160 | import random 161 | from entity_gym.env import * 162 | 163 | 164 | class MineSweeper(Environment): 165 | ... 166 | 167 | def valid_moves(self, x: int, y: int) -> List[bool]: 168 | return [ 169 | x < self.width - 1, 170 | x > 0, 171 | y < self.height - 1, 172 | y > 0, 173 | # Always allow staying in place and defusing mines 174 | True, 175 | ] 176 | 177 | def observe(self) -> Observation: 178 | return Observation( 179 | actions={ 180 | "Move": CategoricalActionMask( 181 | # Allow all robots to move 182 | actor_types=["Robot"], 183 | mask=[ 184 | self.valid_moves(x, y) 185 | for (x, y) in self.robots 186 | ], 187 | ), 188 | }, 189 | ... 190 | ) 191 | 192 | SelectEntityAction 193 | ================== 194 | 195 | Suppose we want to add a new *Orbital Cannon* entity to the game that can fire a laser at any mine or robot every 5 steps. 196 | Since the number of mines and robots is unknown, we cannot use a normal categorical action for our Orbital Cannon. 197 | Instead, we will use a `SelectEntityAction `_, which allows us to select one entity from a list of entities. 198 | 199 | 200 | .. code-block:: python 201 | 202 | from entity_gym.env import * 203 | 204 | class MineSweeper(Environment): 205 | ... 206 | 207 | @classmethod 208 | def obs_space(cls) -> ObsSpace: 209 | return ObsSpace({ 210 | "Mine": Entity(features=["x", "y"]), 211 | "Robot": Entity(features=["x", "y"]), 212 | # The Orbital Cannon entity 213 | "Orbital Cannon": Entity(["cooldown"]), 214 | }) 215 | 216 | @classmethod 217 | def action_space(cls) -> ActionSpace: 218 | return ActionSpace({ 219 | "Move": CategoricalAction( 220 | ["Up", "Down", "Left", "Right", "Defuse Mines"] 221 | ), 222 | # New action for firing laser 223 | "Fire Orbital Cannon": SelectEntityActionSpace(), 224 | }) 225 | 226 | 227 | 228 | def reset(self) -> Observation: 229 | ... 230 | # Set orbital cannon cooldown to 5 231 | self.orbital_cannon_cooldown = 5 232 | return self.observe() 233 | 234 | def observe(self) -> Observation: 235 | return Observation( 236 | entities={ 237 | "Mine": ( 238 | self.mines, 239 | [("Mine", i) for i in range(len(self.mines))], 240 | ), 241 | "Robot": ( 242 | self.robots, 243 | [("Robot", i) for i in range(len(self.robots))], 244 | ), 245 | "Orbital Cannon": ( 246 | [(self.orbital_cannon_cooldown,)], 247 | [("Orbital Cannon", 0)], 248 | ) 249 | }, 250 | actions={ 251 | "Move": CategoricalActionMask( 252 | actor_types=["Robot"], 253 | ), 254 | "Fire Orbital Cannon": SelectEntityActionMask( 255 | # Only the Orbital Cannon can fire, but not if cooldown > 0 256 | actor_types=["Orbital Cannon"] if self.orbital_cannon_cooldown == 0 else [], 257 | # Both mines and robots can be fired at 258 | actee_types=["Mine", "Robot"], 259 | ), 260 | }, 261 | done=len(self.mines) == 0 or len(self.robots) == 0, 262 | reward=1.0 if len(self.mines) == 0 else 0, 263 | ) 264 | 265 | def act(self, actions: Mapping[ActionName, Action]) -> Observation: 266 | fire = actions["Fire Orbital Cannon"] 267 | assert isinstance(fire, SelectEntityAction) 268 | remove_robot = None 269 | for (entity_type, i) in fire.actees: 270 | if entity_type == "Mine": 271 | self.mines.remove(self.mines[i]) 272 | elif entity_type == "Robot": 273 | # Don't remove yet to keep indices valid 274 | remove_robot = i 275 | 276 | move = actions["Move"] 277 | ... 278 | 279 | if remove_robot is not None: 280 | self.robots.pop(remove_robot) 281 | # Remove all robots that stepped on a mine 282 | self.robots = [r for r in self.robots if r not in self.mines] 283 | 284 | return self.observe 285 | -------------------------------------------------------------------------------- /entity_gym/runner.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Optional, Tuple 3 | 4 | import click 5 | import numpy as np 6 | 7 | from entity_gym.env import ( 8 | Action, 9 | CategoricalAction, 10 | CategoricalActionSpace, 11 | Environment, 12 | Observation, 13 | ObsSpace, 14 | SelectEntityAction, 15 | SelectEntityActionMask, 16 | SelectEntityActionSpace, 17 | ) 18 | from entity_gym.env.environment import ( 19 | Action, 20 | GlobalCategoricalAction, 21 | GlobalCategoricalActionMask, 22 | GlobalCategoricalActionSpace, 23 | Observation, 24 | ) 25 | from entity_gym.env.validator import ValidatingEnv 26 | 27 | 28 | class Agent(ABC): 29 | """Interface for an agent that receives observations and outputs actions.""" 30 | 31 | @abstractmethod 32 | def act(self, obs: Observation) -> Tuple[Dict[str, Action], float]: 33 | pass 34 | 35 | 36 | class CliRunner: 37 | """ 38 | Interactively run any entity gym environment in a CLI. 39 | 40 | Example: 41 | 42 | .. code-block:: pycon 43 | 44 | >>> from entity_gym.runner import CliRunner 45 | >>> from entity_gym.examples import TreasureHunt 46 | >>> CliRunner(TreasureHunt()).run() 47 | """ 48 | 49 | def __init__(self, env: Environment, agent: Optional[Agent] = None) -> None: 50 | self.env = ValidatingEnv(env) 51 | self.agent = agent 52 | 53 | def run(self, restart: bool = False) -> None: 54 | print_env(self.env) 55 | print() 56 | 57 | obs_space = self.env.obs_space() 58 | actions = self.env.action_space() 59 | obs = self.env.reset_filter(obs_space) 60 | total_reward = obs.reward 61 | step = 0 62 | while True: 63 | if obs.done: 64 | if restart: 65 | obs = self.env.reset_filter(obs_space) 66 | total_reward = obs.reward 67 | step = 0 68 | else: 69 | break 70 | if self.agent is None: 71 | agent_action: Optional[Dict[str, Action]] = None 72 | agent_prediction: Optional[float] = None 73 | else: 74 | agent_action, agent_prediction = self.agent.act(obs) 75 | print_obs(step, obs, total_reward, obs_space, agent_prediction) 76 | action: Dict[str, Action] = {} 77 | received_action = False 78 | for action_name, action_mask in obs.actions.items(): 79 | action_def = actions[action_name] 80 | if isinstance(action_mask, GlobalCategoricalActionMask): 81 | assert isinstance(action_def, GlobalCategoricalActionSpace) 82 | if agent_action is None: 83 | choices = " ".join( 84 | f"{i}/{label}" 85 | for i, label in enumerate(action_def.index_to_label) 86 | ) 87 | else: 88 | probs = agent_action[action_name].probs 89 | assert probs is not None 90 | choice_id = agent_action[action_name].index # type: ignore 91 | choices = " ".join( 92 | click.style( 93 | f"{i}/{label} ", 94 | fg="yellow" if i == choice_id else None, 95 | bold=i == choice_id, 96 | ) 97 | + click.style(f"{100 * prob:.1f}%", fg="yellow") 98 | for i, (label, prob) in enumerate( 99 | zip(action_def.index_to_label, probs) 100 | ) 101 | ) 102 | 103 | click.echo( 104 | f"Choose " 105 | + click.style(f"{action_name}", fg="green") 106 | + f" ({choices})" 107 | ) 108 | while True: 109 | try: 110 | inp = input() 111 | if inp == "" and agent_action is not None: 112 | choice_id = agent_action[action_name].index # type: ignore 113 | else: 114 | try: 115 | choice_id = int(inp) 116 | except ValueError: 117 | print( 118 | f"Invalid choice '{inp}' (must be an integer)" 119 | ) 120 | continue 121 | if choice_id < 0 or choice_id >= len( 122 | action_def.index_to_label 123 | ): 124 | print( 125 | f"Invalid choice {inp} (must be in range [0, {len(action_def.index_to_label) - 1}])" 126 | ) 127 | continue 128 | received_action = True 129 | break 130 | except KeyboardInterrupt: 131 | print() 132 | print("Exiting") 133 | return 134 | action[action_name] = GlobalCategoricalAction( 135 | index=choice_id, 136 | label=action_def.index_to_label[choice_id], 137 | ) 138 | continue 139 | elif action_mask.actor_ids is not None: 140 | actor_ids = action_mask.actor_ids 141 | elif action_mask.actor_types is not None: 142 | actor_ids = [ 143 | id for atype in action_mask.actor_types for id in obs.ids[atype] 144 | ] 145 | else: 146 | actor_ids = obs.index_to_id(obs_space) 147 | 148 | print() 149 | 150 | # Initialize actions 151 | if isinstance(action_def, CategoricalActionSpace): 152 | if action_name not in action: 153 | action[action_name] = CategoricalAction( 154 | indices=np.zeros( 155 | (0, len(action_def.index_to_label)), dtype=np.int64 156 | ), 157 | index_to_label=action_def.index_to_label, 158 | actors=[], 159 | ) 160 | elif isinstance(action_def, SelectEntityActionSpace): 161 | if action_name not in action: 162 | action[action_name] = SelectEntityAction([], []) 163 | 164 | for actor_id in actor_ids: 165 | if isinstance(action_def, CategoricalActionSpace): 166 | # Prompt user for action 167 | if agent_action is None: 168 | choices = " ".join( 169 | f"{i}/{label}" 170 | for i, label in enumerate(action_def.index_to_label) 171 | ) 172 | else: 173 | aa = agent_action[action_name] 174 | assert isinstance(aa, CategoricalAction) 175 | actor_index = aa.actors.index(actor_id) 176 | probs = aa.probs[actor_index] # type: ignore 177 | assert probs is not None 178 | choice_id = aa.indices[actor_index] 179 | choices = " ".join( 180 | click.style( 181 | f"{i}/{label} ", 182 | fg="yellow" if i == choice_id else None, 183 | bold=i == choice_id, 184 | ) 185 | + click.style(f"{100 * prob:.1f}%", fg="yellow") 186 | for i, (label, prob) in enumerate( 187 | zip(action_def.index_to_label, probs) 188 | ) 189 | ) 190 | click.echo( 191 | f"Choose " 192 | + click.style(f"{action_name}", fg="green") 193 | + f" for actor {actor_id}" 194 | + f" ({choices})" 195 | ) 196 | 197 | try: 198 | inp = input() 199 | if inp == "" and agent_action is not None: 200 | aa = agent_action[action_name] 201 | assert isinstance(aa, CategoricalAction) 202 | choice_id = aa.indices[actor_index] 203 | else: 204 | try: 205 | choice_id = int(inp) 206 | except ValueError: 207 | print( 208 | f"Invalid choice '{inp}' (must be an integer)" 209 | ) 210 | continue 211 | received_action = True 212 | except KeyboardInterrupt: 213 | print() 214 | print("Exiting") 215 | return 216 | a = action[action_name] 217 | assert isinstance(a, CategoricalAction) 218 | a.indices = np.array(list(a.indices) + [choice_id]) 219 | a.actors = list(a.actors) + [actor_id] 220 | elif isinstance(action_def, SelectEntityActionSpace): 221 | assert isinstance(action_mask, SelectEntityActionMask) 222 | # Prompt user for entity 223 | click.echo( 224 | f"Choose " 225 | + click.style(f"{action_name}", fg="green") 226 | + f" for actor {actor_id}" 227 | ) 228 | if action_mask.actee_ids is not None: 229 | print( 230 | f"Selectable entities: {', '.join([str(id) for id in action_mask.actee_ids])}" 231 | ) 232 | elif action_mask.actee_types is not None: 233 | print( 234 | f"Selectable entity types: {', '.join([str(id) for id in action_mask.actee_types])}" 235 | ) 236 | else: 237 | print("Selectable entities: all") 238 | 239 | try: 240 | try: 241 | entity_id = int(input()) 242 | except ValueError: 243 | print(f"Invalid choice '{inp}' (must be an integer)") 244 | continue 245 | except KeyboardInterrupt: 246 | print() 247 | print("Exiting") 248 | return 249 | received_action = True 250 | a = action[action_name] 251 | assert isinstance(a, SelectEntityAction) 252 | a.actors = list(a.actors) + [actor_id] 253 | a.actees = list(a.actees) + [entity_id] 254 | else: 255 | raise ValueError(f"Unknown action type {action_def}") 256 | if not received_action: 257 | try: 258 | input("Press ENTER to continue, CTRL-C to exit") 259 | except KeyboardInterrupt: 260 | print() 261 | print("Exiting") 262 | return 263 | obs = self.env.act_filter(action, obs_space) 264 | total_reward += obs.reward 265 | step += 1 266 | 267 | print_obs(step, obs, total_reward, obs_space) 268 | click.secho("Episode finished", fg="green") 269 | 270 | 271 | def print_env(env: ValidatingEnv) -> None: 272 | click.secho(f"Environment: {env.env.__class__.__name__}", fg="white", bold=True) 273 | obs = env.obs_space() 274 | if len(obs.global_features) > 0: 275 | click.echo( 276 | click.style("Global features: ", fg="cyan") + ", ".join(obs.global_features) 277 | ) 278 | for label, entity in obs.entities.items(): 279 | click.echo( 280 | click.style(f"Entity ", fg="cyan") 281 | + click.style(f"{label}", fg="green") 282 | + click.style(f": " if len(entity.features) > 0 else "", fg="cyan") 283 | + ", ".join(entity.features) 284 | ) 285 | acts = env.action_space() 286 | for label, action in acts.items(): 287 | if isinstance(action, CategoricalActionSpace) or isinstance( 288 | action, GlobalCategoricalActionSpace 289 | ): 290 | click.echo( 291 | click.style(f"Categorical", fg="cyan") 292 | + click.style(f" {label}", fg="green") 293 | + click.style(f": ", fg="cyan") 294 | + ", ".join(action.index_to_label) 295 | ) 296 | elif isinstance(action, SelectEntityActionSpace): 297 | click.echo( 298 | click.style(f"Select entity", fg="cyan") 299 | + click.style(f" {label}", fg="green") 300 | ) 301 | else: 302 | raise ValueError(f"Unknown action type {action}") 303 | 304 | 305 | def print_obs( 306 | step: int, 307 | obs: Observation, 308 | total_reward: float, 309 | obs_filter: ObsSpace, 310 | predicted_return: Optional[float] = None, 311 | ) -> None: 312 | click.secho(f"Step {step}", fg="white", bold=True) 313 | click.echo(click.style("Reward: ", fg="cyan") + f"{obs.reward}") 314 | click.echo(click.style("Total: ", fg="cyan") + f"{total_reward}") 315 | if predicted_return is not None: 316 | click.echo( 317 | click.style("Predicted return: ", fg="cyan") 318 | + click.style(f"{predicted_return:.3e}", fg="yellow") 319 | ) 320 | if len(obs_filter.global_features) > 0: 321 | click.echo( 322 | click.style("Global features: ", fg="cyan") 323 | + ", ".join( 324 | f"{label}={value}" 325 | for label, value in zip(obs_filter.global_features, obs.global_features) 326 | ) 327 | ) 328 | if len(obs_filter.entities) > 0: 329 | click.echo(click.style("Entities", fg="cyan")) 330 | entity_index = 0 331 | for entity_type, features in obs.features.items(): 332 | for entity in range(len(features)): 333 | if entity_type in obs.ids: 334 | id = f" (id={obs.ids[entity_type][entity]})" 335 | else: 336 | id = "" 337 | rendered = ( 338 | click.style(entity_type, fg="green") 339 | + "(" 340 | + ", ".join( 341 | map( 342 | lambda nv: nv[0] + "=" + str(nv[1]), 343 | zip( 344 | obs_filter.entities[entity_type].features, 345 | features[entity], 346 | ), 347 | ) 348 | ) 349 | + ")" 350 | ) 351 | print(f"{entity_index} {rendered}{id}") 352 | entity_index += 1 353 | 354 | 355 | __all__ = ["CliRunner", "Agent"] 356 | -------------------------------------------------------------------------------- /entity_gym/env/vec_env.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from abc import ABC, abstractmethod 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, List, Mapping, Optional, Union, overload 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | from ragged_buffer import RaggedBufferBool, RaggedBufferF32, RaggedBufferI64 9 | 10 | from entity_gym.env.environment import ( 11 | ActionName, 12 | ActionSpace, 13 | CategoricalActionSpace, 14 | EntityName, 15 | GlobalCategoricalActionSpace, 16 | Observation, 17 | ObsSpace, 18 | SelectEntityActionSpace, 19 | ) 20 | 21 | 22 | @dataclass 23 | class VecSelectEntityActionMask: 24 | actors: RaggedBufferI64 25 | actees: RaggedBufferI64 26 | 27 | @overload 28 | def __getitem__(self, i: int) -> RaggedBufferI64: 29 | ... 30 | 31 | @overload 32 | def __getitem__(self, i: npt.NDArray[np.int64]) -> "VecSelectEntityActionMask": 33 | ... 34 | 35 | def __getitem__( 36 | self, i: Union[int, npt.NDArray[np.int64]] 37 | ) -> Union["VecSelectEntityActionMask", RaggedBufferI64]: 38 | if isinstance(i, int): 39 | return self.actors[i] 40 | else: 41 | return VecSelectEntityActionMask(self.actors[i], self.actees[i]) 42 | 43 | def extend(self, other: Any) -> None: 44 | assert isinstance( 45 | other, VecSelectEntityActionMask 46 | ), f"Expected VecSelectEntityActionMask, got {type(other)}" 47 | self.actors.extend(other.actors) 48 | self.actees.extend(other.actees) 49 | 50 | def clear(self) -> None: 51 | self.actors.clear() 52 | self.actees.clear() 53 | 54 | 55 | @dataclass 56 | class VecCategoricalActionMask: 57 | actors: RaggedBufferI64 58 | mask: Optional[RaggedBufferBool] 59 | 60 | def __getitem__( 61 | self, i: Union[int, npt.NDArray[np.int64]] 62 | ) -> "VecCategoricalActionMask": 63 | if self.mask is not None and self.mask.size0() > 0: 64 | return VecCategoricalActionMask(self.actors[i], self.mask[i]) 65 | else: 66 | return VecCategoricalActionMask(self.actors[i], None) 67 | 68 | def extend(self, other: Any) -> None: 69 | assert isinstance( 70 | other, VecCategoricalActionMask 71 | ), f"Expected CategoricalActionMaskBatch, got {type(other)}" 72 | if self.mask is not None and other.mask is not None: 73 | self.mask.extend(other.mask) 74 | elif self.mask is None and other.mask is None: 75 | pass 76 | elif self.mask is not None: 77 | self.mask.extend( 78 | RaggedBufferBool.from_flattened( 79 | flattened=np.ones( 80 | shape=(other.actors.items(), self.mask.size2()), 81 | dtype=np.bool_, 82 | ), 83 | lengths=other.actors.size1(), 84 | ) 85 | ) 86 | elif other.mask is not None: 87 | self.mask = RaggedBufferBool.from_flattened( 88 | flattened=np.ones( 89 | shape=(self.actors.items(), other.mask.size2()), 90 | dtype=np.bool_, 91 | ), 92 | lengths=self.actors.size1(), 93 | ) 94 | self.mask.extend(other.mask) 95 | else: 96 | raise Exception("Impossible!") 97 | self.actors.extend(other.actors) 98 | 99 | def clear(self) -> None: 100 | self.actors.clear() 101 | if self.mask is not None: 102 | self.mask.clear() 103 | 104 | 105 | VecActionMask = Union[VecCategoricalActionMask, VecSelectEntityActionMask] 106 | 107 | 108 | @dataclass 109 | class Metric: 110 | count: int = 0 111 | sum: float = 0.0 112 | min: float = float("inf") 113 | max: float = float("-inf") 114 | 115 | def push(self, value: float) -> None: 116 | self.count += 1 117 | self.sum += value 118 | self.min = min(self.min, value) 119 | self.max = max(self.max, value) 120 | 121 | def __iadd__(self, m: "Metric") -> "Metric": 122 | self.count += m.count 123 | self.sum += m.sum 124 | self.min = min(self.min, m.min) 125 | self.max = max(self.max, m.max) 126 | return self 127 | 128 | def __add__(self, m: "Metric") -> "Metric": 129 | return Metric( 130 | count=self.count + m.count, 131 | sum=self.sum + m.sum, 132 | min=min(self.min, m.min), 133 | max=max(self.max, m.max), 134 | ) 135 | 136 | @property 137 | def mean(self) -> float: 138 | if self.count == 0: 139 | return 0.0 140 | else: 141 | return self.sum / self.count 142 | 143 | 144 | @dataclass 145 | class VecObs: 146 | """ 147 | A batch of observations from a vectorized environment. 148 | """ 149 | 150 | features: Dict[EntityName, RaggedBufferF32] 151 | # Optional mask to hide specific entities from the policy but not the value function 152 | visible: Dict[EntityName, RaggedBufferBool] 153 | action_masks: Dict[ActionName, VecActionMask] 154 | reward: npt.NDArray[np.float32] 155 | done: npt.NDArray[np.bool_] 156 | metrics: Dict[str, Metric] 157 | 158 | def extend(self, b: "VecObs") -> None: 159 | num_envs = len(self.reward) 160 | # Extend visible (must happen before features in case of backfill) 161 | for etype in self.features.keys(): 162 | if etype in b.visible: 163 | if etype not in self.visible: 164 | self.visible[etype] = RaggedBufferBool.from_flattened( 165 | flattened=np.ones( 166 | shape=(self.features[etype].items(), 1), dtype=np.bool_ 167 | ), 168 | lengths=self.features[etype].size1(), 169 | ) 170 | self.visible[etype].extend(b.visible[etype]) 171 | elif etype in self.visible: 172 | self.visible[etype].extend( 173 | RaggedBufferBool.from_flattened( 174 | flattened=np.ones( 175 | shape=(b.features[etype].items(), 1), dtype=np.bool_ 176 | ), 177 | lengths=b.features[etype].size1(), 178 | ) 179 | ) 180 | for etype, feats in b.features.items(): 181 | if etype not in self.features: 182 | self.features[etype] = empty_ragged_f32( 183 | feats=feats.size2(), sequences=num_envs 184 | ) 185 | self.features[etype].extend(feats) 186 | for etype, feats in self.features.items(): 187 | if etype not in b.features: 188 | feats.extend(empty_ragged_f32(feats.size2(), len(b.reward))) 189 | for atype, amask in b.action_masks.items(): 190 | if atype not in self.action_masks: 191 | raise NotImplementedError() 192 | else: 193 | self.action_masks[atype].extend(amask) 194 | self.reward = np.concatenate((self.reward, b.reward)) 195 | self.done = np.concatenate((self.done, b.done)) 196 | num_envs = len(self.reward) 197 | for name, stats in b.metrics.items(): 198 | if name in self.metrics: 199 | self.metrics[name].count += stats.count 200 | self.metrics[name].sum += stats.sum 201 | self.metrics[name].min = min(self.metrics[name].min, stats.min) 202 | self.metrics[name].max = max(self.metrics[name].max, stats.max) 203 | else: 204 | self.metrics[name] = copy.copy(stats) 205 | 206 | 207 | class VecEnv(ABC): 208 | """ 209 | Interface for vectorized environments. The main goal of VecEnv is to allow 210 | for maximally efficient environment implementations. 211 | """ 212 | 213 | @abstractmethod 214 | def obs_space(self) -> ObsSpace: 215 | """ 216 | Returns a dictionary mapping the name of observable entities to their type. 217 | """ 218 | raise NotImplementedError 219 | 220 | @abstractmethod 221 | def action_space(self) -> Dict[ActionName, ActionSpace]: 222 | """ 223 | Returns a dictionary mapping the name of actions to their action space. 224 | """ 225 | raise NotImplementedError 226 | 227 | @abstractmethod 228 | def reset(self, obs_config: ObsSpace) -> VecObs: 229 | """ 230 | Resets all environments and returns the initial observations. 231 | """ 232 | raise NotImplementedError 233 | 234 | @abstractmethod 235 | def act( 236 | self, actions: Mapping[ActionName, RaggedBufferI64], obs_filter: ObsSpace 237 | ) -> VecObs: 238 | """ 239 | Performs the given actions on the underlying environments and returns the resulting observations. 240 | Any environment that reaches the end of its episode is reset and returns the initial observation of the next episode. 241 | """ 242 | raise NotImplementedError 243 | 244 | @abstractmethod 245 | def render(self, **kwargs: Any) -> npt.NDArray[np.uint8]: 246 | raise NotImplementedError 247 | 248 | @abstractmethod 249 | def __len__(self) -> int: 250 | raise NotImplementedError 251 | 252 | def close(self) -> None: 253 | pass 254 | 255 | def has_global_entity(self) -> bool: 256 | return len(self.obs_space().global_features) > 0 or any( 257 | isinstance(space, GlobalCategoricalActionSpace) 258 | for space in self.action_space().values() 259 | ) 260 | 261 | 262 | def batch_obs( 263 | obs: List[Observation], obs_space: ObsSpace, action_space: Dict[str, ActionSpace] 264 | ) -> VecObs: 265 | """ 266 | Converts a list of observations into a batch of observations. 267 | """ 268 | features: Dict[EntityName, RaggedBufferF32] = {} 269 | visible: Dict[EntityName, RaggedBufferBool] = {} 270 | action_masks: Dict[ActionName, VecActionMask] = {} 271 | reward = [] 272 | done = [] 273 | metrics = {} 274 | 275 | # Initialize the entire batch with all entities and actions 276 | for entity_name, entity in obs_space.entities.items(): 277 | nfeat = len(entity.features) 278 | features[entity_name] = RaggedBufferF32(nfeat) 279 | global_entity = len(obs_space.global_features) > 0 280 | for action_name, space in action_space.items(): 281 | if isinstance(space, CategoricalActionSpace): 282 | action_masks[action_name] = VecCategoricalActionMask( 283 | RaggedBufferI64(1), 284 | None, 285 | ) 286 | elif isinstance(space, GlobalCategoricalActionSpace): 287 | action_masks[action_name] = VecCategoricalActionMask( 288 | RaggedBufferI64(1), 289 | None, 290 | ) 291 | global_entity = True 292 | elif isinstance(space, SelectEntityActionSpace): 293 | action_masks[action_name] = VecSelectEntityActionMask( 294 | RaggedBufferI64(1), RaggedBufferI64(1) 295 | ) 296 | else: 297 | raise NotImplementedError(f"Action space {space} not supported") 298 | if global_entity: 299 | nfeat = len(obs_space.global_features) 300 | features["__global__"] = RaggedBufferF32(nfeat) 301 | 302 | for i, o in enumerate(obs): 303 | # Merge entity features 304 | for entity_type, entity in obs_space.entities.items(): 305 | if entity_type not in features: 306 | features[entity_type] = RaggedBufferF32.from_flattened( 307 | np.zeros((0, len(entity.features)), dtype=np.float32), 308 | lengths=np.zeros(i, dtype=np.int64), 309 | ) 310 | if entity_type in o.features: 311 | ofeats = o.features[entity_type] 312 | if not isinstance(ofeats, np.ndarray): 313 | ofeats = np.array(ofeats, dtype=np.float32).reshape( 314 | len(ofeats), len(obs_space.entities[entity_type].features) 315 | ) 316 | features[entity_type].push(ofeats) 317 | else: 318 | features[entity_type].push( 319 | np.zeros((0, len(entity.features)), dtype=np.float32) 320 | ) 321 | if global_entity: 322 | gfeats = o.global_features 323 | if not isinstance(gfeats, np.ndarray): 324 | gfeats = np.array(gfeats, dtype=np.float32) 325 | features["__global__"].push( 326 | gfeats.reshape(1, len(obs_space.global_features)) 327 | ) 328 | 329 | # Merge visibilities 330 | for etype, vis in o.visible.items(): 331 | if etype not in visible: 332 | lengths = [] 333 | for j in range(i): 334 | if etype in obs[j].features: 335 | lengths.append(len(obs[j].features[etype])) 336 | else: 337 | lengths.append(0) 338 | visible[etype] = RaggedBufferBool.from_flattened( 339 | np.ones((sum(lengths), 1), dtype=np.bool_), 340 | lengths=np.array(lengths, dtype=np.int64), 341 | ) 342 | if not isinstance(vis, np.ndarray): 343 | vis = np.array(vis, dtype=np.bool_) 344 | visible[etype].push(vis.reshape(-1, 1)) 345 | 346 | # Merge action masks 347 | for atype, space in action_space.items(): 348 | if atype not in o.actions: 349 | if atype in action_masks: 350 | if isinstance(space, CategoricalActionSpace): 351 | vec_action = action_masks[atype] 352 | assert isinstance(vec_action, VecCategoricalActionMask) 353 | vec_action.actors.push(np.zeros((0, 1), dtype=np.int64)) 354 | if vec_action.mask is not None: 355 | vec_action.mask.push( 356 | np.zeros((0, len(space.index_to_label)), dtype=np.bool_) 357 | ) 358 | elif isinstance(space, SelectEntityActionSpace): 359 | vec_action = action_masks[atype] 360 | assert isinstance(vec_action, VecSelectEntityActionMask) 361 | vec_action.actors.push(np.zeros((0, 1), dtype=np.int64)) 362 | vec_action.actees.push(np.zeros((0, 1), dtype=np.int64)) 363 | else: 364 | raise ValueError( 365 | f"Unsupported action space type: {type(space)}" 366 | ) 367 | continue 368 | action = o.actions[atype] 369 | if atype not in action_masks: 370 | if isinstance(space, CategoricalActionSpace) or isinstance( 371 | space, GlobalCategoricalActionSpace 372 | ): 373 | action_masks[atype] = VecCategoricalActionMask( 374 | empty_ragged_i64(1, i), None 375 | ) 376 | elif isinstance(space, SelectEntityActionSpace): 377 | action_masks[atype] = VecSelectEntityActionMask( 378 | empty_ragged_i64(1, i), empty_ragged_i64(1, i) 379 | ) 380 | else: 381 | raise ValueError(f"Unknown action space type: {space}") 382 | if isinstance(space, CategoricalActionSpace) or isinstance( 383 | space, GlobalCategoricalActionSpace 384 | ): 385 | vec_action = action_masks[atype] 386 | assert isinstance(vec_action, VecCategoricalActionMask) 387 | actor_indices = o._actor_indices(atype, obs_space) 388 | vec_action.actors.push(actor_indices.reshape(-1, 1)) 389 | if action.mask is not None: 390 | if vec_action.mask is None: 391 | vec_action.mask = RaggedBufferBool.from_flattened( 392 | np.ones((0, len(space.index_to_label)), dtype=np.bool_), 393 | np.zeros(i, dtype=np.int64), 394 | ) 395 | amask = action.mask 396 | if not isinstance(amask, np.ndarray): 397 | amask = np.array(amask, dtype=np.bool_) 398 | vec_action.mask.push(amask) 399 | elif vec_action.mask is not None: 400 | vec_action.mask.push( 401 | np.ones( 402 | (len(actor_indices), len(space.index_to_label)), 403 | dtype=np.bool_, 404 | ) 405 | ) 406 | elif isinstance(space, SelectEntityActionSpace): 407 | vec_action = action_masks[atype] 408 | assert isinstance(vec_action, VecSelectEntityActionMask) 409 | actors = o._actor_indices(atype, obs_space).reshape(-1, 1) 410 | vec_action.actors.push(actors) 411 | if len(actors) > 0: 412 | vec_action.actees.push( 413 | o._actee_indices(atype, obs_space).reshape(-1, 1) 414 | ) 415 | else: 416 | vec_action.actees.push(np.zeros((0, 1), dtype=np.int64)) 417 | else: 418 | raise NotImplementedError() 419 | 420 | reward.append(o.reward) 421 | done.append(o.done) 422 | for name, value in o.metrics.items(): 423 | if name not in metrics: 424 | metrics[name] = Metric() 425 | metrics[name].push(value) 426 | 427 | return VecObs( 428 | features, 429 | visible, 430 | action_masks, 431 | np.array(reward, dtype=np.float32), 432 | np.array(done, dtype=np.bool_), 433 | metrics, 434 | ) 435 | 436 | 437 | def empty_ragged_f32(feats: int, sequences: int) -> RaggedBufferF32: 438 | return RaggedBufferF32.from_flattened( 439 | np.zeros((0, feats), dtype=np.float32), 440 | lengths=np.array([0] * sequences, dtype=np.int64), 441 | ) 442 | 443 | 444 | def empty_ragged_i64(feats: int, sequences: int) -> RaggedBufferI64: 445 | return RaggedBufferI64.from_flattened( 446 | np.zeros((0, feats), dtype=np.int64), 447 | lengths=np.array([0] * sequences, dtype=np.int64), 448 | ) 449 | --------------------------------------------------------------------------------