├── pogema ├── wrappers │ ├── __init__.py │ ├── multi_time_limit.py │ ├── persistence.py │ └── metrics.py ├── integrations │ ├── __init__.py │ ├── sample_factory.py │ ├── make_pogema.py │ ├── pettingzoo.py │ └── pymarl.py ├── svg_animation │ ├── __init__.py │ ├── svg_objects.py │ ├── animation_wrapper.py │ └── animation_drawer.py ├── __init__.py ├── grid_registry.py ├── utils.py ├── a_star_policy.py ├── generator.py ├── grid_config.py ├── grid.py └── envs.py ├── local_build.sh ├── requirements.txt ├── LICENSE ├── .github └── workflows │ ├── CI.yml │ └── codeql-analysis.yml ├── setup.py ├── version_history.MD ├── tests ├── test_integrations.py ├── test_deterministic_policy.py ├── test_pogema_env.py └── test_grid.py └── README.md /pogema/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pogema/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pogema/svg_animation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /local_build.sh: -------------------------------------------------------------------------------- 1 | pip3 install --no-cache-dir . -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>1.23.5,<=1.26.4 2 | pydantic>=1.8.2,<=1.9.1 3 | pytest>=6.2.5,<=7.1.2 4 | pettingzoo==1.22.3 5 | tabulate>=0.8.7,<=0.8.10 6 | gymnasium==0.28.1 7 | -------------------------------------------------------------------------------- /pogema/__init__.py: -------------------------------------------------------------------------------- 1 | from gymnasium import register 2 | from pogema.grid_config import GridConfig 3 | from pogema.integrations.make_pogema import pogema_v0 4 | from pogema.svg_animation.animation_wrapper import AnimationMonitor, AnimationConfig 5 | from pogema.a_star_policy import AStarAgent, BatchAStarAgent 6 | 7 | __version__ = '1.5.0a0' 8 | 9 | __all__ = [ 10 | 'GridConfig', 11 | 'pogema_v0', 12 | 'AStarAgent', 'BatchAStarAgent', 13 | "AnimationMonitor", "AnimationConfig", 14 | ] 15 | 16 | register( 17 | id="Pogema-v0", 18 | entry_point="pogema.integrations.make_pogema:make_single_agent_gym", 19 | ) 20 | -------------------------------------------------------------------------------- /pogema/wrappers/multi_time_limit.py: -------------------------------------------------------------------------------- 1 | from gymnasium.wrappers import TimeLimit 2 | 3 | 4 | class MultiTimeLimit(TimeLimit): 5 | def step(self, action): 6 | observation, reward, terminated, truncated, info = self.env.step(action) 7 | self._elapsed_steps += 1 8 | if self._elapsed_steps >= self._max_episode_steps: 9 | truncated = [True] * self.get_num_agents() 10 | return observation, reward, terminated, truncated, info 11 | 12 | def set_elapsed_steps(self, elapsed_steps): 13 | if not self.grid_config.persistent: 14 | raise ValueError("Cannot set elapsed steps for non-persistent environment!") 15 | assert elapsed_steps >= 0 16 | self._elapsed_steps = elapsed_steps 17 | -------------------------------------------------------------------------------- /pogema/integrations/sample_factory.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from gymnasium import Wrapper 4 | 5 | 6 | class IsMultiAgentWrapper(Wrapper): 7 | def __init__(self, env): 8 | super().__init__(env) 9 | 10 | self.is_multiagent = True 11 | 12 | @property 13 | def num_agents(self): 14 | return self.get_num_agents() 15 | 16 | 17 | class MetricsForwardingWrapper(Wrapper): 18 | def step(self, action): 19 | 20 | observations, rewards, terminated, truncated, infos = self.env.step(action) 21 | for info in infos: 22 | if 'metrics' in info: 23 | info.update(episode_extra_stats=deepcopy(info['metrics'])) 24 | return observations, rewards, terminated, truncated, infos 25 | 26 | 27 | class AutoResetWrapper(Wrapper): 28 | def step(self, action): 29 | observations, rewards, terminated, truncated, infos = self.env.step(action) 30 | if all(terminated) or all(truncated): 31 | observations, info = self.env.reset() 32 | return observations, rewards, terminated, truncated, infos 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022, Alexey Skrynnik 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: [ main, dev] 9 | pull_request: 10 | branches: [ main, dev] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.8", "3.9"] 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install flake8 pytest 31 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 32 | - name: Lint with flake8 33 | run: | 34 | # stop the build if there are Python syntax errors or undefined names 35 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 36 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 37 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 38 | - name: Test with pytest 39 | run: | 40 | PYTHONPATH=. pytest -s 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import re 4 | 5 | from setuptools import setup, find_packages 6 | 7 | cur_dir = os.path.abspath(os.path.dirname(__file__)) 8 | with open(os.path.join(cur_dir, 'README.md'), 'rb') as f: 9 | lines = [x.decode('utf-8') for x in f.readlines()] 10 | lines = ''.join([re.sub('^<.*>\n$', '', x) for x in lines]) 11 | long_description = lines 12 | 13 | 14 | def read(*parts): 15 | with codecs.open(os.path.join(cur_dir, *parts), 'r') as fp: 16 | return fp.read() 17 | 18 | 19 | def find_version(*file_paths): 20 | version_file = read(*file_paths) 21 | version_match = re.search( 22 | r"^__version__ = ['\"]([^'\"]*)['\"]", 23 | version_file, 24 | re.M, 25 | ) 26 | if version_match: 27 | return version_match.group(1) 28 | 29 | raise RuntimeError("Unable to find version string.") 30 | 31 | 32 | setup( 33 | name='pogema', 34 | author='Alexey Skrynnik', 35 | license='MIT', 36 | version=find_version("pogema", "__init__.py"), 37 | description='Partially Observable Grid Environment for Multiple Agents', 38 | long_description=long_description, 39 | long_description_content_type='text/markdown', 40 | url='https://github.com/AIRI-Institute/pogema', 41 | install_requires=[ 42 | "gymnasium==0.28.1", 43 | "numpy>1.23.5,<=1.26.4", 44 | "pydantic>=1.8.2,<=1.9.1", 45 | ], 46 | extras_require={ 47 | 48 | }, 49 | package_dir={'': './'}, 50 | packages=find_packages(where='./', include='pogema*'), 51 | include_package_data=True, 52 | python_requires='>=3.7', 53 | ) 54 | -------------------------------------------------------------------------------- /pogema/svg_animation/svg_objects.py: -------------------------------------------------------------------------------- 1 | class SvgObject: 2 | tag = None 3 | 4 | def __init__(self, **kwargs): 5 | self.attributes = kwargs 6 | self.animations = [] 7 | 8 | def add_animation(self, animation): 9 | self.animations.append(animation) 10 | 11 | @staticmethod 12 | def render_attributes(attributes): 13 | result = " ".join([f'{x.replace("_", "-")}="{y}"' for x, y in sorted(attributes.items())]) 14 | return result 15 | 16 | def render(self): 17 | animations = '\n'.join([a.render() for a in self.animations]) if self.animations else None 18 | if animations: 19 | return f"<{self.tag} {self.render_attributes(self.attributes)}> {animations} " 20 | return f"<{self.tag} {self.render_attributes(self.attributes)} />" 21 | 22 | 23 | class Rectangle(SvgObject): 24 | """ 25 | Rectangle class for the SVG. 26 | """ 27 | tag = 'rect' 28 | 29 | def __init__(self, **kwargs): 30 | super().__init__(**kwargs) 31 | self.attributes['y'] = -self.attributes['y'] - self.attributes['height'] 32 | 33 | 34 | class RectangleHref(SvgObject): 35 | """ 36 | Rectangle class for the SVG. 37 | """ 38 | tag = 'use' 39 | 40 | def __init__(self, **kwargs): 41 | super().__init__(**kwargs) 42 | self.attributes['y'] = -self.attributes['y'] - self.attributes['height'] 43 | self.attributes['href'] = "#obstacle" 44 | del self.attributes['height'] 45 | 46 | 47 | class Circle(SvgObject): 48 | """ 49 | Circle class for the SVG. 50 | """ 51 | tag = 'circle' 52 | 53 | def __init__(self, **kwargs): 54 | super().__init__(**kwargs) 55 | self.attributes['cy'] = -self.attributes['cy'] 56 | 57 | 58 | class Line(SvgObject): 59 | """ 60 | Line class for the SVG. 61 | """ 62 | tag = 'line' 63 | 64 | def __init__(self, **kwargs): 65 | super().__init__(**kwargs) 66 | self.attributes['y1'] = -self.attributes['y1'] 67 | self.attributes['y2'] = -self.attributes['y2'] 68 | 69 | 70 | class Animation(SvgObject): 71 | """ 72 | Animation class for the SVG. 73 | """ 74 | tag = 'animate' 75 | 76 | def render(self): 77 | return f"<{self.tag} {self.render_attributes(self.attributes)}/>" 78 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ main ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ main ] 20 | schedule: 21 | - cron: '21 4 * * 1' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://git.io/codeql-language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v2 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v1 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 52 | 53 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 54 | # If this step fails, then you should remove it and run the build manually (see below) 55 | - name: Autobuild 56 | uses: github/codeql-action/autobuild@v1 57 | 58 | # ℹ️ Command-line programs to run using the OS shell. 59 | # 📚 https://git.io/JvXDl 60 | 61 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 62 | # and modify them (or add more) to build your code if your project 63 | # uses a compiled language 64 | 65 | #- run: | 66 | # make bootstrap 67 | # make release 68 | 69 | - name: Perform CodeQL Analysis 70 | uses: github/codeql-action/analyze@v1 71 | -------------------------------------------------------------------------------- /pogema/integrations/make_pogema.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | 3 | from gymnasium import Wrapper 4 | 5 | from pogema import GridConfig 6 | from pogema.envs import _make_pogema 7 | from pogema.integrations.pettingzoo import parallel_env 8 | from pogema.integrations.pymarl import PyMarlPogema 9 | from pogema.integrations.sample_factory import AutoResetWrapper, IsMultiAgentWrapper, MetricsForwardingWrapper 10 | 11 | 12 | def _make_sample_factory_integration(grid_config): 13 | env = _make_pogema(grid_config) 14 | env = MetricsForwardingWrapper(env) 15 | env = IsMultiAgentWrapper(env) 16 | if grid_config.auto_reset is None or grid_config.auto_reset: 17 | env = AutoResetWrapper(env) 18 | return env 19 | 20 | 21 | def _make_py_marl_integration(grid_config, *_, **__): 22 | return PyMarlPogema(grid_config) 23 | 24 | 25 | class SingleAgentWrapper(Wrapper): 26 | 27 | def step(self, action): 28 | observations, rewards, terminated, truncated, infos = self.env.step( 29 | [action] + [self.env.action_space.sample() for _ in range(self.get_num_agents() - 1)]) 30 | return observations[0], rewards[0], terminated[0], truncated[0], infos[0] 31 | 32 | def reset(self, seed: Optional[int] = None, return_info: bool = True, options: Optional[dict] = None, ): 33 | observations, infos = self.env.reset() 34 | if return_info: 35 | return observations[0], infos[0] 36 | else: 37 | return observations[0] 38 | 39 | 40 | def make_single_agent_gym(grid_config: Union[GridConfig, dict] = GridConfig()): 41 | env = _make_pogema(grid_config) 42 | env = SingleAgentWrapper(env) 43 | 44 | return env 45 | 46 | 47 | def make_pogema(grid_config: Union[GridConfig, dict] = GridConfig(), *args, **kwargs): 48 | if isinstance(grid_config, dict): 49 | grid_config = GridConfig(**grid_config) 50 | 51 | if grid_config.integration != 'SampleFactory' and grid_config.auto_reset: 52 | raise KeyError(f"{grid_config.integration} does not support auto_reset") 53 | 54 | if grid_config.integration is None: 55 | return _make_pogema(grid_config) 56 | elif grid_config.integration == 'SampleFactory': 57 | return _make_sample_factory_integration(grid_config) 58 | elif grid_config.integration == 'PyMARL': 59 | return _make_py_marl_integration(grid_config, *args, **kwargs) 60 | elif grid_config.integration == 'rllib': 61 | raise NotImplementedError('Please use PettingZoo integration for rllib') 62 | elif grid_config.integration == 'PettingZoo': 63 | return parallel_env(grid_config) 64 | elif grid_config.integration == 'gymnasium': 65 | return make_single_agent_gym(grid_config) 66 | 67 | raise KeyError(grid_config.integration) 68 | 69 | 70 | pogema_v0 = make_pogema 71 | -------------------------------------------------------------------------------- /pogema/integrations/pettingzoo.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import numpy as np 4 | from pogema import GridConfig 5 | from pogema.envs import _make_pogema 6 | 7 | 8 | def parallel_env(grid_config: GridConfig = GridConfig()): 9 | return PogemaParallel(grid_config) 10 | 11 | 12 | class PogemaParallel: 13 | 14 | def state(self): 15 | return self.pogema.get_state() 16 | 17 | def __init__(self, grid_config: GridConfig, render_mode='ansi'): 18 | self.metadata = {'render_modes': ['ansi'], "name": "pogema"} 19 | self.render_mode = render_mode 20 | self.pogema = _make_pogema(grid_config) 21 | self.possible_agents = ["player_" + str(r) for r in range(self.pogema.get_num_agents())] 22 | self.agent_name_mapping = dict(zip(self.possible_agents, list(range(len(self.possible_agents))))) 23 | self.agents = None 24 | self.num_moves = None 25 | 26 | @functools.lru_cache(maxsize=None) 27 | def observation_space(self, agent): 28 | assert agent in self.possible_agents 29 | return self.pogema.observation_space 30 | 31 | @functools.lru_cache(maxsize=None) 32 | def action_space(self, agent): 33 | assert agent in self.possible_agents 34 | return self.pogema.action_space 35 | 36 | def render(self, mode="human"): 37 | assert mode == 'human' 38 | return self.pogema.render() 39 | 40 | def reset(self, seed=None, options=None): 41 | observations, info = self.pogema.reset(seed=seed, options=options) 42 | self.agents = self.possible_agents[:] 43 | self.num_moves = 0 44 | observations = {agent: observations[self.agent_name_mapping[agent]].astype(np.float32) for agent in self.agents} 45 | return observations 46 | 47 | def step(self, actions): 48 | anm = self.agent_name_mapping 49 | 50 | actions = [actions[agent] if agent in actions else 0 for agent in self.possible_agents] 51 | observations, rewards, terminated, truncated, infos = self.pogema.step(actions) 52 | d_observations = {agent: observations[anm[agent]].astype(np.float32) for agent in 53 | self.agents} 54 | d_rewards = {agent: rewards[anm[agent]] for agent in self.agents} 55 | d_terminated = {agent: terminated[anm[agent]] for agent in self.agents} 56 | d_truncated = {agent: truncated[anm[agent]] for agent in self.agents} 57 | d_infos = {agent: infos[anm[agent]] for agent in self.agents} 58 | 59 | for agent, idx in anm.items(): 60 | if (not self.pogema.grid.is_active[idx] or all(truncated) or all(terminated)) and agent in self.agents: 61 | self.agents.remove(agent) 62 | 63 | return d_observations, d_rewards, d_terminated, d_truncated, d_infos 64 | 65 | @property 66 | def unwrapped(self): 67 | return self 68 | 69 | def close(self): 70 | pass 71 | -------------------------------------------------------------------------------- /pogema/integrations/pymarl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pogema import GridConfig 4 | from pogema.envs import _make_pogema 5 | 6 | 7 | class PyMarlPogema: 8 | 9 | def __init__(self, grid_config, mh_distance=False): 10 | gc = grid_config 11 | self._grid_config: GridConfig = gc 12 | 13 | self.env = _make_pogema(grid_config) 14 | self._mh_distance = mh_distance 15 | self._observations, _ = self.env.reset() 16 | self.max_episode_steps = gc.max_episode_steps 17 | self.episode_limit = gc.max_episode_steps 18 | self.n_agents = self.env.get_num_agents() 19 | 20 | self.spec = None 21 | 22 | @property 23 | def unwrapped(self): 24 | return self 25 | 26 | def step(self, actions): 27 | self._observations, rewards, terminated, truncated, infos = self.env.step(actions) 28 | info = {} 29 | done = all(terminated) or all(truncated) 30 | if done: 31 | for key, value in infos[0]['metrics'].items(): 32 | info[key] = value 33 | 34 | return sum(rewards), done, info 35 | 36 | def get_obs(self): 37 | return np.array([self.get_obs_agent(agent_id) for agent_id in range(self.n_agents)]) 38 | 39 | def get_obs_agent(self, agent_id): 40 | return np.array(self._observations[agent_id]).flatten() 41 | 42 | def get_obs_size(self): 43 | return len(np.array(self._observations[0]).flatten()) 44 | 45 | def get_state(self): 46 | return self.env.get_state() 47 | 48 | def get_state_size(self): 49 | return len(self.get_state()) 50 | 51 | def get_avail_actions(self): 52 | actions = [] 53 | for i in range(self.env.get_num_agents()): 54 | actions.append(self.get_avail_agent_actions(i)) 55 | return actions 56 | 57 | # noinspection PyUnusedLocal 58 | @staticmethod 59 | def get_avail_agent_actions(agent_id): 60 | return list(range(5)) 61 | 62 | @staticmethod 63 | def get_total_actions(): 64 | return 5 65 | 66 | def reset(self): 67 | self._grid_config = self.env.grid_config 68 | self._observations, _ = self.env.reset() 69 | return np.array(self._observations).flatten() 70 | 71 | def save_replay(self): 72 | return 73 | 74 | def render(self, *args, **kwargs): 75 | return self.env.render(*args, **kwargs) 76 | 77 | def get_env_info(self): 78 | env_info = {"state_shape": self.get_state_size(), 79 | "obs_shape": self.get_obs_size(), 80 | "n_actions": self.get_total_actions(), 81 | "n_agents": self.n_agents, 82 | "episode_limit": self.episode_limit, 83 | } 84 | return env_info 85 | 86 | @staticmethod 87 | def get_stats(): 88 | return {} 89 | 90 | def close(self): 91 | return 92 | 93 | def sample_actions(self): 94 | return self.env.sample_actions() 95 | -------------------------------------------------------------------------------- /pogema/grid_registry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pogema.utils import check_grid, render_grid 4 | 5 | GRID_STR_REGISTRY = {} 6 | 7 | 8 | def in_registry(name): 9 | return name in GRID_STR_REGISTRY 10 | 11 | 12 | def get_grid(name): 13 | if in_registry(name): 14 | return GRID_STR_REGISTRY[name] 15 | else: 16 | raise KeyError(f"Grid with name {name} not found") 17 | 18 | 19 | class RegisteredGrid: 20 | FREE = 0 21 | OBSTACLE = 1 22 | 23 | def str_to_grid(self, grid_str): 24 | obstacles = [] 25 | agents = {} 26 | targets = {} 27 | for idx, line in enumerate(grid_str.split()): 28 | row = [] 29 | for char in line: 30 | if char == '.': 31 | row.append(self.FREE) 32 | elif char == '#': 33 | row.append(self.OBSTACLE) 34 | elif 'A' <= char <= 'Z': 35 | targets[char.lower()] = len(obstacles), len(row) 36 | row.append(self.FREE) 37 | elif 'a' <= char <= 'z': 38 | agents[char.lower()] = len(obstacles), len(row) 39 | row.append(self.FREE) 40 | else: 41 | raise KeyError(f"Unsupported symbol '{char}' at line {idx}") 42 | if row: 43 | if obstacles: 44 | assert len(obstacles[-1]) == len(row), f"Wrong string size for row {idx};" 45 | obstacles.append(row) 46 | return obstacles, agents, targets 47 | 48 | def __init__(self, name: str, grid_str: str = None, agents_positions: list = None, agents_targets: list = None): 49 | self.name = name 50 | self.grid_str = grid_str 51 | self.agents_positions = agents_positions 52 | self.agents_targets = agents_targets 53 | 54 | self.obstacles, agents, targets = self.str_to_grid(grid_str) 55 | self.obstacles = np.array(self.obstacles, dtype=np.int32) 56 | 57 | if agents_positions and agents: 58 | raise ValueError("Agents positions are already defined in the grid string!") 59 | if agents_targets and targets: 60 | raise ValueError("Agents targets are already defined in the grid string!") 61 | 62 | if agents: 63 | self.agents_xy = [] 64 | for _, (x, y) in sorted(agents.items()): 65 | self.agents_xy.append([x, y]) 66 | else: 67 | self.agents_xy = agents_positions 68 | 69 | if targets: 70 | self.targets_xy = [] 71 | for _, (x, y) in sorted(targets.items()): 72 | self.targets_xy.append([x, y]) 73 | else: 74 | self.targets_xy = agents_targets 75 | if in_registry(name): 76 | raise ValueError(f"Grid with name {self.name} already registered!") 77 | check_grid(self.obstacles, self.agents_xy, self.targets_xy) 78 | 79 | register_grid(self) 80 | 81 | def get_obstacles(self): 82 | return self.obstacles 83 | 84 | def get_agents_xy(self): 85 | return self.agents_xy 86 | 87 | def get_targets_xy(self): 88 | return self.targets_xy 89 | 90 | def render(self): 91 | render_grid(obstacles=self.get_obstacles(), positions_xy=self.get_agents_xy(), targets_xy=self.get_targets_xy()) 92 | 93 | 94 | def register_grid(rg: RegisteredGrid): 95 | if in_registry(rg.name): 96 | raise KeyError(f"Grid with name {rg.name} already registered") 97 | GRID_STR_REGISTRY[rg.name] = rg 98 | -------------------------------------------------------------------------------- /pogema/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from pydantic import BaseModel 4 | 5 | from typing_extensions import Literal 6 | 7 | 8 | class AgentsTargetsSizeError(Exception): 9 | pass 10 | 11 | 12 | def grid_to_str(grid): 13 | return '\n'.join(''.join('.' if cell == 0 else '#' for cell in row) for row in grid) 14 | 15 | 16 | def check_grid(obstacles, agents_xy, targets_xy): 17 | if bool(agents_xy) != bool(targets_xy): 18 | raise AgentsTargetsSizeError("Agents and targets must be defined together/undefined together!") 19 | 20 | if not agents_xy or not targets_xy: 21 | return 22 | 23 | if len(agents_xy) != len(targets_xy): 24 | raise IndexError("Can't create task. Please provide agents_xy and targets_xy of the same size.") 25 | 26 | # check overlapping of agents 27 | for i in range(len(agents_xy)): 28 | for j in range(i + 1, len(agents_xy)): 29 | if agents_xy[i] == agents_xy[j]: 30 | raise ValueError(f"Agents can't overlap! {agents_xy[i]} is in both {i} and {j} position.") 31 | 32 | for start_xy, finish_xy in zip(agents_xy, targets_xy): 33 | s_x, s_y = start_xy 34 | if obstacles[s_x, s_y]: 35 | raise KeyError(f'Cell is {s_x, s_y} occupied by obstacle.') 36 | f_x, f_y = finish_xy 37 | if obstacles[f_x, f_y]: 38 | raise KeyError(f'Cell is {f_x, f_y} occupied by obstacle.') 39 | 40 | # todo check connectivity of starts and finishes 41 | 42 | 43 | def render_grid(obstacles, positions_xy=None, targets_xy=None, is_active=None, mode='human'): 44 | if positions_xy is None: 45 | positions_xy = [] 46 | if targets_xy is None: 47 | targets_xy = [] 48 | if is_active is None: 49 | if positions_xy: 50 | is_active = [True] * len(positions_xy) 51 | else: 52 | is_active = [] 53 | from io import StringIO 54 | import string 55 | from gymnasium import utils as gym_utils 56 | from contextlib import closing 57 | 58 | outfile = StringIO() if mode == 'ansi' else sys.stdout 59 | chars = string.digits + string.ascii_letters + string.punctuation 60 | positions_map = {(x, y): id_ for id_, (x, y) in enumerate(positions_xy) if is_active[id_]} 61 | finishes_map = {(x, y): id_ for id_, (x, y) in enumerate(targets_xy) if is_active[id_]} 62 | for line_index, line in enumerate(obstacles): 63 | out = '' 64 | for cell_index, cell in enumerate(line): 65 | if cell == CommonSettings().FREE: 66 | agent_id = positions_map.get((line_index, cell_index), None) 67 | finish_id = finishes_map.get((line_index, cell_index), None) 68 | 69 | if agent_id is not None: 70 | out += str(gym_utils.colorize(' ' + chars[agent_id % len(chars)] + ' ', color='red', bold=True, 71 | highlight=False)) 72 | elif finish_id is not None: 73 | out += str( 74 | gym_utils.colorize('|' + chars[finish_id % len(chars)] + '|', 'white', highlight=False)) 75 | else: 76 | out += str(gym_utils.colorize(str(' . '), 'white', highlight=False)) 77 | else: 78 | out += str(gym_utils.colorize(str(' '), 'cyan', bold=False, highlight=True)) 79 | out += '\n' 80 | outfile.write(out) 81 | 82 | if mode != 'human': 83 | with closing(outfile): 84 | return outfile.getvalue() 85 | 86 | 87 | class CommonSettings(BaseModel): 88 | MOVES: list = [[0, 0], [-1, 0], [1, 0], [0, -1], [0, 1], ] 89 | FREE: Literal[0] = 0 90 | OBSTACLE: Literal[1] = 1 91 | empty_outside: bool = True 92 | -------------------------------------------------------------------------------- /pogema/wrappers/persistence.py: -------------------------------------------------------------------------------- 1 | from gymnasium import Wrapper 2 | 3 | 4 | class AgentState: 5 | def __init__(self, x, y, tx, ty, step, active): 6 | self.x = x 7 | self.y = y 8 | self.tx = tx 9 | self.ty = ty 10 | self.step = step 11 | self.active = active 12 | 13 | def get_xy(self): 14 | return self.x, self.y 15 | 16 | def get_target_xy(self): 17 | return self.tx, self.ty 18 | 19 | def is_active(self): 20 | return self.active 21 | 22 | def get_step(self): 23 | return self.step 24 | 25 | def __eq__(self, other): 26 | o = other 27 | return self.x == o.x and self.y == o.y and self.tx == o.tx and self.ty == o.ty and self.active == o.active 28 | 29 | def __str__(self): 30 | return str([self.x, self.y, self.tx, self.ty, self.step, self.active]) 31 | 32 | 33 | class PersistentWrapper(Wrapper): 34 | def __init__(self, env, xy_offset=None): 35 | super().__init__(env) 36 | self._step = None 37 | self._agent_states = None 38 | self._xy_offset = xy_offset 39 | 40 | def step(self, action): 41 | result = self.env.step(action) 42 | self._step += 1 43 | for agent_idx in range(self.get_num_agents()): 44 | agent_state = self._get_agent_state(self.grid, agent_idx) 45 | if agent_state != self._agent_states[agent_idx][-1]: 46 | self._agent_states[agent_idx].append(agent_state) 47 | 48 | return result 49 | 50 | def step_back(self): 51 | if self._step <= 0: 52 | return False 53 | self._step -= 1 54 | self.set_elapsed_steps(self._step) 55 | for idx in reversed(range(self.get_num_agents())): 56 | 57 | if self._step < self._agent_states[idx][-1].step: 58 | self._agent_states[idx].pop() 59 | state = self._agent_states[idx][-1] 60 | 61 | if state.active: 62 | self.grid.show_agent(idx) 63 | else: 64 | self.grid.hide_agent(idx) 65 | self.grid.move_agent_to_cell(idx, state.x, state.y) 66 | self.grid.finishes_xy[idx] = state.tx, state.ty 67 | 68 | return True 69 | 70 | def _get_agent_state(self, grid, agent_idx): 71 | x, y = grid.positions_xy[agent_idx] 72 | tx, ty = grid.finishes_xy[agent_idx] 73 | active = grid.is_active[agent_idx] 74 | if self._xy_offset: 75 | x += self._xy_offset 76 | y += self._xy_offset 77 | tx += self._xy_offset 78 | ty += self._xy_offset 79 | return AgentState(x, y, tx, ty, self._step, active) 80 | 81 | def reset(self, **kwargs): 82 | result = self.env.reset(**kwargs) 83 | 84 | self._step = 0 85 | 86 | self._agent_states = [] 87 | for agent_idx in range(self.get_num_agents()): 88 | self._agent_states.append([self._get_agent_state(self.grid, agent_idx)]) 89 | 90 | return result 91 | 92 | @staticmethod 93 | def agent_state_to_full_list(agent_states, num_steps): 94 | result = [] 95 | current_state_id = 0 96 | # going over num_steps + 1, to handle last step 97 | for episode_step in range(num_steps + 1): 98 | if current_state_id < len(agent_states) - 1 and agent_states[current_state_id + 1].step == episode_step: 99 | current_state_id += 1 100 | result.append(agent_states[current_state_id]) 101 | return result 102 | 103 | @classmethod 104 | def decompress_history(cls, history): 105 | max_steps = max([agent_states[-1].step for agent_states in history]) 106 | result = [cls.agent_state_to_full_list(agent_states, max_steps) for agent_states in history] 107 | return result 108 | 109 | def get_full_history(self): 110 | return [self.agent_state_to_full_list(agent_states, self._step) for agent_states in self._agent_states] 111 | 112 | def get_history(self): 113 | return self._agent_states 114 | -------------------------------------------------------------------------------- /version_history.MD: -------------------------------------------------------------------------------- 1 | The development history of POGEMA, starting from version 1.0.0. 2 | 3 | Version 1.5.0 (August, 2025) 4 | • Added support of custom targets_xy for lifelong MAPF. 5 | • Improved work with rectangular grids. Added width and height attributes to GridConfig. 6 | • Added method update_config to properly update all attributes of GridConfig. 7 | • Added more tests for new features - custom targets_xy and width/height attributes. 8 | 9 | Version 1.4.0 (April 5, 2025) 10 | • Extended limits for size of maps and number of agents. 11 | • Fixed ep_length value. 12 | • Updated some tests. 13 | 14 | Version 1.3.0 (June 13, 2024) 15 | 16 | • Updates for integration with newer version of gymnasium. 17 | • Refactored AgentsDensityWrapper for modularity and clarity. 18 | • Introduced RuntimeMetricWrapper for runtime monitoring. 19 | • Enhanced map generation methods and added new metrics like SOC_Makespan. 20 | • Animation improvements for better visualization. 21 | 22 | Version 1.2.2 (September 22, 2023) 23 | 24 | • Implemented soft collision handling for agent interactions. 25 | • Improved lifelong scenario seeding for consistent agent behavior. 26 | • Enhanced metric logging for better integration with PyMARL. 27 | 28 | Version 1.2.0 (August 30, 2023) 29 | 30 | • Fixed import issues with Literal and animation issues. 31 | • Improved visualizations, including grid lines and border toggles. 32 | 33 | Version 1.1.0 (March 30, 2023) 34 | 35 | • Updated dependencies for gymnasium and PettingZoo. 36 | • Added an option to remove animation borders for cleaner outputs. 37 | • Fixed animation bugs for stuck agents. 38 | 39 | Version 1.0.0 (February 2023) 40 | 41 | • Launched core features, including A* policy implementations* and CI/CD support. 42 | • Introduced basic visualization and fixed animation bugs. 43 | 44 | Post-Version Updates 45 | 46 | • Adjusted the number of agents in setups. 47 | • Updated package metadata for better compatibility. 48 | • Addressed legacy issues and improved benchmark generation. 49 | 50 | Version 1.1.6 (February 21, 2023) 51 | 52 | • Fixed static animation issues and added grid object rendering. 53 | 54 | Version 1.1.5 (December 28, 2022) 55 | 56 | • Fixed Python 3.7 compatibility issues and added map registries for better management. 57 | • Introduced an attrition metric. 58 | 59 | Version 1.1.4 (November 18, 2022) 60 | 61 | • Fixed flake8 warnings for improved code quality. 62 | 63 | Version 1.1.3 (October 28, 2022) 64 | 65 | • Corrected random seed initialization for PogemaLifeLong. 66 | • Optimized animation behavior. 67 | 68 | Version 1.1.2 (October 5, 2022) 69 | 70 | • Upgraded SVG animations for better compression. 71 | 72 | Version 1.1.1 (August 30, 2022) 73 | 74 | • Added map_name attributes for clearer references. 75 | • Implemented new observation types (MAPF, POMAPF) and enhanced metrics aggregation. 76 | 77 | Version 1.0.x and Earlier 78 | 79 | • Introduced cooperative reward wrappers and lifelong environment versions. 80 | • Dropped Python 3.6 support and refined animation handling. 81 | 82 | Version 1.0.3 (June 29, 2022) 83 | 84 | • Fixed rendering issues for inactive agents. 85 | 86 | Version 1.0.2 (June 27, 2022) 87 | 88 | • Enhanced customization for agent and target positions. 89 | 90 | Pre-1.0.2 Development (June 2022) 91 | 92 | • Improved tests, refactored code, and removed unnecessary dependencies. 93 | • Introduced the PogemaLifeLong class with target generation and metrics tailored for lifelong scenarios. 94 | • Introduced customizable map rules and agent/target positions. 95 | • Simplified installation by removing unnecessary dependencies. 96 | 97 | 98 | Version 1.0.0 (March 31, 2022) 99 | 100 | • Added predefined configurations for grid environments and improved visualization. 101 | • Integrated PettingZoo support and enhanced usability with better examples. 102 | • Introduced grid_config class for environment configuration and improved state management. 103 | • Added methods for relative position observations and fixed PettingZoo compatibility. 104 | • Documentation improvements for better user guidance. 105 | 106 | -------------------------------------------------------------------------------- /pogema/a_star_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pogema import GridConfig 3 | 4 | from heapq import heappop, heappush 5 | 6 | INF = 1e7 7 | 8 | 9 | class GridMemory: 10 | def __init__(self, start_r=64): 11 | self._memory = np.zeros(shape=(start_r * 2 + 1, start_r * 2 + 1), dtype=bool) 12 | 13 | @staticmethod 14 | def _try_to_insert(x, y, source, target): 15 | r = source.shape[0] // 2 16 | try: 17 | target[x - r:x + r + 1, y - r:y + r + 1] = source 18 | return True 19 | except ValueError: 20 | return False 21 | 22 | def _increase_memory(self): 23 | m = self._memory 24 | r = self._memory.shape[0] 25 | self._memory = np.zeros(shape=(r * 2 + 1, r * 2 + 1)) 26 | assert self._try_to_insert(r, r, m, self._memory) 27 | 28 | def update(self, x, y, obstacles): 29 | while True: 30 | r = self._memory.shape[0] // 2 31 | if self._try_to_insert(r + x, r + y, obstacles, self._memory): 32 | break 33 | self._increase_memory() 34 | 35 | def is_obstacle(self, x, y): 36 | r = self._memory.shape[0] // 2 37 | if -r <= x <= r and -r <= y <= r: 38 | return self._memory[r + x, r + y] 39 | else: 40 | return False 41 | 42 | 43 | class Node: 44 | def __init__(self, coord: (int, int) = (INF, INF), g: int = 0, h: int = 0): 45 | self.i, self.j = coord 46 | self.g = g 47 | self.h = h 48 | self.f = g + h 49 | 50 | def __lt__(self, other): 51 | if self.f != other.f: 52 | return self.f < other.f 53 | elif self.g != other.g: 54 | return self.g < other.g 55 | else: 56 | return self.i < other.i or self.j < other.j 57 | 58 | 59 | def h(node, target): 60 | nx, ny = node 61 | tx, ty = target 62 | return abs(nx - tx) + abs(ny - ty) 63 | 64 | 65 | def a_star(start, target, grid: GridMemory, max_steps=10000): 66 | open_ = list() 67 | closed = {start: None} 68 | 69 | heappush(open_, Node(start, 0, h(start, target))) 70 | 71 | for step in range(int(max_steps)): 72 | u = heappop(open_) 73 | 74 | for n in [(u.i - 1, u.j), (u.i + 1, u.j), (u.i, u.j - 1), (u.i, u.j + 1)]: 75 | if not grid.is_obstacle(*n) and n not in closed: 76 | heappush(open_, Node(n, u.g + 1, h(n, target))) 77 | closed[n] = (u.i, u.j) 78 | 79 | if step >= max_steps or (u.i, u.j) == target or len(open_) == 0: 80 | break 81 | 82 | next_node = target if target in closed else None 83 | path = [] 84 | while next_node is not None: 85 | path.append(next_node) 86 | next_node = closed[next_node] 87 | 88 | return list(reversed(path)) 89 | 90 | 91 | class AStarAgent: 92 | def __init__(self, seed=0): 93 | self._moves = GridConfig().MOVES 94 | self._reverse_actions = {tuple(self._moves[i]): i for i in range(len(self._moves))} 95 | 96 | self._gm = None 97 | self._saved_xy = None 98 | self.clear_state() 99 | self._rnd = np.random.default_rng(seed) 100 | 101 | def act(self, obs): 102 | xy, target_xy, obstacles, agents = obs['xy'], obs['target_xy'], obs['obstacles'], obs['agents'] 103 | 104 | 105 | if self._saved_xy is not None and h(self._saved_xy, xy) > 1: 106 | raise IndexError("Agent moved more than 1 step. Please, call clear_state method before new episode.") 107 | if self._saved_xy is not None and h(self._saved_xy, xy) == 0 and xy != target_xy: 108 | return self._rnd.integers(len(self._moves)) 109 | self._gm.update(*xy, obstacles) 110 | path = a_star(xy, target_xy, self._gm, ) 111 | if len(path) <= 1: 112 | action = 0 113 | else: 114 | (x, y), (tx, ty), *_ = path 115 | action = self._reverse_actions[tx - x, ty - y] 116 | 117 | self._saved_xy = xy 118 | return action 119 | 120 | def clear_state(self): 121 | self._saved_xy = None 122 | self._gm = GridMemory() 123 | 124 | 125 | class BatchAStarAgent: 126 | def __init__(self): 127 | self.astar_agents = {} 128 | 129 | def act(self, observations): 130 | actions = [] 131 | for idx, obs in enumerate(observations): 132 | if idx not in self.astar_agents: 133 | self.astar_agents[idx] = AStarAgent() 134 | actions.append(self.astar_agents[idx].act(obs)) 135 | return actions 136 | 137 | def reset_states(self): 138 | self.astar_agents = {} 139 | -------------------------------------------------------------------------------- /tests/test_integrations.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import numpy as np 4 | 5 | from pogema import GridConfig 6 | from pogema.integrations.make_pogema import pogema_v0 7 | 8 | 9 | def test_gym_creation(): 10 | import gymnasium 11 | 12 | env = gymnasium.make("Pogema-v0", grid_config=GridConfig(integration='gymnasium')) 13 | env.reset() 14 | 15 | 16 | def test_integrations(): 17 | for integration in ['SampleFactory', 'PyMARL', 'gymnasium', "PettingZoo", None]: 18 | env = pogema_v0(grid_config=GridConfig(integration=integration)) 19 | env.reset() 20 | 21 | 22 | def test_sample_factory_integration(): 23 | env = pogema_v0(GridConfig(seed=7, num_agents=4, size=12, integration='SampleFactory')) 24 | env.reset() 25 | 26 | assert env.num_agents == 4 27 | assert env.is_multiagent is True 28 | 29 | # testing auto-reset wrapper 30 | for _ in range(2): 31 | dones = [False] 32 | infos = None 33 | while True: 34 | _, _, terminated, truncated, infos = env.step(env.sample_actions()) 35 | if all(terminated) or all(truncated): 36 | break 37 | 38 | assert np.isclose(infos[0]['episode_extra_stats']['ISR'], 0.0) 39 | assert np.isclose(infos[0]['episode_extra_stats']['CSR'], 0.0) 40 | 41 | 42 | def test_pymarl_integration(): 43 | gc = GridConfig(seed=7, num_agents=4, obs_radius=3, max_episode_steps=16, integration='PyMARL') 44 | env = pogema_v0(gc) 45 | 46 | _state = [0.14285714285714285, 1.0, 1.0, 0.5714285714285714, 0.42857142857142855, 0.7142857142857143, 47 | 0.8571428571428571, 0.2857142857142857, 0.8571428571428571, 0.42857142857142855, 0.42857142857142855, 1.0, 48 | 0.5714285714285714, 0.7142857142857143, 0.14285714285714285, 0.42857142857142855, 0.0, 0.0, 0.0, 0.0, 0.0, 49 | 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50 | 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 51 | 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 52 | 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 53 | 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 54 | 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 55 | 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 56 | 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 57 | 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 58 | 0.0, 0.0] 59 | assert np.isclose(_state, env.get_state()).all() 60 | 61 | assert env.episode_limit == 16 62 | assert env.get_env_info()['state_shape'] == 212 63 | assert env.get_env_info()['obs_shape'] == 147 64 | assert env.get_env_info()['n_agents'] == 4 65 | assert env.get_env_info()['episode_limit'] == 16 66 | 67 | num_agents, dimension = env.get_obs().shape 68 | assert num_agents == gc.num_agents 69 | assert dimension == reduce(lambda a, b: a * b, env.env.observation_space.shape) 70 | assert dimension == env.get_obs_size() 71 | assert env.get_state_size() == env.get_state().shape[0] 72 | 73 | done = False 74 | cnt = 0 75 | while not done: 76 | assert cnt < gc.max_episode_steps 77 | _, done, _ = env.step(env.sample_actions()) 78 | cnt += 1 79 | 80 | 81 | def test_single_agent_gym_integration(): 82 | gc = GridConfig(seed=7, num_agents=1, integration='gymnasium') 83 | env = pogema_v0(gc) 84 | 85 | obs, info = env.reset() 86 | 87 | assert obs.shape == env.observation_space.shape 88 | done = False 89 | 90 | cnt = 0 91 | while True: 92 | assert cnt < gc.max_episode_steps 93 | obs, reward, terminated, truncated, info = env.step(env.action_space.sample()) 94 | assert obs.shape == env.observation_space.shape 95 | assert isinstance(reward, float) 96 | assert isinstance(done, bool) 97 | assert isinstance(info, dict) 98 | cnt += 1 99 | if terminated or truncated: 100 | break 101 | 102 | 103 | def test_petting_zoo(): 104 | from pettingzoo.test import api_test, parallel_api_test, render_test 105 | # from pettingzoo.test import render_test 106 | 107 | gc = GridConfig(num_agents=16, size=16, integration='PettingZoo') 108 | 109 | parallel_api_test(pogema_v0(gc), num_cycles=1000) 110 | 111 | try: 112 | from pettingzoo.utils import parallel_to_aec 113 | 114 | def env(grid_config: GridConfig = GridConfig(num_agents=20, size=16)): 115 | return parallel_to_aec(pogema_v0(grid_config)) 116 | 117 | api_test(env(gc), num_cycles=1000, verbose_progress=True) 118 | # todo fix this 119 | # render_test(lambda: pogema_v0(gc)) 120 | except ImportError: 121 | pass 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | [![Pogema logo](https://raw.githubusercontent.com/Tviskaron/pogema-pics/main/pogema-logo-v1.svg)](https://github.com/Cognitive-AI-Systems/pogema) 5 | 6 | **Partially-Observable Grid Environment for Multiple Agents** 7 | 8 | [![CodeFactor](https://www.codefactor.io/repository/github/tviskaron/pogema/badge)](https://www.codefactor.io/repository/github/tviskaron/pogema) 9 | [![Downloads](https://static.pepy.tech/badge/pogema)](https://pepy.tech/project/pogema) 10 | [![CI](https://github.com/Cognitive-AI-Systems/pogema/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/Cognitive-AI-Systems/pogema/actions/workflows/CI.yml) 11 | [![CodeQL](https://github.com/Cognitive-AI-Systems/pogema/actions/workflows/codeql-analysis.yml/badge.svg)](https://github.com/Cognitive-AI-Systems/pogema/actions/workflows/codeql-analysis.yml) 12 | 13 |
14 | 15 | Partially Observable Multi-Agent Pathfinding (PO-MAPF) is a challenging problem that fundamentally differs from regular MAPF. In regular MAPF, a central controller constructs a joint plan for all agents before they start execution. However, PO-MAPF is intrinsically decentralized, and decision-making, such as planning, is interleaved with execution. At each time step, an agent receives a local observation of the environment and decides which action to take. The ultimate goal for the agents is to reach their goals while avoiding collisions with each other and the static obstacles. 16 | 17 | POGEMA stands for Partially-Observable Grid Environment for Multiple Agents. It is a grid-based environment that was specifically designed to be flexible, tunable, and scalable. It can be tailored to a variety of PO-MAPF settings. Currently, the somewhat standard setting is supported, in which agents can move between the cardinal-adjacent cells of the grid, and each action (move or wait) takes one time step. No information sharing occurs between the agents. POGEMA can generate random maps and start/goal locations for the agents. It also accepts custom maps as input. 18 | 19 | ## Installation 20 | 21 | Just install from PyPI: 22 | 23 | ```pip install pogema``` 24 | 25 | ## Using Example 26 | 27 | ```python 28 | from pogema import pogema_v0, GridConfig 29 | 30 | env = pogema_v0(grid_config=GridConfig()) 31 | 32 | obs, info = env.reset() 33 | 34 | while True: 35 | # Using random policy to make actions 36 | obs, reward, terminated, truncated, info = env.step(env.sample_actions()) 37 | env.render() 38 | if all(terminated) or all(truncated): 39 | break 40 | ``` 41 | 42 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/19dSEGTQeM3oVJtVjpC162t1XApmv6APc?usp=sharing) 43 | 44 | 45 | ## Baselines and Evaluation Protocol 46 | The baseline implementations and evaluation pipeline are presented in [POGEMA Benchmark](https://github.com/Cognitive-AI-Systems/pogema-benchmark) repository. 47 | 48 | ## Interfaces 49 | Pogema provides integrations with a range of MARL frameworks: PettingZoo, PyMARL and SampleFactory. 50 | 51 | ### PettingZoo 52 | 53 | ```python 54 | from pogema import pogema_v0, GridConfig 55 | 56 | # Create Pogema environment with PettingZoo interface 57 | env = pogema_v0(GridConfig(integration="PettingZoo")) 58 | ``` 59 | 60 | ### PyMARL 61 | 62 | ```python 63 | from pogema import pogema_v0, GridConfig 64 | 65 | env = pogema_v0(GridConfig(integration="PyMARL")) 66 | ``` 67 | 68 | ### SampleFactory 69 | 70 | ```python 71 | from pogema import pogema_v0, GridConfig 72 | 73 | env = pogema_v0(GridConfig(integration="SampleFactory")) 74 | ``` 75 | 76 | ### Gymnasium 77 | 78 | Pogema is fully capable for single-agent pathfinding tasks. 79 | 80 | ```python 81 | from pogema import pogema_v0, GridConfig 82 | 83 | env = pogema_v0(GridConfig(integration="gymnasium")) 84 | ``` 85 | 86 | Example of training [stable-baselines3](https://github.com/DLR-RM/stable-baselines3) DQN to solve single-agent pathfinding tasks: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1vPwTd0PnzpWrB-bCHqoLSVwU9G9Lgcmv?usp=sharing) 87 | 88 | 89 | 90 | 91 | ## Customization 92 | 93 | ### Random maps 94 | ```python 95 | from pogema import pogema_v0, GridConfig 96 | 97 | # Define random configuration 98 | grid_config = GridConfig(num_agents=4, # number of agents 99 | size=8, # size of the grid 100 | density=0.4, # obstacle density 101 | seed=1, # set to None for random 102 | # obstacles, agents and targets 103 | # positions at each reset 104 | max_episode_steps=128, # horizon 105 | obs_radius=3, # defines field of view 106 | ) 107 | 108 | env = pogema_v0(grid_config=grid_config) 109 | env.reset() 110 | env.render() 111 | 112 | ``` 113 | 114 | ### Custom maps 115 | ```python 116 | from pogema import pogema_v0, GridConfig 117 | 118 | grid = """ 119 | .....#..... 120 | .....#..... 121 | ........... 122 | .....#..... 123 | .....#..... 124 | #.####..... 125 | .....###.## 126 | .....#..... 127 | .....#..... 128 | ........... 129 | .....#..... 130 | """ 131 | 132 | # Define new configuration with 8 randomly placed agents 133 | grid_config = GridConfig(map=grid, num_agents=8) 134 | 135 | # Create custom Pogema environment 136 | env = pogema_v0(grid_config=grid_config) 137 | ``` 138 | 139 | 140 | 141 | 142 | ## Citation 143 | If you use this repository in your research or wish to cite it, please make a reference to our paper: 144 | ``` 145 | @inproceedings{skrynnik2025pogema, 146 | title={POGEMA: A Benchmark Platform for Cooperative Multi-Agent Pathfinding}, 147 | author={Skrynnik, Alexey and Andreychuk, Anton and Borzilov, Anatolii and Chernyavskiy, Alexander and Yakovlev, Konstantin and Panov, Aleksandr}, 148 | booktitle={The Thirteenth International Conference on Learning Representations}, 149 | year={2025} 150 | } 151 | ``` 152 | -------------------------------------------------------------------------------- /pogema/generator.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | 6 | from pogema import GridConfig 7 | 8 | 9 | def generate_obstacles(grid_config: GridConfig, rnd=None): 10 | if rnd is None: 11 | rnd = np.random.default_rng(grid_config.seed) 12 | return rnd.binomial(1, grid_config.density, (grid_config.height, grid_config.width)) 13 | 14 | 15 | def bfs(grid, moves, start_id, free_cell): 16 | q = [] 17 | current_id = start_id 18 | 19 | components = [0 for _ in range(start_id)] 20 | 21 | size_x = len(grid) 22 | size_y = len(grid[0]) 23 | 24 | for x in range(size_x): 25 | for y in range(size_y): 26 | if grid[x, y] != free_cell: 27 | continue 28 | 29 | grid[x, y] = current_id 30 | components.append(1) 31 | q.append((x, y)) 32 | 33 | while len(q): 34 | cx, cy = q.pop(0) 35 | 36 | for dx, dy in moves: 37 | nx, ny = cx + dx, cy + dy 38 | if 0 <= nx < size_x and 0 <= ny < size_y: 39 | if grid[nx, ny] == free_cell: 40 | grid[nx, ny] = current_id 41 | components[current_id] += 1 42 | q.append((nx, ny)) 43 | 44 | current_id += 1 45 | return components 46 | 47 | 48 | def placing_fast(order, components, grid, start_id, num_agents): 49 | link_to_next = [-1 for _ in range(len(order))] 50 | colors = [-1 for _ in range(len(components))] 51 | size = len(order) 52 | for index in range(size): 53 | reversed_index = len(order) - index - 1 54 | color = grid[order[reversed_index]] 55 | link_to_next[reversed_index] = colors[color] 56 | colors[color] = reversed_index 57 | 58 | positions_xy = [] 59 | finishes_xy = [] 60 | 61 | for index in range(len(order)): 62 | next_index = link_to_next[index] 63 | if next_index == -1: 64 | continue 65 | 66 | positions_xy.append(order[index]) 67 | finishes_xy.append(order[next_index]) 68 | 69 | link_to_next[next_index] = -1 70 | if len(finishes_xy) >= num_agents: 71 | break 72 | return positions_xy, finishes_xy 73 | 74 | 75 | def placing(order, components, grid, start_id, num_agents): 76 | requests = [[] for _ in range(len(components))] 77 | 78 | done_requests = 0 79 | positions_xy = [] 80 | finishes_xy = [(-1, -1) for _ in range(num_agents)] 81 | for x, y in order: 82 | if grid[x, y] < start_id: 83 | continue 84 | 85 | id_ = grid[x, y] 86 | grid[x, y] = 0 87 | 88 | if requests[id_]: 89 | tt = requests[id_].pop() 90 | finishes_xy[tt] = x, y 91 | done_requests += 1 92 | continue 93 | 94 | if len(positions_xy) >= num_agents: 95 | if done_requests >= num_agents: 96 | break 97 | continue 98 | 99 | if components[id_] >= 2: 100 | components[id_] -= 2 101 | requests[id_].append(len(positions_xy)) 102 | positions_xy.append((x, y)) 103 | 104 | return positions_xy, finishes_xy 105 | 106 | def generate_from_possible_positions(grid_config: GridConfig): 107 | if len(grid_config.possible_agents_xy) < grid_config.num_agents or len(grid_config.possible_targets_xy) < grid_config.num_agents: 108 | raise OverflowError(f"Can't create task. Not enough possible positions for {grid_config.num_agents} agents.") 109 | rng = np.random.default_rng(grid_config.seed) 110 | rng.shuffle(grid_config.possible_agents_xy) 111 | rng.shuffle(grid_config.possible_targets_xy) 112 | return grid_config.possible_agents_xy[:grid_config.num_agents], grid_config.possible_targets_xy[:grid_config.num_agents] 113 | 114 | 115 | def generate_positions_and_targets_fast(obstacles, grid_config): 116 | c = grid_config 117 | grid = obstacles.copy() 118 | 119 | start_id = max(c.FREE, c.OBSTACLE) + 1 120 | 121 | components = bfs(grid, tuple(c.MOVES), start_id, free_cell=c.FREE) 122 | height, width = obstacles.shape 123 | order = [(x, y) for x in range(height) for y in range(width) if grid[x, y] >= start_id] 124 | np.random.default_rng(c.seed).shuffle(order) 125 | 126 | return placing(order=order, components=components, grid=grid, start_id=start_id, num_agents=c.num_agents) 127 | 128 | def generate_from_possible_targets(rnd_generator, possible_positions, position): 129 | new_target = tuple(rnd_generator.choice(possible_positions, 1)[0]) 130 | while new_target == position: 131 | new_target = tuple(rnd_generator.choice(possible_positions, 1)[0]) 132 | return new_target 133 | 134 | def generate_new_target(rnd_generator, point_to_component, component_to_points, position): 135 | component_id = point_to_component[position] 136 | component = component_to_points[component_id] 137 | new_target = tuple(*rnd_generator.choice(component, 1)) 138 | while new_target == position: 139 | new_target = tuple(*rnd_generator.choice(component, 1)) 140 | return new_target 141 | 142 | 143 | def get_components(grid_config, obstacles, positions_xy, target_xy): 144 | c = grid_config 145 | grid = obstacles.copy() 146 | 147 | start_id = max(c.FREE, c.OBSTACLE) + 1 148 | bfs(grid, tuple(c.MOVES), start_id, free_cell=c.FREE) 149 | height, width = obstacles.shape 150 | 151 | comp_to_points = defaultdict(list) 152 | point_to_comp = {} 153 | for x in range(height): 154 | for y in range(width): 155 | comp_to_points[grid[x, y]].append((x, y)) 156 | point_to_comp[(x, y)] = grid[x, y] 157 | return comp_to_points, point_to_comp 158 | 159 | 160 | def time_it(func, num_iterations): 161 | start = time.monotonic() 162 | for index in range(num_iterations): 163 | grid_config = GridConfig(num_agents=64, size=64, seed=index) 164 | obstacles = generate_obstacles(grid_config) 165 | result = func(obstacles, grid_config, ) 166 | if index == 0 and num_iterations > 1: 167 | print(result) 168 | finish = time.monotonic() 169 | 170 | return finish - start 171 | 172 | 173 | def main(): 174 | num_iterations = 1000 175 | time_it(generate_positions_and_targets_fast, num_iterations=1) 176 | print('fast:', time_it(generate_positions_and_targets_fast, num_iterations=num_iterations)) 177 | 178 | 179 | if __name__ == '__main__': 180 | main() 181 | -------------------------------------------------------------------------------- /pogema/wrappers/metrics.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | from gymnasium import Wrapper 5 | 6 | 7 | class AbstractMetric(Wrapper): 8 | def _compute_stats(self, step, is_on_goal, finished): 9 | raise NotImplementedError 10 | 11 | def __init__(self, env): 12 | super().__init__(env) 13 | self._current_step = 0 14 | 15 | def step(self, action): 16 | obs, reward, terminated, truncated, infos = self.env.step(action) 17 | finished = all(truncated) or all(terminated) 18 | 19 | metric = self._compute_stats(self._current_step, self.was_on_goal, finished) 20 | self._current_step += 1 21 | if finished: 22 | self._current_step = 0 23 | 24 | if metric: 25 | if 'metrics' not in infos[0]: 26 | infos[0]['metrics'] = {} 27 | infos[0]['metrics'].update(**metric) 28 | 29 | return obs, reward, terminated, truncated, infos 30 | 31 | 32 | class LifeLongAverageThroughputMetric(AbstractMetric): 33 | 34 | def __init__(self, env): 35 | super().__init__(env) 36 | self._solved_instances = 0 37 | 38 | def _compute_stats(self, step, is_on_goal, finished): 39 | for agent_idx, on_goal in enumerate(is_on_goal): 40 | if on_goal: 41 | self._solved_instances += 1 42 | if finished: 43 | result = {'avg_throughput': self._solved_instances / self.grid_config.max_episode_steps} 44 | self._solved_instances = 0 45 | return result 46 | 47 | 48 | class NonDisappearCSRMetric(AbstractMetric): 49 | 50 | def _compute_stats(self, step, is_on_goal, finished): 51 | if finished: 52 | return {'CSR': float(all(is_on_goal))} 53 | 54 | 55 | class NonDisappearISRMetric(AbstractMetric): 56 | 57 | def _compute_stats(self, step, is_on_goal, finished): 58 | if finished: 59 | return {'ISR': float(sum(is_on_goal)) / self.get_num_agents()} 60 | 61 | 62 | class NonDisappearEpLengthMetric(AbstractMetric): 63 | 64 | def _compute_stats(self, step, is_on_goal, finished): 65 | if finished: 66 | return {'ep_length': step + 1} 67 | 68 | 69 | class EpLengthMetric(AbstractMetric): 70 | def __init__(self, env): 71 | super().__init__(env) 72 | self._solve_time = [None for _ in range(self.get_num_agents())] 73 | 74 | def _compute_stats(self, step, is_on_goal, finished): 75 | for idx, on_goal in enumerate(is_on_goal): 76 | if self._solve_time[idx] is None: 77 | if on_goal or finished: 78 | self._solve_time[idx] = step 79 | 80 | if finished: 81 | result = {'ep_length': sum(self._solve_time) / self.get_num_agents() + 1} 82 | self._solve_time = [None for _ in range(self.get_num_agents())] 83 | return result 84 | 85 | 86 | class CSRMetric(AbstractMetric): 87 | def __init__(self, env): 88 | super().__init__(env) 89 | self._solved_instances = 0 90 | 91 | def _compute_stats(self, step, is_on_goal, finished): 92 | self._solved_instances += sum(is_on_goal) 93 | if finished: 94 | results = {'CSR': float(self._solved_instances == self.get_num_agents())} 95 | self._solved_instances = 0 96 | return results 97 | 98 | 99 | class ISRMetric(AbstractMetric): 100 | def __init__(self, env): 101 | super().__init__(env) 102 | self._solved_instances = 0 103 | 104 | def _compute_stats(self, step, is_on_goal, finished): 105 | self._solved_instances += sum(is_on_goal) 106 | if finished: 107 | results = {'ISR': self._solved_instances / self.get_num_agents()} 108 | self._solved_instances = 0 109 | return results 110 | 111 | 112 | class SumOfCostsAndMakespanMetric(AbstractMetric): 113 | def __init__(self, env): 114 | super().__init__(env) 115 | self._solve_time = [None for _ in range(self.get_num_agents())] 116 | 117 | def _compute_stats(self, step, is_on_goal, finished): 118 | for idx, on_goal in enumerate(is_on_goal): 119 | if self._solve_time[idx] is None and (on_goal or finished): 120 | self._solve_time[idx] = step 121 | if not on_goal and not finished: 122 | self._solve_time[idx] = None 123 | 124 | if finished: 125 | result = {'SoC': sum(self._solve_time) + self.get_num_agents(), 'makespan': max(self._solve_time) + 1} 126 | self._solve_time = [None for _ in range(self.get_num_agents())] 127 | return result 128 | 129 | 130 | class AgentsDensityWrapper(Wrapper): 131 | def __init__(self, env): 132 | super().__init__(env) 133 | self._avg_agents_density = None 134 | 135 | def count_agents(self, observations): 136 | avg_agents_density = [] 137 | for obs in observations: 138 | traversable_cells = np.size(obs['obstacles']) - np.count_nonzero(obs['obstacles']) 139 | avg_agents_density.append(np.count_nonzero(obs['agents']) / traversable_cells) 140 | self._avg_agents_density.append(np.mean(avg_agents_density)) 141 | 142 | def step(self, actions): 143 | observations, rewards, terminated, truncated, infos = self.env.step(actions) 144 | self.count_agents(observations) 145 | if all(terminated) or all(truncated): 146 | if 'metrics' not in infos[0]: 147 | infos[0]['metrics'] = {} 148 | infos[0]['metrics'].update(avg_agents_density=float(np.mean(self._avg_agents_density))) 149 | return observations, rewards, terminated, truncated, infos 150 | 151 | def reset(self, **kwargs): 152 | self._avg_agents_density = [] 153 | observations, info = self.env.reset(**kwargs) 154 | self.count_agents(observations) 155 | return observations, info 156 | 157 | 158 | class RuntimeMetricWrapper(Wrapper): 159 | def __init__(self, env): 160 | super().__init__(env) 161 | self._start_time = None 162 | self._env_step_time = None 163 | 164 | def step(self, actions): 165 | env_step_start = time.monotonic() 166 | observations, rewards, terminated, truncated, infos = self.env.step(actions) 167 | env_step_end = time.monotonic() 168 | self._env_step_time += env_step_end - env_step_start 169 | if all(terminated) or all(truncated): 170 | final_time = time.monotonic() - self._start_time - self._env_step_time 171 | if 'metrics' not in infos[0]: 172 | infos[0]['metrics'] = {} 173 | infos[0]['metrics'].update(runtime=final_time) 174 | return observations, rewards, terminated, truncated, infos 175 | 176 | def reset(self, **kwargs): 177 | obs = self.env.reset(**kwargs) 178 | self._start_time = time.monotonic() 179 | self._env_step_time = 0.0 180 | return obs 181 | -------------------------------------------------------------------------------- /pogema/svg_animation/animation_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import cycle 3 | from gymnasium import logger, Wrapper 4 | 5 | from pogema import GridConfig 6 | from pogema.svg_animation.animation_drawer import AnimationConfig, SvgSettings, GridHolder, AnimationDrawer 7 | from pogema.wrappers.persistence import PersistentWrapper, AgentState 8 | 9 | 10 | class AnimationMonitor(Wrapper): 11 | """ 12 | Defines the animation, which saves the episode as SVG. 13 | """ 14 | 15 | def __init__(self, env, animation_config=AnimationConfig()): 16 | self._working_radius = env.grid_config.obs_radius - 1 17 | env = PersistentWrapper(env, xy_offset=-self._working_radius) 18 | 19 | super().__init__(env) 20 | 21 | self.history = env.get_history() 22 | 23 | self.svg_settings: SvgSettings = SvgSettings() 24 | self.animation_config: AnimationConfig = animation_config 25 | 26 | self._episode_idx = 0 27 | 28 | def step(self, action): 29 | """ 30 | Saves information about the episode. 31 | :param action: current actions 32 | :return: obs, reward, done, info 33 | """ 34 | obs, reward, terminated, truncated, info = self.env.step(action) 35 | 36 | multi_agent_terminated = isinstance(terminated, (list, tuple)) and all(terminated) 37 | single_agent_terminated = isinstance(terminated, (bool, int)) and terminated 38 | multi_agent_truncated = isinstance(truncated, (list, tuple)) and all(truncated) 39 | single_agent_truncated = isinstance(truncated, (bool, int)) and truncated 40 | 41 | if multi_agent_terminated or single_agent_terminated or multi_agent_truncated or single_agent_truncated: 42 | save_tau = self.animation_config.save_every_idx_episode 43 | if save_tau: 44 | if (self._episode_idx + 1) % save_tau or save_tau == 1: 45 | if not os.path.exists(self.animation_config.directory): 46 | logger.info(f"Creating pogema monitor directory {self.animation_config.directory}", ) 47 | os.makedirs(self.animation_config.directory, exist_ok=True) 48 | 49 | path = os.path.join(self.animation_config.directory, 50 | self.pick_name(self.grid_config, self._episode_idx)) 51 | self.save_animation(path) 52 | 53 | return obs, reward, terminated, truncated, info 54 | 55 | @staticmethod 56 | def pick_name(grid_config: GridConfig, episode_idx=None, zfill_ep=5): 57 | """ 58 | Picks a name for the SVG file. 59 | :param grid_config: configuration of the grid 60 | :param episode_idx: idx of the episode 61 | :param zfill_ep: zfill for the episode number 62 | :return: 63 | """ 64 | gc = grid_config 65 | name = 'pogema' 66 | if episode_idx is not None: 67 | name += f'-ep{str(episode_idx).zfill(zfill_ep)}' 68 | if gc: 69 | if gc.map_name: 70 | name += f'-{gc.map_name}' 71 | if gc.seed is not None: 72 | name += f'-seed{gc.seed}' 73 | else: 74 | name += '-render' 75 | return name + '.svg' 76 | 77 | def reset(self, **kwargs): 78 | """ 79 | Resets the environment and resets the current positions of agents and targets 80 | :param kwargs: 81 | :return: obs: observation 82 | """ 83 | obs = self.env.reset(**kwargs) 84 | 85 | self._episode_idx += 1 86 | self.history = self.env.get_history() 87 | 88 | return obs 89 | 90 | def save_animation(self, name='render.svg', animation_config: AnimationConfig = AnimationConfig()): 91 | """ 92 | Saves the animation. 93 | :param name: name of the file 94 | :param animation_config: animation configuration 95 | :return: None 96 | """ 97 | wr = self._working_radius 98 | if wr > 0: 99 | obstacles = self.env.get_obstacles(ignore_borders=False)[wr:-wr, wr:-wr] 100 | else: 101 | obstacles = self.env.get_obstacles(ignore_borders=False) 102 | history: list[list[AgentState]] = self.env.decompress_history(self.history) 103 | 104 | svg_settings = SvgSettings() 105 | colors_cycle = cycle(svg_settings.colors) 106 | agents_colors = {index: next(colors_cycle) for index in range(self.grid_config.num_agents)} 107 | 108 | for agent_idx in range(self.grid_config.num_agents): 109 | history[agent_idx].append(history[agent_idx][-1]) 110 | 111 | episode_length = len(history[0]) 112 | # Change episode length for egocentric environment 113 | if animation_config.egocentric_idx is not None and self.grid_config.on_target == 'finish': 114 | episode_length = history[animation_config.egocentric_idx][-1].step + 1 115 | for agent_idx in range(self.grid_config.num_agents): 116 | history[agent_idx] = history[agent_idx][:episode_length] 117 | 118 | grid_holder = GridHolder( 119 | width=len(obstacles), height=len(obstacles[0]), 120 | obstacles=obstacles, 121 | episode_length=episode_length, 122 | history=history, 123 | obs_radius=self.grid_config.obs_radius, 124 | on_target=self.grid_config.on_target, 125 | colors=agents_colors, 126 | config=animation_config, 127 | svg_settings=svg_settings 128 | ) 129 | 130 | animation = AnimationDrawer().create_animation(grid_holder) 131 | with open(name, "w") as f: 132 | f.write(animation.render()) 133 | 134 | 135 | def main(): 136 | from pogema import GridConfig, pogema_v0, AnimationMonitor, BatchAStarAgent, AnimationConfig 137 | 138 | for egocentric_idx in [0, 1]: 139 | for on_target in ['nothing', 'restart', 'finish']: 140 | grid = """ 141 | ....#.. 142 | ..#.... 143 | ....... 144 | ....... 145 | #.#.#.. 146 | #.#.#.. 147 | """ 148 | grid_config = GridConfig(size=32, num_agents=2, obs_radius=2, seed=8, on_target=on_target, 149 | max_episode_steps=16, 150 | density=0.1, map=grid, observation_type="POMAPF") 151 | env = pogema_v0(grid_config=grid_config) 152 | env = AnimationMonitor(env, AnimationConfig(save_every_idx_episode=None)) 153 | 154 | obs, _ = env.reset() 155 | truncated = terminated = [False] 156 | 157 | agent = BatchAStarAgent() 158 | while not all(terminated) and not all(truncated): 159 | obs, _, terminated, truncated, _ = env.step(agent.act(obs)) 160 | 161 | anim_folder = 'renders' 162 | if not os.path.exists(anim_folder): 163 | os.makedirs(anim_folder) 164 | 165 | env.save_animation(f'{anim_folder}/anim-{on_target}.svg') 166 | env.save_animation(f'{anim_folder}/anim-{on_target}-ego-{egocentric_idx}.svg', 167 | AnimationConfig(egocentric_idx=egocentric_idx)) 168 | env.save_animation(f'{anim_folder}/anim-static.svg', AnimationConfig(static=True)) 169 | env.save_animation(f'{anim_folder}/anim-static-ego.svg', AnimationConfig(egocentric_idx=0, static=True)) 170 | env.save_animation(f'{anim_folder}/anim-static-no-agents.svg', 171 | AnimationConfig(show_agents=False, static=True)) 172 | 173 | 174 | if __name__ == '__main__': 175 | main() 176 | -------------------------------------------------------------------------------- /tests/test_deterministic_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from heapq import heappop, heappush 3 | from pogema import GridConfig, pogema_v0, AnimationMonitor 4 | 5 | # from pogema.animation import AnimationMonitor 6 | 7 | INF = 1000000007 8 | 9 | 10 | class Node: 11 | def __init__(self, coord=(INF, INF), g: int = 0, h: int = 0): 12 | self.i, self.j = coord 13 | self.g = g 14 | self.h = h 15 | self.f = g + h 16 | 17 | def __lt__(self, other): 18 | return self.f < other.f or ((self.f == other.f) and (self.g < other.g)) 19 | 20 | 21 | class AStar: 22 | def __init__(self): 23 | self.start = (0, 0) 24 | self.goal = (0, 0) 25 | self.max_steps = 500 26 | self.OPEN = list() 27 | self.CLOSED = dict() 28 | self.obstacles = set() 29 | self.other_agents = set() 30 | self.best_node = Node(self.start, 0, self.h(self.start)) 31 | 32 | def h(self, node): 33 | return abs(node[0] - self.goal[0]) + abs(node[1] - self.goal[1]) 34 | 35 | def get_neighbours(self, u): 36 | neighbors = [] 37 | for d in [(-1, 0), (1, 0), (0, -1), (0, 1)]: 38 | if (u[0] + d[0], u[1] + d[1]) not in self.obstacles: 39 | neighbors.append((u[0] + d[0], u[1] + d[1])) 40 | return neighbors 41 | 42 | def compute_shortest_path(self): 43 | u = Node() 44 | steps = 0 45 | while len(self.OPEN) > 0 and steps < self.max_steps and (u.i, u.j) != self.goal: 46 | u = heappop(self.OPEN) 47 | if self.best_node.h > u.h: 48 | self.best_node = u 49 | steps += 1 50 | for n in self.get_neighbours((u.i, u.j)): 51 | if n not in self.CLOSED and n not in self.other_agents: 52 | heappush(self.OPEN, Node(n, u.g + 1, self.h(n))) 53 | self.CLOSED[n] = (u.i, u.j) 54 | 55 | def get_next_node(self): 56 | next_node = self.start 57 | if self.goal in self.CLOSED: 58 | next_node = self.goal 59 | while self.CLOSED[next_node] != self.start: 60 | next_node = self.CLOSED[next_node] 61 | return next_node 62 | 63 | def update_obstacles(self, obs, other_agents): 64 | obstacles = np.nonzero(obs) 65 | self.obstacles.clear() 66 | for k in range(len(obstacles[0])): 67 | self.obstacles.add((obstacles[0][k], obstacles[1][k])) 68 | self.other_agents.clear() 69 | agents = np.nonzero(other_agents) 70 | for k in range(len(agents[0])): 71 | self.other_agents.add((agents[0][k], agents[1][k])) 72 | 73 | def reset(self): 74 | self.CLOSED = dict() 75 | self.OPEN = list() 76 | heappush(self.OPEN, Node(self.start, 0, self.h(self.start))) 77 | self.best_node = Node(self.start, 0, self.h(self.start)) 78 | 79 | def update_path(self, start, goal): 80 | self.start = start 81 | self.goal = goal 82 | self.reset() 83 | self.compute_shortest_path() 84 | 85 | 86 | class DeterministicPolicy: 87 | def __init__(self, random_seed=42, random_rate=0.2): 88 | self.agents = None 89 | self.actions = {tuple(GridConfig().MOVES[i]): i for i in 90 | range(len(GridConfig().MOVES))} # make a dictionary to translate coordinates of action into id 91 | self.obs_radius = GridConfig().obs_radius 92 | self._rnd = np.random.RandomState(random_seed) 93 | self._random_rate = random_rate 94 | 95 | def get_goal(self, obs): 96 | goal = np.nonzero(obs[2]) 97 | goal_i = goal[0][0] 98 | goal_j = goal[1][0] 99 | if obs[0][goal_i][goal_j]: 100 | goal_i -= 1 if goal_i == 0 else 0 101 | goal_i += 1 if goal_i == self.obs_radius * 2 else 0 102 | goal_j -= 1 if goal_j == 0 else 0 103 | goal_j += 1 if goal_j == self.obs_radius * 2 else 0 104 | return goal_i, goal_j 105 | 106 | def act(self, obs) -> list: 107 | if self.agents is None: 108 | self.obs_radius = len(obs[0][0]) // 2 109 | self.agents = [AStar() for _ in range(len(obs))] # create a planner for each of the agents 110 | actions = [] 111 | for k in range(len(obs)): 112 | start = (self.obs_radius, self.obs_radius) 113 | goal = self.get_goal(obs[k]) 114 | if start == goal: # don't waste time on the agents that have already reached their goals 115 | actions.append(0) # just add useless action to save the order and length of the actions 116 | continue 117 | self.agents[k].update_obstacles(obs[k][0], obs[k][1]) 118 | self.agents[k].update_path(start, goal) 119 | next_node = self.agents[k].get_next_node() 120 | actions.append(self.actions[(next_node[0] - start[0], next_node[1] - start[1])]) 121 | for idx, action in enumerate(actions): 122 | if self._rnd.random() < self._random_rate: 123 | actions[idx] = self._rnd.randint(1, 4) 124 | return actions 125 | 126 | 127 | def run_policy(gc: GridConfig, save_animation=False): 128 | policy = DeterministicPolicy() 129 | env = pogema_v0(grid_config=gc) 130 | if save_animation: 131 | env = AnimationMonitor(env) 132 | 133 | while True: 134 | obs, info = env.reset() 135 | while True: 136 | obs, reward, terminated, truncated, info = env.step(policy.act(obs)) 137 | if all(terminated) or all(truncated): 138 | break 139 | 140 | yield info[0]['metrics'] 141 | 142 | 143 | def test_life_long(): 144 | gc = GridConfig(num_agents=20, size=8, obs_radius=4, seed=42, max_episode_steps=64, on_target='restart') 145 | results_generator = run_policy(gc, save_animation=False) 146 | 147 | metrics = results_generator.__next__() 148 | assert np.isclose(metrics['avg_throughput'], 1.671875) 149 | metrics = results_generator.__next__() 150 | assert np.isclose(metrics['avg_throughput'], 1.609375) 151 | 152 | gc = GridConfig(num_agents=24, size=8, obs_radius=4, seed=43, max_episode_steps=64, on_target='restart') 153 | results_generator = run_policy(gc, save_animation=False) 154 | 155 | metrics = results_generator.__next__() 156 | assert np.isclose(metrics['avg_throughput'], 0.4375) 157 | 158 | 159 | def test_disappearing(): 160 | gc = GridConfig(num_agents=20, size=8, obs_radius=2, seed=42, density=0.2, max_episode_steps=32, on_target='finish') 161 | results_generator = run_policy(gc, save_animation=False) 162 | 163 | metrics = results_generator.__next__() 164 | assert np.isclose(metrics['ep_length'], 22.95) 165 | assert np.isclose(metrics['ISR'], 0.5) 166 | assert np.isclose(metrics['CSR'], 0.0) 167 | 168 | metrics = results_generator.__next__() 169 | assert np.isclose(metrics['ep_length'], 15.55) 170 | assert np.isclose(metrics['ISR'], 0.9) 171 | assert np.isclose(metrics['CSR'], 0.0) 172 | 173 | 174 | def test_non_disappearing(): 175 | gc = GridConfig(num_agents=4, size=5, obs_radius=2, seed=3, density=0.2, max_episode_steps=32, on_target='nothing') 176 | results_generator = run_policy(gc, save_animation=False) 177 | 178 | metrics = results_generator.__next__() 179 | assert np.isclose(metrics['ep_length'], 21) 180 | assert np.isclose(metrics['CSR'], 1.0) 181 | assert np.isclose(metrics['ISR'], 1.0) 182 | 183 | metrics = results_generator.__next__() 184 | assert np.isclose(metrics['ep_length'], 14) 185 | assert np.isclose(metrics['CSR'], 1.0) 186 | assert np.isclose(metrics['ISR'], 1.0) 187 | 188 | gc = GridConfig(num_agents=7, size=5, obs_radius=2, seed=3, density=0.2, max_episode_steps=32, on_target='nothing') 189 | results_generator = run_policy(gc, save_animation=False) 190 | 191 | metrics = results_generator.__next__() 192 | assert np.isclose(metrics['ep_length'], 32) 193 | assert np.isclose(metrics['CSR'], 0.0) 194 | assert np.isclose(metrics['ISR'], 0.71428571428) 195 | -------------------------------------------------------------------------------- /tests/test_pogema_env.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | 4 | import numpy as np 5 | import pytest 6 | from tabulate import tabulate 7 | 8 | from pogema import pogema_v0, AnimationMonitor 9 | 10 | from pogema.envs import ActionsSampler 11 | from pogema.grid import GridConfig 12 | 13 | 14 | class ActionMapping: 15 | noop: int = 0 16 | up: int = 1 17 | down: int = 2 18 | left: int = 3 19 | right: int = 4 20 | 21 | 22 | def test_moving(): 23 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42)) 24 | ac = ActionMapping() 25 | env.reset() 26 | 27 | env.step([ac.right, ac.noop]) 28 | env.step([ac.up, ac.noop]) 29 | env.step([ac.left, ac.noop]) 30 | env.step([ac.down, ac.noop]) 31 | env.step([ac.down, ac.noop]) 32 | env.step([ac.left, ac.noop]) 33 | env.step([ac.left, ac.noop]) 34 | env.step([ac.up, ac.noop]) 35 | env.step([ac.up, ac.noop]) 36 | env.step([ac.up, ac.noop]) 37 | 38 | env.step([ac.right, ac.noop]) 39 | env.step([ac.up, ac.noop]) 40 | env.step([ac.right, ac.noop]) 41 | env.step([ac.down, ac.noop]) 42 | obs, reward, terminated, truncated, infos = env.step([ac.right, ac.noop]) 43 | 44 | assert np.isclose([1.0, 0.0], reward).all() 45 | assert np.isclose([True, False], terminated).all() 46 | 47 | 48 | def test_types(): 49 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42)) 50 | obs, info = env.reset() 51 | assert obs[0].dtype == np.float32 52 | 53 | 54 | def run_episode(grid_config=None, env=None): 55 | if env is None: 56 | env = pogema_v0(grid_config) 57 | env.reset() 58 | 59 | obs, rewards, terminated, truncated, infos = env.reset(), [None], [False], [False], [None] 60 | 61 | results = [[obs, rewards, terminated, truncated, infos]] 62 | while True: 63 | results.append(env.step(env.sample_actions())) 64 | terminated, truncated = results[-1][2], results[-1][3] 65 | if all(terminated) or all(truncated): 66 | break 67 | return results 68 | 69 | 70 | def test_metrics(): 71 | *_, infos = run_episode(GridConfig(num_agents=2, seed=5, size=5, max_episode_steps=64))[-1] 72 | assert np.isclose(infos[0]['metrics']['CSR'], 0.0) 73 | assert np.isclose(infos[0]['metrics']['ISR'], 0.5) 74 | 75 | *_, infos = run_episode(GridConfig(num_agents=2, seed=5, size=5, max_episode_steps=512))[-1] 76 | assert np.isclose(infos[0]['metrics']['CSR'], 1.0) 77 | assert np.isclose(infos[0]['metrics']['ISR'], 1.0) 78 | 79 | *_, infos = run_episode(GridConfig(num_agents=5, seed=5, size=5, max_episode_steps=64))[-1] 80 | assert np.isclose(infos[0]['metrics']['CSR'], 0.0) 81 | assert np.isclose(infos[0]['metrics']['ISR'], 0.2) 82 | 83 | 84 | def test_standard_pogema(): 85 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='finish')) 86 | env.reset() 87 | run_episode(env=env) 88 | 89 | 90 | def test_pomapf_observation(): 91 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='finish', 92 | observation_type='POMAPF')) 93 | obs, info = env.reset() 94 | assert 'agents' in obs[0] 95 | assert 'obstacles' in obs[0] 96 | assert 'xy' in obs[0] 97 | assert 'target_xy' in obs[0] 98 | run_episode(env=env) 99 | 100 | 101 | def test_mapf_observation(): 102 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='finish', 103 | observation_type='MAPF')) 104 | obs, info = env.reset() 105 | assert 'global_obstacles' in obs[0] 106 | assert 'global_xy' in obs[0] 107 | assert 'global_target_xy' in obs[0] 108 | run_episode(env=env) 109 | 110 | 111 | def test_standard_pogema_animation(): 112 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='finish')) 113 | env = AnimationMonitor(env) 114 | env.reset() 115 | run_episode(env=env) 116 | 117 | 118 | def test_gym_pogema_animation(): 119 | import gymnasium 120 | env = gymnasium.make('Pogema-v0', 121 | grid_config=GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, 122 | on_target='finish')) 123 | env = AnimationMonitor(env) 124 | env.reset() 125 | done = False 126 | while True: 127 | _, _, terminated, truncated, _ = env.step(env.action_space.sample()) 128 | if terminated or truncated: 129 | break 130 | 131 | 132 | def test_non_disappearing_pogema(): 133 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='nothing')) 134 | env.reset() 135 | run_episode(env=env) 136 | 137 | 138 | def test_non_disappearing_pogema_no_seed(): 139 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=None, on_target='nothing')) 140 | env.reset() 141 | run_episode(env=env) 142 | 143 | 144 | def test_non_disappearing_pogema_animation(): 145 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='nothing')) 146 | env = AnimationMonitor(env) 147 | env.reset() 148 | run_episode(env=env) 149 | 150 | 151 | def test_life_long_pogema(): 152 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='restart')) 153 | env.reset() 154 | run_episode(env=env) 155 | 156 | 157 | def test_life_long_pogema_empty_seed(): 158 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=None, on_target='restart')) 159 | env.reset() 160 | run_episode(env=env) 161 | 162 | 163 | def test_life_long_pogema_animation(): 164 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='restart')) 165 | env = AnimationMonitor(env) 166 | env.reset() 167 | run_episode(env=env) 168 | 169 | 170 | def test_custom_positions_and_num_agents(): 171 | grid = """ 172 | .... 173 | .... 174 | """ 175 | gc = GridConfig( 176 | map=grid, 177 | agents_xy=[[0, 0], [0, 1], [0, 2], [0, 3]], 178 | targets_xy=[[1, 0], [1, 1], [1, 2], [1, 3]], 179 | ) 180 | 181 | for num_agents in range(1, 5): 182 | gc.num_agents = num_agents 183 | env = pogema_v0(grid_config=gc) 184 | env.reset() 185 | assert num_agents == len(env.get_agents_xy()) 186 | assert num_agents == len(env.get_targets_xy()) 187 | 188 | 189 | def test_custom_positions_and_empty_num_agents(): 190 | grid = """ 191 | .... 192 | .... 193 | """ 194 | gc = GridConfig( 195 | map=grid, 196 | agents_xy=[[0, 0], [0, 1], [0, 2], [0, 3]], 197 | targets_xy=[[1, 0], [1, 1], [1, 2], [1, 3]], 198 | ) 199 | env = pogema_v0(grid_config=gc) 200 | env.reset() 201 | assert len(gc.agents_xy) == len(env.get_agents_xy()) 202 | 203 | 204 | def test_persistent_env(num_steps=100): 205 | seed = 42 206 | 207 | env = pogema_v0( 208 | grid_config=GridConfig(on_target='finish', seed=seed, num_agents=8, density=0.132, size=8, obs_radius=2, 209 | persistent=True)) 210 | 211 | env.reset() 212 | action_sampler = ActionsSampler(env.action_space.n, seed=seed) 213 | 214 | first_run_observations = [] 215 | 216 | def state_repr(observations, rewards, terminates, truncates, infos): 217 | return np.concatenate([np.array(observations).flatten(), terminates, truncates, np.array(rewards), ]) 218 | 219 | for current_step in range(num_steps): 220 | actions = action_sampler.sample_actions(dim=env.get_num_agents()) 221 | obs, reward, terminated, truncated, info = env.step(actions) 222 | 223 | first_run_observations.append(state_repr(obs, reward, terminated, truncated, info)) 224 | if all(terminated) or all(truncated): 225 | break 226 | 227 | # resetting the environment to the initial state using backward steps 228 | for current_step in range(num_steps): 229 | if not env.step_back(): 230 | break 231 | 232 | action_sampler = ActionsSampler(env.action_space.n, seed=seed) 233 | 234 | second_run_observations = [] 235 | for current_step in range(num_steps): 236 | actions = action_sampler.sample_actions(dim=env.get_num_agents()) 237 | obs, reward, terminated, truncated, info = env.step(actions) 238 | second_run_observations.append(state_repr(obs, reward, terminated, truncated, info)) 239 | assert np.isclose(first_run_observations[current_step], second_run_observations[current_step]).all() 240 | if all(terminated) or all(truncated): 241 | break 242 | assert np.isclose(first_run_observations, second_run_observations).all() 243 | 244 | 245 | def test_steps_per_second_throughput(): 246 | table = [] 247 | for on_target in ['finish', 'nothing', 'restart']: 248 | for num_agents in [1, 32, 64]: 249 | for size in [32, 64]: 250 | gc = GridConfig(obs_radius=5, seed=42, max_episode_steps=1024, 251 | size=size, num_agents=num_agents, on_target=on_target) 252 | 253 | start_time = time.monotonic() 254 | run_episode(grid_config=gc) 255 | end_time = time.monotonic() 256 | steps_per_second = gc.max_episode_steps / (end_time - start_time) 257 | table.append([on_target, num_agents, size, steps_per_second * gc.num_agents]) 258 | print('\n' + tabulate(table, headers=['on_target', 'num_agents', 'size', 'SPS (individual)'], tablefmt='grid')) 259 | -------------------------------------------------------------------------------- /pogema/grid_config.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Optional, Union 3 | from pydantic import validator, root_validator 4 | 5 | from pogema.utils import CommonSettings 6 | 7 | from typing_extensions import Literal 8 | 9 | 10 | class GridConfig(CommonSettings, ): 11 | on_target: Literal['finish', 'nothing', 'restart'] = 'finish' 12 | seed: Optional[int] = None 13 | width: Optional[int] = None 14 | height: Optional[int] = None 15 | size: int = 8 16 | density: float = 0.3 17 | obs_radius: int = 5 18 | agents_xy: Optional[list] = None 19 | targets_xy: Optional[list] = None 20 | num_agents: Optional[int] = None 21 | possible_agents_xy: Optional[list] = None 22 | possible_targets_xy: Optional[list] = None 23 | collision_system: Literal['block_both', 'priority', 'soft'] = 'priority' 24 | persistent: bool = False 25 | observation_type: Literal['POMAPF', 'MAPF', 'default'] = 'default' 26 | map: Optional[Union[list, str]] = None 27 | 28 | map_name: Optional[str] = None 29 | 30 | integration: Literal['SampleFactory', 'PyMARL', 'rllib', 'gymnasium', 'PettingZoo'] = None 31 | max_episode_steps: int = 64 32 | auto_reset: Optional[bool] = None 33 | 34 | @root_validator 35 | def validate_dimensions_and_positions(cls, values): 36 | width_provided = values.get('width') is not None 37 | height_provided = values.get('height') is not None 38 | 39 | if width_provided and not height_provided: 40 | raise ValueError("Invalid dimension configuration. Please provide height.") 41 | elif not width_provided and height_provided: 42 | raise ValueError("Invalid dimension configuration. Please provide width.") 43 | 44 | if not width_provided and not height_provided: 45 | values['width'] = values.get('size', 8) 46 | values['height'] = values.get('size', 8) 47 | if 'size' not in values or values.get('size') != max(values.get('width'), values.get('height')): 48 | values['size'] = max(values.get('width'), values.get('height')) 49 | 50 | 51 | width = values.get('width') 52 | height = values.get('height') 53 | 54 | if width is not None and height is not None: 55 | agents_xy = values.get('agents_xy') 56 | if agents_xy is not None: 57 | cls.check_positions(agents_xy, width, height) 58 | 59 | targets_xy = values.get('targets_xy') 60 | if targets_xy is not None: 61 | first_element = targets_xy[0] 62 | if isinstance(first_element[0], (list, tuple)): 63 | for agent_goals in targets_xy: 64 | cls.check_positions(agent_goals, width, height) 65 | else: 66 | cls.check_positions(targets_xy, width, height) 67 | 68 | return values 69 | 70 | @validator('seed') 71 | def seed_initialization(cls, v): 72 | assert v is None or (0 <= v < sys.maxsize), "seed must be in [0, " + str(sys.maxsize) + ']' 73 | return v 74 | 75 | @staticmethod 76 | def _validate_dimension(v, field_name): 77 | if v is not None: 78 | if field_name == 'size': 79 | assert 2 <= v <= 4096, f"{field_name} must be in [2, 4096]" 80 | else: 81 | assert 1 <= v <= 4096, f"{field_name} must be in [1, 4096]" 82 | return v 83 | 84 | @validator('size') 85 | def size_restrictions(cls, v): 86 | return cls._validate_dimension(v, 'size') 87 | 88 | @validator('width') 89 | def width_restrictions(cls, v): 90 | return cls._validate_dimension(v, 'width') 91 | 92 | @validator('height') 93 | def height_restrictions(cls, v): 94 | return cls._validate_dimension(v, 'height') 95 | 96 | @validator('density') 97 | def density_restrictions(cls, v): 98 | assert 0.0 <= v <= 1, "density must be in [0, 1]" 99 | return v 100 | 101 | @validator('agents_xy') 102 | def agents_xy_validation(cls, v, values): 103 | if v is not None: 104 | if not isinstance(v, (list, tuple)): 105 | raise ValueError("agents_xy must be a list") 106 | for position in v: 107 | if not isinstance(position, (list, tuple)) or len(position) != 2: 108 | raise ValueError("Position must be a list/tuple of length 2") 109 | if not all(isinstance(coord, int) for coord in position): 110 | raise ValueError("Position coordinates must be integers") 111 | return v 112 | 113 | @validator('targets_xy') 114 | def targets_xy_validation(cls, v, values): 115 | if v is not None: 116 | if not v or not isinstance(v, (list, tuple)): 117 | raise ValueError("targets_xy must be a list") 118 | 119 | first_element = v[0] 120 | if not isinstance(first_element, (list, tuple)): 121 | raise ValueError("Invalid targets_xy format") 122 | 123 | if isinstance(first_element[0], (list, tuple)): 124 | for agent_goals in v: 125 | if not isinstance(agent_goals, (list, tuple)) or len(agent_goals) < 2: 126 | raise ValueError("Each agent must have at least two goals in the sequence") 127 | for position in agent_goals: 128 | if not isinstance(position, (list, tuple)) or len(position) != 2: 129 | raise ValueError("Position must be a list/tuple of length 2") 130 | if not all(isinstance(coord, int) for coord in position): 131 | raise ValueError("Position coordinates must be integers") 132 | else: 133 | on_target = values.get('on_target', 'finish') 134 | if on_target == 'restart': 135 | raise ValueError("on_target='restart' requires goal sequences, not single goals. Use format: targets_xy: [[[x1,y1],[x2,y2]], [[x3,y3],[x4,y4]]]") 136 | for position in v: 137 | if not isinstance(position, (list, tuple)) or len(position) != 2: 138 | raise ValueError("Position must be a list/tuple of length 2") 139 | if not all(isinstance(coord, int) for coord in position): 140 | raise ValueError("Position coordinates must be integers") 141 | return v 142 | 143 | @staticmethod 144 | def check_positions(v, width, height): 145 | for position in v: 146 | if not isinstance(position, (list, tuple)) or len(position) != 2: 147 | raise ValueError("Position must be a list/tuple of length 2") 148 | x, y = position 149 | if not isinstance(x, int) or not isinstance(y, int): 150 | raise ValueError("Position coordinates must be integers") 151 | if not (0 <= x < height and 0 <= y < width): 152 | raise IndexError(f"Position is out of bounds! {position} is not in [{0}, {height}] x [{0}, {width}]") 153 | 154 | 155 | @validator('num_agents', always=True) 156 | def num_agents_must_be_positive(cls, v, values): 157 | if v is None: 158 | if values['agents_xy']: 159 | v = len(values['agents_xy']) 160 | else: 161 | v = 1 162 | assert 1 <= v <= 10000000, "num_agents must be in [1, 10000000]" 163 | return v 164 | 165 | @validator('obs_radius') 166 | def obs_radius_must_be_positive(cls, v): 167 | assert 1 <= v <= 128, "obs_radius must be in [1, 128]" 168 | return v 169 | 170 | @validator('map', always=True) 171 | def map_validation(cls, v, values): 172 | if v is None: 173 | return None 174 | if isinstance(v, str): 175 | v, agents_xy, targets_xy, possible_agents_xy, possible_targets_xy = cls.str_map_to_list(v, values['FREE'], 176 | values['OBSTACLE']) 177 | if agents_xy and targets_xy and values.get('agents_xy') is not None and values.get( 178 | 'targets_xy') is not None: 179 | raise KeyError("""Can't create task. Please provide agents_xy and targets_xy only once. 180 | Either with parameters or with a map.""") 181 | if (agents_xy or targets_xy) and (possible_agents_xy or possible_targets_xy): 182 | raise KeyError("""Can't create task. Mark either possible locations or precise ones.""") 183 | elif agents_xy and targets_xy: 184 | values['agents_xy'] = agents_xy 185 | values['targets_xy'] = targets_xy 186 | values['num_agents'] = len(agents_xy) 187 | elif (values.get('agents_xy') is None or values.get( 188 | 'targets_xy') is None) and possible_agents_xy and possible_targets_xy: 189 | values['possible_agents_xy'] = possible_agents_xy 190 | values['possible_targets_xy'] = possible_targets_xy 191 | 192 | height = len(v) 193 | width = 0 194 | area = 0 195 | for line in v: 196 | width = max(width, len(line)) 197 | area += len(line) 198 | 199 | values['size'] = max(width, height) 200 | values['width'] = width 201 | values['height'] = height 202 | values['density'] = sum([sum(line) for line in v]) / area 203 | 204 | return v 205 | 206 | @validator('possible_agents_xy') 207 | def possible_agents_xy_validation(cls, v): 208 | return v 209 | 210 | @validator('possible_targets_xy') 211 | def possible_targets_xy_validation(cls, v): 212 | return v 213 | 214 | @staticmethod 215 | def str_map_to_list(str_map, free, obstacle): 216 | obstacles = [] 217 | agents = {} 218 | targets = {} 219 | possible_agents_xy = [] 220 | possible_targets_xy = [] 221 | special_chars = {'@', '$', '!'} 222 | 223 | for row_idx, line in enumerate(str_map.split()): 224 | row = [] 225 | for col_idx, char in enumerate(line): 226 | position = (row_idx, col_idx) 227 | 228 | if char == '.': 229 | row.append(free) 230 | possible_agents_xy.append(position) 231 | possible_targets_xy.append(position) 232 | elif char == '#': 233 | row.append(obstacle) 234 | elif char in special_chars: 235 | row.append(free) 236 | if char == '@': 237 | possible_agents_xy.append(position) 238 | elif char == '$': 239 | possible_targets_xy.append(position) 240 | elif 'A' <= char <= 'Z': 241 | targets[char.lower()] = position 242 | row.append(free) 243 | possible_agents_xy.append(position) 244 | possible_targets_xy.append(position) 245 | elif 'a' <= char <= 'z': 246 | agents[char.lower()] = position 247 | row.append(free) 248 | possible_agents_xy.append(position) 249 | possible_targets_xy.append(position) 250 | else: 251 | raise KeyError(f"Unsupported symbol '{char}' at line {row_idx}") 252 | 253 | if row: 254 | assert len(obstacles[-1]) == len(row) if obstacles else True, f"Wrong string size for row {row_idx};" 255 | obstacles.append(row) 256 | 257 | agents_xy = [[x, y] for _, (x, y) in sorted(agents.items())] 258 | targets_xy = [[x, y] for _, (x, y) in sorted(targets.items())] 259 | 260 | assert len(targets_xy) == len(agents_xy), "Mismatch in number of agents and targets." 261 | 262 | if not any(char in special_chars for char in str_map): 263 | possible_agents_xy, possible_targets_xy = None, None 264 | 265 | return obstacles, agents_xy, targets_xy, possible_agents_xy, possible_targets_xy 266 | 267 | def update_config(self, **kwargs): 268 | current_values = self.dict() 269 | 270 | if 'size' in kwargs: 271 | current_values.pop('width', None) 272 | current_values.pop('height', None) 273 | elif 'width' in kwargs or 'height' in kwargs: 274 | current_values.pop('size', None) 275 | current_values.update(kwargs) 276 | new_instance = GridConfig(**current_values) 277 | 278 | for field_name, field_value in new_instance.__dict__.items(): 279 | setattr(self, field_name, field_value) 280 | -------------------------------------------------------------------------------- /pogema/grid.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import warnings 3 | 4 | import numpy as np 5 | 6 | from pogema.generator import generate_obstacles, generate_positions_and_targets_fast, \ 7 | get_components, generate_from_possible_positions 8 | from .grid_config import GridConfig 9 | from .grid_registry import in_registry, get_grid 10 | from .utils import render_grid 11 | 12 | 13 | class Grid: 14 | 15 | def __init__(self, grid_config: GridConfig, add_artificial_border: bool = True, num_retries=10): 16 | 17 | self.config = grid_config 18 | self.rnd = np.random.default_rng(grid_config.seed) 19 | if self.config.map is None: 20 | self.obstacles = generate_obstacles(self.config) 21 | else: 22 | self.obstacles = np.array([np.array(line) for line in self.config.map]) 23 | if in_registry(self.config.map_name): 24 | self.obstacles = get_grid(self.config.map_name).get_obstacles() 25 | self.obstacles = self.obstacles.astype(np.int32) 26 | 27 | if grid_config.targets_xy and grid_config.agents_xy: 28 | self.starts_xy = grid_config.agents_xy 29 | 30 | if isinstance(grid_config.targets_xy[0][0], (list, tuple)): 31 | self.finishes_xy = [sequence[0] for sequence in grid_config.targets_xy] 32 | else: 33 | self.finishes_xy = grid_config.targets_xy 34 | 35 | if len(self.starts_xy) != len(self.finishes_xy): 36 | raise IndexError("Can't create task. Please provide agents_xy and targets_xy of the same size.") 37 | if grid_config.num_agents > len(self.starts_xy): 38 | raise IndexError(f"Not enough agents_xy and targets_xy to place {grid_config.num_agents} agents") 39 | self.starts_xy = self.starts_xy[:grid_config.num_agents] 40 | self.finishes_xy = self.finishes_xy[:grid_config.num_agents] 41 | for start_xy, finish_xy in zip(self.starts_xy, self.finishes_xy): 42 | s_x, s_y = start_xy 43 | f_x, f_y = finish_xy 44 | if self.config.map is not None and self.obstacles[s_x, s_y] == grid_config.OBSTACLE: 45 | warnings.warn(f"There is an obstacle on a start point ({s_x}, {s_y}), replacing with free cell", 46 | Warning, stacklevel=2) 47 | self.obstacles[s_x, s_y] = grid_config.FREE 48 | if self.config.map is not None and self.obstacles[f_x, f_y] == grid_config.OBSTACLE: 49 | warnings.warn(f"There is an obstacle on a finish point ({f_x}, {f_y}), replacing with free cell", 50 | Warning, stacklevel=2) 51 | self.obstacles[f_x, f_y] = grid_config.FREE 52 | elif grid_config.possible_agents_xy and grid_config.possible_targets_xy: 53 | self.starts_xy, self.finishes_xy = generate_from_possible_positions(self.config) 54 | else: 55 | self.starts_xy, self.finishes_xy = generate_positions_and_targets_fast(self.obstacles, self.config) 56 | 57 | if len(self.starts_xy) != len(self.finishes_xy): 58 | for attempt in range(num_retries): 59 | if len(self.starts_xy) == len(self.finishes_xy): 60 | warnings.warn(f'Created valid configuration only with {attempt} attempts.', Warning, stacklevel=2) 61 | break 62 | if self.config.map is None: 63 | self.obstacles = generate_obstacles(self.config) 64 | self.starts_xy, self.finishes_xy = generate_positions_and_targets_fast(self.obstacles, self.config) 65 | 66 | if not self.starts_xy or not self.finishes_xy or len(self.starts_xy) != len(self.finishes_xy): 67 | raise OverflowError( 68 | "Can't create task. Please check grid grid_config, especially density, num_agent and map.") 69 | 70 | if add_artificial_border: 71 | self.add_artificial_border() 72 | 73 | filled_positions = np.zeros(self.obstacles.shape) 74 | for x, y in self.starts_xy: 75 | filled_positions[x, y] = 1 76 | 77 | self.positions = filled_positions 78 | self.positions_xy = self.starts_xy 79 | self._initial_xy = deepcopy(self.starts_xy) 80 | self.is_active = {agent_id: True for agent_id in range(self.config.num_agents)} 81 | 82 | def add_artificial_border(self): 83 | gc = self.config 84 | r = gc.obs_radius 85 | if gc.empty_outside: 86 | filled_obstacles = np.zeros(np.array(self.obstacles.shape) + r * 2) 87 | else: 88 | filled_obstacles = self.rnd.binomial(1, gc.density, np.array(self.obstacles.shape) + r * 2) 89 | 90 | height, width = filled_obstacles.shape 91 | filled_obstacles[r - 1, r - 1:width - r + 1] = gc.OBSTACLE 92 | filled_obstacles[r - 1:height - r + 1, r - 1] = gc.OBSTACLE 93 | filled_obstacles[height - r, r - 1:width - r + 1] = gc.OBSTACLE 94 | filled_obstacles[r - 1:height - r + 1, width - r] = gc.OBSTACLE 95 | filled_obstacles[r:height - r, r:width - r] = self.obstacles 96 | 97 | self.obstacles = filled_obstacles 98 | 99 | self.starts_xy = [(x + r, y + r) for x, y in self.starts_xy] 100 | self.finishes_xy = [(x + r, y + r) for x, y in self.finishes_xy] 101 | 102 | def get_obstacles(self, ignore_borders=False): 103 | gc = self.config 104 | if ignore_borders: 105 | return self.obstacles[gc.obs_radius:-gc.obs_radius, gc.obs_radius:-gc.obs_radius].copy() 106 | return self.obstacles.copy() 107 | 108 | @staticmethod 109 | def _cut_borders_xy(positions, obs_radius): 110 | return [[x - obs_radius, y - obs_radius] for x, y in positions] 111 | 112 | @staticmethod 113 | def _filter_inactive(pos, active_flags): 114 | return [pos for idx, pos in enumerate(pos) if active_flags[idx]] 115 | 116 | def get_grid_config(self): 117 | return deepcopy(self.config) 118 | 119 | # def _get_grid_config(self) -> GridConfig: 120 | # return self.env.grid_config 121 | 122 | def _prepare_positions(self, positions, only_active, ignore_borders): 123 | gc = self.config 124 | 125 | if only_active: 126 | positions = self._filter_inactive(positions, [idx for idx, active in self.is_active.items() if active]) 127 | 128 | if ignore_borders: 129 | positions = self._cut_borders_xy(positions, gc.obs_radius) 130 | 131 | return positions 132 | 133 | def get_agents_xy(self, only_active=False, ignore_borders=False): 134 | return self._prepare_positions(deepcopy(self.positions_xy), only_active, ignore_borders) 135 | 136 | @staticmethod 137 | def to_relative(coordinates, offset): 138 | result = deepcopy(coordinates) 139 | for idx, _ in enumerate(result): 140 | x, y = result[idx] 141 | dx, dy = offset[idx] 142 | result[idx] = x - dx, y - dy 143 | return result 144 | 145 | def get_agents_xy_relative(self): 146 | return self.to_relative(self.positions_xy, self._initial_xy) 147 | 148 | def get_targets_xy_relative(self): 149 | return self.to_relative(self.finishes_xy, self._initial_xy) 150 | 151 | def get_targets_xy(self, only_active=False, ignore_borders=False): 152 | return self._prepare_positions(deepcopy(self.finishes_xy), only_active, ignore_borders) 153 | 154 | def _normalize_coordinates(self, coordinates): 155 | gc = self.config 156 | 157 | x, y = coordinates 158 | 159 | x -= gc.obs_radius 160 | y -= gc.obs_radius 161 | 162 | x /= gc.height - 1 163 | y /= gc.width - 1 164 | 165 | return x, y 166 | 167 | def get_state(self, ignore_borders=False, as_dict=False): 168 | agents_xy = list(map(self._normalize_coordinates, self.get_agents_xy(ignore_borders))) 169 | targets_xy = list(map(self._normalize_coordinates, self.get_targets_xy(ignore_borders))) 170 | 171 | obstacles = self.get_obstacles(ignore_borders) 172 | 173 | if as_dict: 174 | return {"obstacles": obstacles, "agents_xy": agents_xy, "targets_xy": targets_xy} 175 | 176 | return np.concatenate(list(map(lambda x: np.array(x).flatten(), [agents_xy, targets_xy, obstacles]))) 177 | 178 | def get_observation_shape(self): 179 | full_radius = self.config.obs_radius * 2 + 1 180 | return 2, full_radius, full_radius 181 | 182 | def get_num_actions(self): 183 | return len(self.config.MOVES) 184 | 185 | def get_obstacles_for_agent(self, agent_id): 186 | x, y = self.positions_xy[agent_id] 187 | r = self.config.obs_radius 188 | return self.obstacles[x - r:x + r + 1, y - r:y + r + 1].astype(np.float32) 189 | 190 | def get_positions(self, agent_id): 191 | x, y = self.positions_xy[agent_id] 192 | r = self.config.obs_radius 193 | return self.positions[x - r:x + r + 1, y - r:y + r + 1].astype(np.float32) 194 | 195 | def get_target(self, agent_id): 196 | 197 | x, y = self.positions_xy[agent_id] 198 | fx, fy = self.finishes_xy[agent_id] 199 | if x == fx and y == fy: 200 | return 0.0, 0.0 201 | rx, ry = fx - x, fy - y 202 | dist = np.sqrt(rx ** 2 + ry ** 2) 203 | return rx / dist, ry / dist 204 | 205 | def get_square_target(self, agent_id): 206 | c = self.config 207 | full_size = self.config.obs_radius * 2 + 1 208 | result = np.zeros((full_size, full_size)) 209 | x, y = self.positions_xy[agent_id] 210 | fx, fy = self.finishes_xy[agent_id] 211 | dx, dy = x - fx, y - fy 212 | 213 | dx = min(dx, c.obs_radius) if dx >= 0 else max(dx, -c.obs_radius) 214 | dy = min(dy, c.obs_radius) if dy >= 0 else max(dy, -c.obs_radius) 215 | result[c.obs_radius - dx, c.obs_radius - dy] = 1 216 | return result.astype(np.float32) 217 | 218 | def render(self, mode='human'): 219 | render_grid(self.obstacles, self.positions_xy, self.finishes_xy, self.is_active, mode=mode) 220 | 221 | def move_agent_to_cell(self, agent_id, x, y): 222 | if self.positions[self.positions_xy[agent_id]] == self.config.FREE: 223 | raise KeyError("Agent {} is not in the map".format(agent_id)) 224 | self.positions[self.positions_xy[agent_id]] = self.config.FREE 225 | if self.obstacles[x, y] != self.config.FREE or self.positions[x, y] != self.config.FREE: 226 | raise ValueError(f"Can't force agent to blocked position {x} {y}") 227 | self.positions_xy[agent_id] = x, y 228 | self.positions[self.positions_xy[agent_id]] = self.config.OBSTACLE 229 | 230 | def has_obstacle(self, x, y): 231 | return self.obstacles[x, y] == self.config.OBSTACLE 232 | 233 | def move_without_checks(self, agent_id, action): 234 | x, y = self.positions_xy[agent_id] 235 | dx, dy = self.config.MOVES[action] 236 | self.positions[x, y] = self.config.FREE 237 | self.positions[x+dx, y+dy] = self.config.OBSTACLE 238 | self.positions_xy[agent_id] = (x+dx, y+dy) 239 | 240 | def move(self, agent_id, action): 241 | x, y = self.positions_xy[agent_id] 242 | dx, dy = self.config.MOVES[action] 243 | if self.obstacles[x + dx, y + dy] == self.config.FREE: 244 | if self.positions[x + dx, y + dy] == self.config.FREE: 245 | self.positions[x, y] = self.config.FREE 246 | x += dx 247 | y += dy 248 | self.positions[x, y] = self.config.OBSTACLE 249 | self.positions_xy[agent_id] = (x, y) 250 | 251 | def on_goal(self, agent_id): 252 | return self.positions_xy[agent_id] == self.finishes_xy[agent_id] 253 | 254 | def is_active(self, agent_id): 255 | return self.is_active[agent_id] 256 | 257 | def hide_agent(self, agent_id): 258 | if not self.is_active[agent_id]: 259 | return False 260 | self.is_active[agent_id] = False 261 | 262 | self.positions[self.positions_xy[agent_id]] = self.config.FREE 263 | 264 | return True 265 | 266 | def show_agent(self, agent_id): 267 | if self.is_active[agent_id]: 268 | return False 269 | 270 | self.is_active[agent_id] = True 271 | if self.positions[self.positions_xy[agent_id]] == self.config.OBSTACLE: 272 | raise KeyError("The cell is already occupied") 273 | self.positions[self.positions_xy[agent_id]] = self.config.OBSTACLE 274 | return True 275 | 276 | 277 | class GridLifeLong(Grid): 278 | def __init__(self, grid_config: GridConfig, add_artificial_border: bool = True, num_retries=10): 279 | 280 | super().__init__(grid_config, add_artificial_border, num_retries) 281 | 282 | self.component_to_points, self.point_to_component = get_components(grid_config, self.obstacles, 283 | self.positions_xy, self.finishes_xy) 284 | 285 | for i in range(len(self.positions_xy)): 286 | position, target = self.positions_xy[i], self.finishes_xy[i] 287 | if self.point_to_component[position] != self.point_to_component[target]: 288 | warnings.warn(f"The start point ({position[0]}, {position[1]}) and the goal" 289 | f" ({target[0]}, {target[1]}) are in different components. The goal is changed.", 290 | Warning, stacklevel=2) 291 | -------------------------------------------------------------------------------- /pogema/svg_animation/animation_drawer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import typing 3 | from dataclasses import dataclass 4 | 5 | from pogema import GridConfig 6 | from pogema.svg_animation.svg_objects import Line, RectangleHref, Animation, Circle, Rectangle 7 | 8 | 9 | @dataclass 10 | class AnimationConfig: 11 | directory: str = 'renders/' 12 | static: bool = False 13 | show_agents: bool = True 14 | egocentric_idx: typing.Optional[int] = None 15 | uid: typing.Optional[str] = None 16 | save_every_idx_episode: typing.Optional[int] = 1 17 | show_grid_lines: bool = True 18 | 19 | 20 | @dataclass 21 | class SvgSettings: 22 | r: int = 35 23 | stroke_width: int = 10 24 | scale_size: int = 100 25 | time_scale: float = 0.25 26 | draw_start: int = 100 27 | rx: int = 15 28 | 29 | obstacle_color: str = '#84A1AE' 30 | ego_color: str = '#c1433c' 31 | ego_other_color: str = '#6e81af' 32 | shaded_opacity: float = 0.2 33 | egocentric_shaded: bool = True 34 | stroke_dasharray: int = 25 35 | 36 | colors: tuple = ( 37 | '#c1433c', 38 | '#2e6f9e', 39 | '#6e81af', 40 | '#00b9c8', 41 | '#72D5C8', 42 | '#0ea08c', 43 | '#8F7B66', 44 | ) 45 | 46 | 47 | @dataclass 48 | class GridHolder: 49 | obstacles: typing.Any = None 50 | episode_length: int = None 51 | height: int = None 52 | width: int = None 53 | colors: dict = None 54 | history: list = None 55 | obs_radius: int = None 56 | grid_config: GridConfig = None 57 | on_target: str = None 58 | config: AnimationConfig = None 59 | svg_settings: SvgSettings = None 60 | 61 | 62 | class Drawing: 63 | 64 | def __init__(self, height, width, svg_settings): 65 | self.height = height 66 | self.width = width 67 | self.origin = (0, 0) 68 | self.elements = [] 69 | self.svg_settings = svg_settings 70 | 71 | def add_element(self, element): 72 | self.elements.append(element) 73 | 74 | def render(self): 75 | scale = max(self.height, self.width) / 512 76 | scaled_width = math.ceil(self.width / scale) 77 | scaled_height = math.ceil(self.height / scale) 78 | 79 | dx, dy = self.origin 80 | view_box = (dx, dy - self.height, self.width, self.height) 81 | 82 | svg_header = f''' 83 | ''' 85 | 86 | definitions = f''' 87 | 88 | 93 | ''' 94 | 95 | elements_svg = [svg_header, '', definitions, '\n'] 96 | elements_svg.extend(element.render() for element in self.elements) 97 | elements_svg.append('') 98 | return "\n".join(elements_svg) 99 | 100 | 101 | class AnimationDrawer: 102 | 103 | def __init__(self): 104 | pass 105 | 106 | def create_animation(self, grid_holder: GridHolder): 107 | gh = grid_holder 108 | render_width = gh.height * gh.svg_settings.scale_size + gh.svg_settings.scale_size 109 | render_height = gh.width * gh.svg_settings.scale_size + gh.svg_settings.scale_size 110 | drawing = Drawing(width=render_width, height=render_height, svg_settings=SvgSettings()) 111 | obstacles = self.create_obstacles(gh) 112 | 113 | agents = [] 114 | targets = [] 115 | 116 | if gh.config.show_agents: 117 | agents = self.create_agents(gh) 118 | targets = self.create_targets(gh) 119 | 120 | if not gh.config.static: 121 | self.animate_agents(agents, gh) 122 | self.animate_targets(targets, gh) 123 | if gh.config.show_grid_lines: 124 | grid_lines = self.create_grid_lines(gh, render_width, render_height) 125 | for line in grid_lines: 126 | drawing.add_element(line) 127 | for obj in [*obstacles, *agents, *targets]: 128 | drawing.add_element(obj) 129 | 130 | if gh.config.egocentric_idx is not None: 131 | field_of_view = self.create_field_of_view(grid_holder=gh) 132 | if not gh.config.static: 133 | self.animate_obstacles(obstacles=obstacles, grid_holder=gh) 134 | self.animate_field_of_view(field_of_view, gh) 135 | drawing.add_element(field_of_view) 136 | 137 | return drawing 138 | 139 | @staticmethod 140 | def fix_point(x, y, length): 141 | return length - y - 1, x 142 | 143 | @staticmethod 144 | def check_in_radius(x1, y1, x2, y2, r) -> bool: 145 | return x2 - r <= x1 <= x2 + r and y2 - r <= y1 <= y2 + r 146 | 147 | @staticmethod 148 | def create_grid_lines(grid_holder: GridHolder, render_width, render_height): 149 | gh = grid_holder 150 | offset = 0 151 | stroke_settings = {'class': 'line'} 152 | grid_lines = [] 153 | for i in range(-1, grid_holder.height + 1): 154 | x = i * gh.svg_settings.scale_size + gh.svg_settings.scale_size / 2 155 | grid_lines.append(Line(x1=x, y1=offset, x2=x, y2=render_height - offset, **stroke_settings)) 156 | 157 | for i in range(-1, grid_holder.width + 1): 158 | y = i * gh.svg_settings.scale_size + gh.svg_settings.scale_size / 2 159 | grid_lines.append(Line(x1=offset, y1=y, x2=render_width - offset, y2=y, **stroke_settings)) 160 | 161 | return grid_lines 162 | 163 | @staticmethod 164 | def create_field_of_view(grid_holder): 165 | gh: GridHolder = grid_holder 166 | ego_idx = gh.config.egocentric_idx 167 | x, y = gh.history[ego_idx][0].get_xy() 168 | cx = gh.svg_settings.draw_start + y * gh.svg_settings.scale_size 169 | cy = gh.svg_settings.draw_start + (gh.width - x - 1) * gh.svg_settings.scale_size 170 | 171 | dr = (grid_holder.obs_radius + 1) * gh.svg_settings.scale_size - gh.svg_settings.stroke_width * 2 172 | result = Rectangle( 173 | x=cx - dr + gh.svg_settings.r, y=cy - dr + gh.svg_settings.r, 174 | width=2 * dr - 2 * gh.svg_settings.r, height=2 * dr - 2 * gh.svg_settings.r, 175 | stroke=gh.svg_settings.ego_color, stroke_width=gh.svg_settings.stroke_width, 176 | fill='none', rx=gh.svg_settings.rx, stroke_dasharray=gh.svg_settings.stroke_dasharray 177 | ) 178 | 179 | return result 180 | 181 | def animate_field_of_view(self, view, grid_holder): 182 | gh: GridHolder = grid_holder 183 | x_path = [] 184 | y_path = [] 185 | ego_idx = grid_holder.config.egocentric_idx 186 | for state in gh.history[ego_idx]: 187 | x, y = state.get_xy() 188 | dr = (grid_holder.obs_radius + 1) * gh.svg_settings.scale_size - gh.svg_settings.stroke_width * 2 189 | cx = gh.svg_settings.draw_start + y * gh.svg_settings.scale_size 190 | cy = -gh.svg_settings.draw_start + -(gh.width - x - 1) * gh.svg_settings.scale_size 191 | x_path.append(str(cx - dr + gh.svg_settings.r)) 192 | y_path.append(str(cy - dr + gh.svg_settings.r)) 193 | 194 | visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[ego_idx]] 195 | 196 | view.add_animation(self.compressed_anim('x', x_path, gh.svg_settings.time_scale)) 197 | view.add_animation(self.compressed_anim('y', y_path, gh.svg_settings.time_scale)) 198 | view.add_animation(self.compressed_anim('visibility', visibility, gh.svg_settings.time_scale)) 199 | 200 | def animate_agents(self, agents, grid_holder): 201 | gh: GridHolder = grid_holder 202 | ego_idx = gh.config.egocentric_idx 203 | 204 | for agent_idx, agent in enumerate(agents): 205 | x_path = [] 206 | y_path = [] 207 | opacity = [] 208 | for idx, agent_state in enumerate(gh.history[agent_idx]): 209 | x, y = agent_state.get_xy() 210 | 211 | x_path.append(str(gh.svg_settings.draw_start + y * gh.svg_settings.scale_size)) 212 | y_path.append(str(-gh.svg_settings.draw_start + -(gh.width - x - 1) * gh.svg_settings.scale_size)) 213 | 214 | if ego_idx is not None: 215 | ego_x, ego_y = gh.history[ego_idx][idx].get_xy() 216 | if self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius): 217 | opacity.append('1.0') 218 | else: 219 | opacity.append(str(gh.svg_settings.shaded_opacity)) 220 | 221 | visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]] 222 | 223 | agent.add_animation(self.compressed_anim('cy', y_path, gh.svg_settings.time_scale)) 224 | agent.add_animation(self.compressed_anim('cx', x_path, gh.svg_settings.time_scale)) 225 | agent.add_animation(self.compressed_anim('visibility', visibility, gh.svg_settings.time_scale)) 226 | if opacity: 227 | agent.add_animation(self.compressed_anim('opacity', opacity, gh.svg_settings.time_scale)) 228 | 229 | @classmethod 230 | def compressed_anim(cls, attr_name, tokens, time_scale, rep_cnt='indefinite'): 231 | tokens, times = cls.compress_tokens(tokens) 232 | cumulative = [0, ] 233 | for t in times: 234 | cumulative.append(cumulative[-1] + t) 235 | times = [str(round(value / cumulative[-1], 10)) for value in cumulative] 236 | tokens = [tokens[0]] + tokens 237 | 238 | times = times 239 | tokens = tokens 240 | return Animation( 241 | attributeName=attr_name, dur=f'{time_scale * (-1 + cumulative[-1])}s', 242 | values=";".join(tokens), repeatCount=rep_cnt, keyTimes=";".join(times) 243 | ) 244 | 245 | @staticmethod 246 | def wisely_add(token, cnt, tokens, times): 247 | if cnt > 1: 248 | tokens += [token, token] 249 | times += [1, cnt - 1] 250 | else: 251 | tokens.append(token) 252 | times.append(cnt) 253 | 254 | @classmethod 255 | def compress_tokens(cls, input_tokens: list): 256 | tokens = [] 257 | times = [] 258 | if input_tokens: 259 | cur_idx = 0 260 | cnt = 1 261 | for idx in range(1, len(input_tokens)): 262 | if input_tokens[idx] == input_tokens[cur_idx]: 263 | cnt += 1 264 | else: 265 | cls.wisely_add(input_tokens[cur_idx], cnt, tokens, times) 266 | cnt = 1 267 | cur_idx = idx 268 | cls.wisely_add(input_tokens[cur_idx], cnt, tokens, times) 269 | return tokens, times 270 | 271 | def animate_targets(self, targets, grid_holder): 272 | gh: GridHolder = grid_holder 273 | ego_idx = gh.config.egocentric_idx 274 | 275 | for agent_idx, target in enumerate(targets): 276 | target_idx = ego_idx if ego_idx is not None else agent_idx 277 | 278 | x_path = [] 279 | y_path = [] 280 | 281 | for step_idx, state in enumerate(gh.history[target_idx]): 282 | x, y = state.get_target_xy() 283 | x_path.append(str(gh.svg_settings.draw_start + y * gh.svg_settings.scale_size)) 284 | y_path.append(str(-gh.svg_settings.draw_start + -(gh.width - x - 1) * gh.svg_settings.scale_size)) 285 | 286 | visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]] 287 | 288 | if gh.on_target == 'restart' or gh.on_target == 'wait': 289 | target.add_animation(self.compressed_anim('cy', y_path, gh.svg_settings.time_scale)) 290 | target.add_animation(self.compressed_anim('cx', x_path, gh.svg_settings.time_scale)) 291 | target.add_animation(self.compressed_anim("visibility", visibility, gh.svg_settings.time_scale)) 292 | 293 | def create_obstacles(self, grid_holder): 294 | gh = grid_holder 295 | result = [] 296 | 297 | for i in range(gh.height): 298 | for j in range(gh.width): 299 | x, y = self.fix_point(i, j, gh.width) 300 | 301 | if gh.obstacles[x][y]: 302 | obs_settings = {} 303 | obs_settings.update( 304 | x=gh.svg_settings.draw_start + i * gh.svg_settings.scale_size - gh.svg_settings.r, 305 | y=gh.svg_settings.draw_start + j * gh.svg_settings.scale_size - gh.svg_settings.r, 306 | height=gh.svg_settings.r * 2, 307 | ) 308 | 309 | if gh.config.egocentric_idx is not None and gh.svg_settings.egocentric_shaded: 310 | initial_positions = [agent_states[0].get_xy() for agent_states in gh.history] 311 | ego_x, ego_y = initial_positions[gh.config.egocentric_idx] 312 | if not self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius): 313 | obs_settings.update(opacity=gh.svg_settings.shaded_opacity) 314 | 315 | result.append(RectangleHref(**obs_settings)) 316 | 317 | return result 318 | 319 | def animate_obstacles(self, obstacles, grid_holder): 320 | gh: GridHolder = grid_holder 321 | obstacle_idx = 0 322 | 323 | for i in range(gh.height): 324 | for j in range(gh.width): 325 | x, y = self.fix_point(i, j, gh.width) 326 | if not gh.obstacles[x][y]: 327 | continue 328 | opacity = [] 329 | seen = set() 330 | for step_idx, agent_state in enumerate(gh.history[gh.config.egocentric_idx]): 331 | ego_x, ego_y = agent_state.get_xy() 332 | if self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius): 333 | seen.add((x, y)) 334 | if (x, y) in seen: 335 | opacity.append(str(1.0)) 336 | else: 337 | opacity.append(str(gh.svg_settings.shaded_opacity)) 338 | 339 | obstacle = obstacles[obstacle_idx] 340 | obstacle.add_animation(self.compressed_anim('opacity', opacity, gh.svg_settings.time_scale)) 341 | 342 | obstacle_idx += 1 343 | 344 | def create_agents(self, grid_holder): 345 | initial_positions = [state[0].get_xy() for state in grid_holder.history if state[0].is_active()] 346 | agents = [] 347 | gh: GridHolder = grid_holder 348 | ego_idx = grid_holder.config.egocentric_idx 349 | 350 | for idx, (x, y) in enumerate(initial_positions): 351 | circle_settings = { 352 | 'cx': gh.svg_settings.draw_start + y * gh.svg_settings.scale_size, 353 | 'cy': gh.svg_settings.draw_start + (grid_holder.width - x - 1) * gh.svg_settings.scale_size, 354 | 'r': gh.svg_settings.r, 'fill': grid_holder.colors[idx], 'class': 'agent', 355 | } 356 | 357 | if ego_idx is not None: 358 | ego_x, ego_y = initial_positions[ego_idx] 359 | is_out_of_radius = not self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius) 360 | circle_settings['fill'] = gh.svg_settings.ego_other_color 361 | if idx == ego_idx: 362 | circle_settings['fill'] = gh.svg_settings.ego_color 363 | elif is_out_of_radius and gh.svg_settings.egocentric_shaded: 364 | circle_settings['opacity'] = gh.svg_settings.shaded_opacity 365 | 366 | agents.append(Circle(**circle_settings)) 367 | 368 | return agents 369 | 370 | @staticmethod 371 | def create_targets(grid_holder): 372 | gh: GridHolder = grid_holder 373 | targets = [] 374 | for agent_idx, agent_states in enumerate(gh.history): 375 | 376 | tx, ty = agent_states[0].get_target_xy() 377 | x, y = ty, gh.width - tx - 1 378 | 379 | if not any([agent_state.is_active() for agent_state in gh.history[agent_idx]]): 380 | continue 381 | 382 | circle_settings = {"class": 'target'} 383 | circle_settings.update( 384 | cx=gh.svg_settings.draw_start + x * gh.svg_settings.scale_size, r=gh.svg_settings.r, 385 | cy=gh.svg_settings.draw_start + y * gh.svg_settings.scale_size, stroke=gh.colors[agent_idx], 386 | ) 387 | 388 | if gh.config.egocentric_idx is not None: 389 | if gh.config.egocentric_idx != agent_idx: 390 | continue 391 | 392 | circle_settings.update(stroke=gh.svg_settings.ego_color) 393 | target = Circle(**circle_settings) 394 | targets.append(target) 395 | return targets 396 | -------------------------------------------------------------------------------- /tests/test_grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pydantic import ValidationError 3 | 4 | from pogema import GridConfig 5 | from pogema.grid import Grid 6 | import pytest 7 | 8 | from pogema.integrations.make_pogema import pogema_v0 9 | 10 | 11 | def test_obstacle_creation(): 12 | config = GridConfig(seed=1, obs_radius=2, size=5, num_agents=1, density=0.2) 13 | obstacles = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 14 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], 15 | [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], 16 | [0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0], 17 | [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], 18 | [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], 19 | [0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0], 20 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], 21 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] 22 | assert np.isclose(Grid(config).obstacles, obstacles).all() 23 | 24 | config = GridConfig(seed=3, obs_radius=1, size=4, num_agents=1, density=0.4) 25 | obstacles = [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 26 | [1.0, 0.0, 0.0, 1.0, 0.0, 1.0], 27 | [1.0, 0.0, 0.0, 0.0, 0.0, 1.0], 28 | [1.0, 1.0, 0.0, 0.0, 0.0, 1.0], 29 | [1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 30 | [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]] 31 | assert np.isclose(Grid(config).obstacles, obstacles).all() 32 | 33 | 34 | def test_initial_positions(): 35 | config = GridConfig(seed=1, obs_radius=2, size=5, num_agents=1, density=0.2) 36 | positions_xy = [(2, 4)] 37 | assert np.isclose(Grid(config).positions_xy, positions_xy).all() 38 | 39 | config = GridConfig(seed=1, obs_radius=2, size=12, num_agents=10, density=0.2) 40 | positions_xy = [(13, 10), (7, 4), (4, 3), (2, 11), (12, 6), (8, 11), (6, 8), (2, 12), (2, 10), (9, 11)] 41 | assert np.isclose(Grid(config).positions_xy, positions_xy).all() 42 | 43 | 44 | def test_goals(): 45 | config = GridConfig(seed=1, obs_radius=2, size=5, num_agents=1, density=0.4) 46 | finishes_xy = [(5, 2)] 47 | assert np.isclose(Grid(config).finishes_xy, finishes_xy).all() 48 | 49 | config = GridConfig(seed=2, obs_radius=2, size=12, num_agents=10, density=0.2) 50 | finishes_xy = [(11, 10), (8, 11), (2, 13), (3, 5), (12, 6), (9, 12), (9, 6), (9, 2), (10, 2), (6, 11)] 51 | assert np.isclose(Grid(config).finishes_xy, finishes_xy).all() 52 | 53 | 54 | def test_overflow(): 55 | with pytest.raises(OverflowError): 56 | Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=100, density=0.0)) 57 | 58 | with pytest.raises(OverflowError): 59 | Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=1, density=1.0)) 60 | 61 | 62 | def test_overflow_warning(): 63 | with pytest.warns(Warning): 64 | for _ in range(1000): 65 | Grid(GridConfig(obs_radius=2, size=4, num_agents=6, density=0.3), num_retries=10000) 66 | 67 | 68 | def test_edge_cases(): 69 | with pytest.raises(ValidationError): 70 | GridConfig(seed=1, obs_radius=2, size=1, num_agents=1, density=0.4) 71 | 72 | with pytest.raises(ValidationError): 73 | GridConfig(seed=1, obs_radius=2, size=4, num_agents=0, density=0.4) 74 | 75 | with pytest.raises(OverflowError): 76 | Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=1, density=1.0)) 77 | 78 | with pytest.raises(ValidationError): 79 | Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=1, density=2.0)) 80 | 81 | 82 | def test_edge_cases_for_custom_map(): 83 | test_map = [[0, 0, 0]] 84 | with pytest.raises(OverflowError): 85 | Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=2, map=test_map)) 86 | with pytest.raises(OverflowError): 87 | Grid(GridConfig(seed=2, obs_radius=2, size=4, num_agents=4, map=test_map)) 88 | 89 | 90 | def test_custom_map(): 91 | test_map = [ 92 | [1, 0, 0], 93 | [0, 1, 0], 94 | [0, 0, 1], 95 | ] 96 | grid = Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=2, map=test_map)) 97 | obstacles = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 98 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], 99 | [0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], 100 | [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], 101 | [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], 102 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], 103 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] 104 | assert np.isclose(grid.obstacles, obstacles).all() 105 | 106 | test_map = [ 107 | [0, 1, 0], 108 | [0, 1, 0], 109 | [0, 0, 0], 110 | [0, 1, 0], 111 | [0, 1, 0], 112 | ] 113 | grid = Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=2, map=test_map)) 114 | obstacles = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 115 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], 116 | [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], 117 | [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], 118 | [0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0], 119 | [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], 120 | [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], 121 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], 122 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] 123 | assert np.isclose(grid.obstacles, obstacles).all() 124 | 125 | test_map = [ 126 | [0, 0, 1, 0, 0], 127 | [1, 0, 0, 0, 0], 128 | [0, 1, 0, 0, 1], 129 | ] 130 | grid = Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=2, map=test_map)) 131 | obstacles = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 132 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], 133 | [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0], 134 | [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], 135 | [0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], 136 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], 137 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] 138 | assert np.isclose(grid.obstacles, obstacles).all() 139 | 140 | 141 | def test_overflow_for_custom_map(): 142 | test_map = [ 143 | [0, 0, 1, 0, 0], 144 | [0, 1, 0, 1, 0], 145 | [0, 1, 0, 0, 1], 146 | ] 147 | with pytest.raises(OverflowError): 148 | Grid(GridConfig(obs_radius=2, size=4, num_agents=5, density=0.3, map=test_map), num_retries=100) 149 | 150 | 151 | def test_str_custom_map(): 152 | grid_map = """ 153 | .a...#..... 154 | .....#..... 155 | ..C.....b.. 156 | .....#..... 157 | .....#..... 158 | #.####..... 159 | .....###.## 160 | .....#..... 161 | .c...#..... 162 | .B.......A. 163 | .....#..... 164 | """ 165 | grid = Grid(GridConfig(obs_radius=2, size=4, density=0.3, map=grid_map)) 166 | assert (grid.config.num_agents == 3) 167 | assert (np.isclose(0.1404958, grid.config.density)) 168 | assert (np.isclose(11, grid.config.size)) 169 | 170 | grid_map = """.....#....""" 171 | grid = Grid(GridConfig(seed=2, num_agents=3, map=grid_map)) 172 | assert (grid.config.num_agents == 3) 173 | assert (np.isclose(0.1, grid.config.density)) 174 | assert (np.isclose(10, grid.config.size)) 175 | 176 | 177 | def test_custom_starts_and_finishes_random(): 178 | agents_xy = [(x, x) for x in range(8)] 179 | targets_xy = [(x, x) for x in range(8, 16)] 180 | grid_config = GridConfig(seed=12, size=16, num_agents=8, agents_xy=agents_xy, targets_xy=targets_xy) 181 | env = pogema_v0(grid_config=grid_config) 182 | env.reset() 183 | r = grid_config.obs_radius 184 | assert [(x - r, y - r) for x, y in env.grid.positions_xy] == agents_xy and \ 185 | [(x - r, y - r) for x, y in env.grid.finishes_xy] == targets_xy 186 | 187 | 188 | def test_out_of_bounds_for_custom_positions(): 189 | Grid(GridConfig(seed=12, size=17, agents_xy=[[0, 16]], targets_xy=[[16, 0]])) 190 | 191 | with pytest.raises(IndexError): 192 | GridConfig(seed=12, size=17, agents_xy=[[0, 17]], targets_xy=[[0, 0]]) 193 | with pytest.raises(IndexError): 194 | GridConfig(seed=12, size=17, agents_xy=[[0, 0]], targets_xy=[[0, 17]]) 195 | with pytest.raises(IndexError): 196 | GridConfig(seed=12, size=17, agents_xy=[[-1, 0]], targets_xy=[[0, 0]]) 197 | with pytest.raises(IndexError): 198 | GridConfig(seed=12, size=17, agents_xy=[[0, 0]], targets_xy=[[0, -1]]) 199 | 200 | 201 | def test_duplicated_params(): 202 | grid_map = "Aa" 203 | with pytest.raises(KeyError): 204 | GridConfig(agents_xy=[[0, 0]], targets_xy=[[0, 0]], map=grid_map) 205 | 206 | 207 | def test_custom_grid_with_empty_agents_and_targets(): 208 | grid_map = """....""" 209 | Grid(GridConfig(agents_xy=None, targets_xy=None, map=grid_map, num_agents=1)) 210 | 211 | 212 | def test_custom_grid_with_specific_positions(): 213 | grid_map = """ 214 | !!!!!!!!!!!!!!!!!! 215 | !@@!@@!$$$$$$$$$$! 216 | !@@!@@!##########! 217 | !@@!@@!$$$$$$$$$$! 218 | !!!!!!!!!!!!!!!!!! 219 | !@@!@@!$$$$$$$$$$! 220 | !@@!@@!##########! 221 | !@@!@@!$$$$$$$$$$! 222 | !!!!!!!!!!!!!!!!!! 223 | """ 224 | Grid(GridConfig(obs_radius=2, size=4, num_agents=24, map=grid_map)) 225 | with pytest.raises(OverflowError): 226 | Grid(GridConfig(obs_radius=2, size=4, num_agents=25, map=grid_map)) 227 | 228 | grid_map = """ 229 | !!!!!!!!!!! 230 | !@@!@@!$$$$ 231 | !@@!@@!#### 232 | !@@!@@!$$$$ 233 | !!!!!!!!!!! 234 | !@@!@@!$$$$ 235 | !@@!@@!#### 236 | !@@!@@!$$$$ 237 | !!!!!!!!!!! 238 | """ 239 | Grid(GridConfig(obs_radius=2, num_agents=16, map=grid_map)) 240 | with pytest.raises(OverflowError): 241 | Grid(GridConfig(obs_radius=2, num_agents=17, map=grid_map)) 242 | 243 | grid_map = """ 244 | !!!!!!!!!!! 245 | !@@!@@!.Ab. 246 | !@@!@@!#### 247 | !@@!@@!.aB. 248 | 249 | """ 250 | with pytest.raises(KeyError): 251 | Grid(GridConfig(obs_radius=2, map=grid_map)) 252 | 253 | 254 | def test_restricted_grid(): 255 | grid = """ 256 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 257 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 258 | !@@!@@!##########!##########!##########!@@!@@! 259 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@! 260 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 261 | """ 262 | env = pogema_v0(grid_config=GridConfig(map=grid, num_agents=24, seed=0, obs_radius=2)) 263 | env.reset() 264 | 265 | with pytest.raises(OverflowError): 266 | env = pogema_v0(grid_config=GridConfig(map=grid, num_agents=25, seed=0, obs_radius=2)) 267 | env.reset() 268 | 269 | 270 | def test_rectangular_grid_basic(): 271 | config = GridConfig(width=12, height=8) 272 | assert np.isclose(config.width, 12) 273 | assert np.isclose(config.height, 8) 274 | assert np.isclose(config.size, 12) 275 | 276 | 277 | def test_rectangular_grid_backward_compatibility(): 278 | config = GridConfig(size=10) 279 | assert np.isclose(config.size, 10) 280 | assert np.isclose(config.width, 10) 281 | assert np.isclose(config.height, 10) 282 | 283 | 284 | def test_rectangular_grid_mixed_config(): 285 | config = GridConfig(size=8, width=12, height=6) 286 | assert np.isclose(config.width, 12) 287 | assert np.isclose(config.height, 6) 288 | assert np.isclose(config.size, 12) 289 | 290 | 291 | def test_rectangular_grid_validation(): 292 | with pytest.raises(ValueError): 293 | GridConfig(width=12) 294 | 295 | with pytest.raises(ValueError): 296 | GridConfig(height=8) 297 | 298 | GridConfig(width=12, height=8) 299 | GridConfig(size=10) 300 | GridConfig(size=10, width=12, height=8) 301 | 302 | 303 | def test_rectangular_grid_position_validation(): 304 | config = GridConfig(width=12, height=8, agents_xy=[[0, 11], [7, 0]], targets_xy=[[7, 11], [0, 0]]) 305 | assert len(config.agents_xy) == 2 306 | assert len(config.targets_xy) == 2 307 | 308 | with pytest.raises(IndexError): 309 | GridConfig(width=12, height=8, agents_xy=[[8, 0]], targets_xy=[[0, 0]]) 310 | 311 | with pytest.raises(IndexError): 312 | GridConfig(width=12, height=8, agents_xy=[[0, 12]], targets_xy=[[0, 0]]) 313 | 314 | 315 | def test_rectangular_grid_creation(): 316 | config = GridConfig(width=12, height=8, seed=1, num_agents=2) 317 | grid = Grid(config) 318 | 319 | assert np.isclose(grid.config.width, 12) 320 | assert np.isclose(grid.config.height, 8) 321 | assert np.isclose(grid.config.size, 12) 322 | 323 | 324 | def test_goal_sequences_validation(): 325 | config = GridConfig( 326 | width=8, height=8, 327 | agents_xy=[[0, 0], [1, 1]], 328 | targets_xy=[ 329 | [[2, 2], [3, 3], [4, 4]], 330 | [[2, 4], [3, 5]] 331 | ] 332 | ) 333 | assert np.isclose(len(config.targets_xy), 2) 334 | assert np.isclose(len(config.targets_xy[0]), 3) 335 | assert np.isclose(len(config.targets_xy[1]), 2) 336 | 337 | config = GridConfig( 338 | width=8, height=8, 339 | agents_xy=[[0, 0], [1, 1]], 340 | targets_xy=[[7, 7], [6, 6]] 341 | ) 342 | assert np.isclose(len(config.targets_xy), 2) 343 | 344 | with pytest.raises(ValueError): 345 | GridConfig( 346 | width=8, height=8, 347 | agents_xy=[[0, 0], [1, 1]], 348 | targets_xy=[[[2, 2], [3, 3]], [4, 4]] 349 | ) 350 | 351 | with pytest.raises(ValueError): 352 | GridConfig( 353 | width=8, height=8, 354 | agents_xy=[[0, 0]], 355 | targets_xy=[[[2, 2]]] 356 | ) 357 | 358 | with pytest.raises(ValueError): 359 | GridConfig( 360 | width=8, height=8, 361 | agents_xy=[[0, 0]], 362 | targets_xy=[[[2.5, 2], [3, 3]]] 363 | ) 364 | 365 | with pytest.raises(IndexError): 366 | GridConfig( 367 | width=8, height=8, 368 | agents_xy=[[0, 0]], 369 | targets_xy=[[[2, 2], [10, 10]]] 370 | ) 371 | 372 | with pytest.raises(ValueError, match="on_target='restart' requires goal sequences"): 373 | GridConfig( 374 | width=8, height=8, 375 | agents_xy=[[0, 0], [1, 1]], 376 | targets_xy=[[2, 2], [3, 3]], 377 | on_target='restart' 378 | ) 379 | 380 | 381 | def test_grid_with_goal_sequences(): 382 | config = GridConfig( 383 | width=8, height=8, 384 | agents_xy=[[0, 0], [1, 1]], 385 | targets_xy=[ 386 | [[2, 2], [3, 3], [4, 4]], 387 | [[2, 4], [3, 5]] 388 | ] 389 | ) 390 | 391 | grid = Grid(config) 392 | 393 | expected_initial_targets = [[2, 2], [2, 4]] 394 | r = config.obs_radius 395 | expected_with_offset = [(x + r, y + r) for x, y in expected_initial_targets] 396 | 397 | assert np.isclose(grid.finishes_xy, expected_with_offset).all() 398 | 399 | 400 | def test_pogema_lifelong_with_sequences(): 401 | from pogema.envs import PogemaLifeLong 402 | import warnings 403 | 404 | config = GridConfig( 405 | width=8, height=8, 406 | agents_xy=[[1, 1], [1, 2]], 407 | targets_xy=[ 408 | [[2, 2], [3, 3], [4, 4]], 409 | [[2, 4], [3, 5]] 410 | ], 411 | on_target='restart' 412 | ) 413 | 414 | env = PogemaLifeLong(grid_config=config) 415 | obs = env.reset() 416 | 417 | assert env.has_custom_sequences == True 418 | assert np.isclose(env.current_goal_indices, [0, 0]).all() 419 | 420 | with warnings.catch_warnings(record=True) as w: 421 | warnings.simplefilter("always") 422 | 423 | target1 = env._generate_new_target(0) 424 | assert np.isclose(env.current_goal_indices[0], 1) 425 | 426 | target2 = env._generate_new_target(0) 427 | assert np.isclose(env.current_goal_indices[0], 2) 428 | 429 | target3 = env._generate_new_target(0) 430 | assert np.isclose(env.current_goal_indices[0], 0) 431 | 432 | cycling_warnings = [warning for warning in w if "completed all 3 provided targets" in str(warning.message)] 433 | assert np.isclose(len(cycling_warnings), 1) 434 | 435 | with warnings.catch_warnings(record=True) as w: 436 | warnings.simplefilter("always") 437 | 438 | env._generate_new_target(1) 439 | env._generate_new_target(1) 440 | 441 | assert np.isclose(len(w), 1) 442 | assert "completed all 2 provided targets" in str(w[0].message) 443 | assert "cycling back to the beginning" in str(w[0].message) 444 | 445 | 446 | def test_pogema_lifelong_reset(): 447 | from pogema.envs import PogemaLifeLong 448 | 449 | config = GridConfig( 450 | width=8, height=8, 451 | agents_xy=[[1, 1], [1, 2]], 452 | targets_xy=[ 453 | [[2, 2], [3, 3]], 454 | [[2, 4], [3, 5]] 455 | ], 456 | on_target='restart' 457 | ) 458 | 459 | env = PogemaLifeLong(grid_config=config) 460 | env.reset() 461 | 462 | env._generate_new_target(0) 463 | env._generate_new_target(1) 464 | assert np.isclose(env.current_goal_indices, [1, 1]).all() 465 | 466 | env.reset() 467 | assert np.isclose(env.current_goal_indices, [0, 0]).all() 468 | 469 | 470 | def test_pogema_lifelong_without_sequences(): 471 | from pogema.envs import PogemaLifeLong 472 | 473 | config = GridConfig( 474 | width=8, height=8, 475 | num_agents=2, 476 | on_target='restart' 477 | ) 478 | 479 | env = PogemaLifeLong(grid_config=config) 480 | obs = env.reset() 481 | 482 | assert env.has_custom_sequences == False 483 | 484 | target = env._generate_new_target(0) 485 | assert isinstance(target, tuple) 486 | assert np.isclose(len(target), 2) 487 | 488 | 489 | def test_goal_sequences_position_format(): 490 | with pytest.raises(ValueError, match="Position must be a list/tuple of length 2"): 491 | GridConfig( 492 | width=8, height=8, 493 | agents_xy=[[0, 0]], 494 | targets_xy=[[[2, 2, 3], [4, 4]]] 495 | ) 496 | 497 | with pytest.raises(ValueError, match="Position coordinates must be integers"): 498 | GridConfig( 499 | width=8, height=8, 500 | agents_xy=[[0, 0]], 501 | targets_xy=[[[2.5, 2], [4, 4]]] 502 | ) 503 | -------------------------------------------------------------------------------- /pogema/envs.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import warnings 3 | import numpy as np 4 | import gymnasium 5 | from gymnasium.error import ResetNeeded 6 | 7 | from pogema.grid import Grid, GridLifeLong 8 | from pogema.grid_config import GridConfig 9 | from pogema.wrappers.metrics import LifeLongAverageThroughputMetric, NonDisappearEpLengthMetric, \ 10 | NonDisappearCSRMetric, NonDisappearISRMetric, EpLengthMetric, ISRMetric, CSRMetric, SumOfCostsAndMakespanMetric 11 | from pogema.wrappers.multi_time_limit import MultiTimeLimit 12 | from pogema.generator import generate_new_target, generate_from_possible_targets 13 | from pogema.wrappers.persistence import PersistentWrapper 14 | 15 | 16 | class ActionsSampler: 17 | """ 18 | Samples the random actions for the given number of agents using the given seed. 19 | """ 20 | 21 | def __init__(self, num_actions, seed=42): 22 | self._num_actions = num_actions 23 | self._rnd = None 24 | self.update_seed(seed) 25 | 26 | def update_seed(self, seed=None): 27 | self._rnd = np.random.default_rng(seed) 28 | 29 | def sample_actions(self, dim=1): 30 | return self._rnd.integers(self._num_actions, size=dim) 31 | 32 | 33 | class PogemaBase(gymnasium.Env): 34 | """ 35 | Abstract class of the Pogema environment. 36 | """ 37 | metadata = {"render_modes": ["ansi"], } 38 | 39 | def step(self, action): 40 | raise NotImplementedError 41 | 42 | def reset(self, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None, ): 43 | raise NotImplementedError 44 | 45 | def __init__(self, grid_config: GridConfig = GridConfig()): 46 | # noinspection PyTypeChecker 47 | self.grid: Grid = None 48 | self.grid_config = grid_config 49 | 50 | self.action_space: gymnasium.spaces.Discrete = gymnasium.spaces.Discrete(len(self.grid_config.MOVES)) 51 | self._multi_action_sampler = ActionsSampler(self.action_space.n, seed=self.grid_config.seed) 52 | 53 | def _get_agents_obs(self, agent_id=0): 54 | """ 55 | Returns the observation of the agent with the given id. 56 | :param agent_id: 57 | :return: 58 | """ 59 | return np.concatenate([ 60 | self.grid.get_obstacles_for_agent(agent_id)[None], 61 | self.grid.get_positions(agent_id)[None], 62 | self.grid.get_square_target(agent_id)[None] 63 | ]) 64 | 65 | def check_reset(self): 66 | """ 67 | Checks if the reset needed. 68 | :return: 69 | """ 70 | if self.grid is None: 71 | raise ResetNeeded("Please reset environment first!") 72 | 73 | def render(self, mode='human'): 74 | """ 75 | Renders the environment using ascii graphics. 76 | :param mode: 77 | :return: 78 | """ 79 | self.check_reset() 80 | return self.grid.render(mode=mode) 81 | 82 | def sample_actions(self): 83 | """ 84 | Samples the random actions for the given number of agents. 85 | :return: 86 | """ 87 | return self._multi_action_sampler.sample_actions(dim=self.grid_config.num_agents) 88 | 89 | def get_num_agents(self): 90 | """ 91 | Returns the number of agents in the environment. 92 | :return: 93 | """ 94 | return self.grid_config.num_agents 95 | 96 | 97 | class Pogema(PogemaBase): 98 | def __init__(self, grid_config=GridConfig(num_agents=2)): 99 | super().__init__(grid_config) 100 | self.was_on_goal = None 101 | full_size = self.grid_config.obs_radius * 2 + 1 102 | if self.grid_config.observation_type == 'default': 103 | self.observation_space = gymnasium.spaces.Box(-1.0, 1.0, shape=(3, full_size, full_size)) 104 | elif self.grid_config.observation_type == 'POMAPF': 105 | self.observation_space: gymnasium.spaces.Dict = gymnasium.spaces.Dict( 106 | obstacles=gymnasium.spaces.Box(0.0, 1.0, shape=(full_size, full_size)), 107 | agents=gymnasium.spaces.Box(0.0, 1.0, shape=(full_size, full_size)), 108 | xy=gymnasium.spaces.Box(low=-1024, high=1024, shape=(2,), dtype=int), 109 | target_xy=gymnasium.spaces.Box(low=-1024, high=1024, shape=(2,), dtype=int), 110 | ) 111 | elif self.grid_config.observation_type == 'MAPF': 112 | self.observation_space: gymnasium.spaces.Dict = gymnasium.spaces.Dict( 113 | obstacles=gymnasium.spaces.Box(0.0, 1.0, shape=(full_size, full_size)), 114 | agents=gymnasium.spaces.Box(0.0, 1.0, shape=(full_size, full_size)), 115 | xy=gymnasium.spaces.Box(low=-1024, high=1024, shape=(2,), dtype=int), 116 | target_xy=gymnasium.spaces.Box(low=-1024, high=1024, shape=(2,), dtype=int), 117 | # global_obstacles=None, # todo define shapes of global state variables 118 | # global_xy=None, 119 | # global_target_xy=None, 120 | ) 121 | else: 122 | raise ValueError(f"Unknown observation type: {self.grid.config.observation_type}") 123 | 124 | def step(self, action: list): 125 | assert len(action) == self.grid_config.num_agents 126 | rewards = [] 127 | 128 | terminated = [] 129 | 130 | self.move_agents(action) 131 | self.update_was_on_goal() 132 | 133 | for agent_idx in range(self.grid_config.num_agents): 134 | 135 | on_goal = self.grid.on_goal(agent_idx) 136 | if on_goal and self.grid.is_active[agent_idx]: 137 | rewards.append(1.0) 138 | else: 139 | rewards.append(0.0) 140 | terminated.append(on_goal) 141 | 142 | for agent_idx in range(self.grid_config.num_agents): 143 | if self.grid.on_goal(agent_idx): 144 | self.grid.hide_agent(agent_idx) 145 | self.grid.is_active[agent_idx] = False 146 | 147 | infos = self._get_infos() 148 | 149 | observations = self._obs() 150 | truncated = [False] * self.grid_config.num_agents 151 | return observations, rewards, terminated, truncated, infos 152 | 153 | def _initialize_grid(self): 154 | self.grid: Grid = Grid(grid_config=self.grid_config) 155 | 156 | def update_was_on_goal(self): 157 | self.was_on_goal = [self.grid.on_goal(agent_idx) and self.grid.is_active[agent_idx] 158 | for agent_idx in range(self.grid_config.num_agents)] 159 | 160 | def reset(self, seed: Optional[int] = None, return_info: bool = True, options: Optional[dict] = None, ): 161 | self._initialize_grid() 162 | self.update_was_on_goal() 163 | 164 | if seed is not None: 165 | self.grid.seed = seed 166 | 167 | if return_info: 168 | return self._obs(), self._get_infos() 169 | return self._obs() 170 | 171 | def _obs(self): 172 | if self.grid_config.observation_type == 'default': 173 | return [self._get_agents_obs(index) for index in range(self.grid_config.num_agents)] 174 | elif self.grid_config.observation_type == 'POMAPF': 175 | return self._pomapf_obs() 176 | 177 | elif self.grid_config.observation_type == 'MAPF': 178 | results = self._pomapf_obs() 179 | global_obstacles = self.grid.get_obstacles() 180 | global_agents_xy = self.grid.get_agents_xy() 181 | global_targets_xy = self.grid.get_targets_xy() 182 | 183 | for agent_idx in range(self.grid_config.num_agents): 184 | result = results[agent_idx] 185 | result.update(global_obstacles=global_obstacles) 186 | result['global_xy'] = global_agents_xy[agent_idx] 187 | result['global_target_xy'] = global_targets_xy[agent_idx] 188 | 189 | return results 190 | else: 191 | raise ValueError(f"Unknown observation type: {self.grid.config.observation_type}") 192 | 193 | def _pomapf_obs(self): 194 | results = [] 195 | agents_xy_relative = self.grid.get_agents_xy_relative() 196 | targets_xy_relative = self.grid.get_targets_xy_relative() 197 | 198 | for agent_idx in range(self.grid_config.num_agents): 199 | result = {'obstacles': self.grid.get_obstacles_for_agent(agent_idx), 200 | 'agents': self.grid.get_positions(agent_idx), 201 | 'xy': agents_xy_relative[agent_idx], 202 | 'target_xy': targets_xy_relative[agent_idx]} 203 | 204 | results.append(result) 205 | return results 206 | 207 | def _get_infos(self): 208 | infos = [dict() for _ in range(self.grid_config.num_agents)] 209 | for agent_idx in range(self.grid_config.num_agents): 210 | infos[agent_idx]['is_active'] = self.grid.is_active[agent_idx] 211 | return infos 212 | 213 | def _revert_action(self, agent_idx, used_cells, cell, actions): 214 | actions[agent_idx] = 0 215 | used_cells[cell].remove(agent_idx) 216 | new_cell = self.grid.positions_xy[agent_idx] 217 | if new_cell in used_cells and len(used_cells[new_cell]) > 0: 218 | used_cells[new_cell].append(agent_idx) 219 | return self._revert_action(used_cells[new_cell][0], used_cells, new_cell, actions) 220 | else: 221 | used_cells.setdefault(new_cell, []).append(agent_idx) 222 | return actions, used_cells 223 | 224 | def move_agents(self, actions): 225 | if self.grid.config.collision_system == 'priority': 226 | for agent_idx in range(self.grid_config.num_agents): 227 | if self.grid.is_active[agent_idx]: 228 | self.grid.move(agent_idx, actions[agent_idx]) 229 | elif self.grid.config.collision_system == 'block_both': 230 | used_cells = {} 231 | agents_xy = self.grid.get_agents_xy() 232 | for agent_idx, (x, y) in enumerate(agents_xy): 233 | if self.grid.is_active[agent_idx]: 234 | dx, dy = self.grid_config.MOVES[actions[agent_idx]] 235 | used_cells[x + dx, y + dy] = 'blocked' if (x + dx, y + dy) in used_cells else 'visited' 236 | used_cells[x, y] = 'blocked' 237 | for agent_idx in range(self.grid_config.num_agents): 238 | if self.grid.is_active[agent_idx]: 239 | x, y = agents_xy[agent_idx] 240 | dx, dy = self.grid_config.MOVES[actions[agent_idx]] 241 | if used_cells.get((x + dx, y + dy), None) != 'blocked': 242 | self.grid.move(agent_idx, actions[agent_idx]) 243 | elif self.grid.config.collision_system == 'soft': 244 | used_cells = dict() 245 | used_edges = dict() 246 | agents_xy = self.grid.get_agents_xy() 247 | for agent_idx, (x, y) in enumerate(agents_xy): 248 | if self.grid.is_active[agent_idx]: 249 | dx, dy = self.grid.config.MOVES[actions[agent_idx]] 250 | used_cells.setdefault((x + dx, y + dy), []).append(agent_idx) 251 | used_edges[x, y, x + dx, y + dy] = [agent_idx] 252 | if dx != 0 or dy != 0: 253 | used_edges.setdefault((x + dx, y + dy, x, y), []).append(agent_idx) 254 | for agent_idx, (x, y) in enumerate(agents_xy): 255 | if self.grid.is_active[agent_idx]: 256 | dx, dy = self.grid.config.MOVES[actions[agent_idx]] 257 | if len(used_edges[x, y, x + dx, y + dy]) > 1: 258 | used_cells[x + dx, y + dy].remove(agent_idx) 259 | used_cells.setdefault((x, y), []).append(agent_idx) 260 | actions[agent_idx] = 0 261 | for agent_idx in reversed(range(len(agents_xy))): 262 | x, y = agents_xy[agent_idx] 263 | if self.grid.is_active[agent_idx]: 264 | dx, dy = self.grid.config.MOVES[actions[agent_idx]] 265 | if len(used_cells[x + dx, y + dy]) > 1 or self.grid.has_obstacle(x + dx, y + dy): 266 | actions, used_cells = self._revert_action(agent_idx, used_cells, (x + dx, y + dy), actions) 267 | for agent_idx in range(self.grid_config.num_agents): 268 | if self.grid.is_active[agent_idx]: 269 | self.grid.move_without_checks(agent_idx, actions[agent_idx]) 270 | else: 271 | raise ValueError('Unknown collision system: {}'.format(self.grid.config.collision_system)) 272 | 273 | def get_agents_xy_relative(self): 274 | return self.grid.get_agents_xy_relative() 275 | 276 | def get_targets_xy_relative(self): 277 | return self.grid.get_targets_xy_relative() 278 | 279 | def get_obstacles(self, ignore_borders=False): 280 | return self.grid.get_obstacles(ignore_borders=ignore_borders) 281 | 282 | def get_agents_xy(self, only_active=False, ignore_borders=False): 283 | return self.grid.get_agents_xy(only_active=only_active, ignore_borders=ignore_borders) 284 | 285 | def get_targets_xy(self, only_active=False, ignore_borders=False): 286 | return self.grid.get_targets_xy(only_active=only_active, ignore_borders=ignore_borders) 287 | 288 | def get_state(self, ignore_borders=False, as_dict=False): 289 | return self.grid.get_state(ignore_borders=ignore_borders, as_dict=as_dict) 290 | 291 | 292 | class PogemaLifeLong(Pogema): 293 | def __init__(self, grid_config=GridConfig(num_agents=2)): 294 | super().__init__(grid_config) 295 | self.current_goal_indices = [0] * grid_config.num_agents 296 | self.has_custom_sequences = grid_config.targets_xy is not None 297 | 298 | def _initialize_grid(self): 299 | self.grid: GridLifeLong = GridLifeLong(grid_config=self.grid_config) 300 | 301 | main_rng = np.random.default_rng(self.grid_config.seed) 302 | seeds = main_rng.integers(np.iinfo(np.int32).max, size=self.grid_config.num_agents) 303 | self.random_generators = [np.random.default_rng(seed) for seed in seeds] 304 | 305 | def get_lifelong_targets_xy(self, ignore_borders=False): 306 | if self.has_custom_sequences: 307 | if ignore_borders: 308 | return self.grid_config.targets_xy 309 | else: 310 | return [[[x + self.grid_config.obs_radius, y + self.grid_config.obs_radius] for x, y in sequence] 311 | for sequence in self.grid_config.targets_xy] 312 | 313 | sequences = [] 314 | 315 | main_rng = np.random.default_rng(self.grid_config.seed) 316 | seeds = main_rng.integers(np.iinfo(np.int32).max, size=self.grid_config.num_agents) 317 | temp_generators = [np.random.default_rng(seed) for seed in seeds] 318 | 319 | for agent_idx in range(self.grid_config.num_agents): 320 | agent_sequence = [] 321 | start_pos = self.get_agents_xy(ignore_borders=ignore_borders)[agent_idx] 322 | initial_target = self.get_targets_xy(ignore_borders=ignore_borders)[agent_idx] 323 | agent_sequence.append(initial_target) 324 | current_pos = initial_target 325 | total_distance = abs(start_pos[0] - initial_target[0]) + abs(start_pos[1] - initial_target[1]) 326 | 327 | while total_distance < self.grid_config.max_episode_steps: 328 | if ignore_borders: 329 | generator_pos = (current_pos[0] + self.grid_config.obs_radius, 330 | current_pos[1] + self.grid_config.obs_radius) 331 | else: 332 | generator_pos = tuple(current_pos) 333 | 334 | if self.grid_config.possible_targets_xy is not None: 335 | new_goal = generate_from_possible_targets( 336 | temp_generators[agent_idx], 337 | self.grid_config.possible_targets_xy, 338 | generator_pos 339 | ) 340 | if ignore_borders: 341 | goal_coords = list(new_goal) 342 | else: 343 | goal_coords = [new_goal[0] + self.grid_config.obs_radius, 344 | new_goal[1] + self.grid_config.obs_radius] 345 | else: 346 | new_goal = generate_new_target( 347 | temp_generators[agent_idx], 348 | self.grid.point_to_component, 349 | self.grid.component_to_points, 350 | generator_pos 351 | ) 352 | if ignore_borders: 353 | goal_coords = [new_goal[0] - self.grid_config.obs_radius, 354 | new_goal[1] - self.grid_config.obs_radius] 355 | else: 356 | goal_coords = list(new_goal) 357 | 358 | agent_sequence.append(goal_coords) 359 | total_distance += abs(current_pos[0] - goal_coords[0]) + abs(current_pos[1] - goal_coords[1]) 360 | current_pos = goal_coords 361 | sequences.append(agent_sequence) 362 | return sequences 363 | 364 | def reset(self, seed: Optional[int] = None, return_info: bool = True, options: Optional[dict] = None): 365 | self.current_goal_indices = [0] * self.grid_config.num_agents 366 | return super().reset(seed=seed, return_info=return_info, options=options) 367 | 368 | def _generate_new_target(self, agent_idx): 369 | if self.has_custom_sequences: 370 | agent_targets = self.grid_config.targets_xy[agent_idx] 371 | current_idx = self.current_goal_indices[agent_idx] 372 | next_target = agent_targets[(current_idx + 1) % len(agent_targets)] 373 | 374 | self.current_goal_indices[agent_idx] = (current_idx + 1) % len(agent_targets) 375 | 376 | if self.current_goal_indices[agent_idx] == 0 and current_idx == len(agent_targets) - 1: 377 | warnings.warn( 378 | f"Agent {agent_idx} has completed all {len(agent_targets)} provided targets and " 379 | f"is cycling back to the beginning. Provide more targets for the " 380 | f"{self.grid_config.max_episode_steps} episode length.", 381 | UserWarning, 382 | stacklevel=2 383 | ) 384 | 385 | return (next_target[0] + self.grid_config.obs_radius, 386 | next_target[1] + self.grid_config.obs_radius) 387 | elif self.grid_config.possible_targets_xy is not None: 388 | new_goal = generate_from_possible_targets(self.random_generators[agent_idx], 389 | self.grid_config.possible_targets_xy, 390 | self.grid.positions_xy[agent_idx]) 391 | return (new_goal[0] + self.grid_config.obs_radius, new_goal[1] + self.grid_config.obs_radius) 392 | else: 393 | return generate_new_target(self.random_generators[agent_idx], 394 | self.grid.point_to_component, 395 | self.grid.component_to_points, 396 | self.grid.positions_xy[agent_idx]) 397 | 398 | def step(self, action: list): 399 | assert len(action) == self.grid_config.num_agents 400 | rewards = [] 401 | 402 | infos = [dict() for _ in range(self.grid_config.num_agents)] 403 | 404 | self.move_agents(action) 405 | self.update_was_on_goal() 406 | 407 | for agent_idx in range(self.grid_config.num_agents): 408 | on_goal = self.grid.on_goal(agent_idx) 409 | if on_goal and self.grid.is_active[agent_idx]: 410 | rewards.append(1.0) 411 | else: 412 | rewards.append(0.0) 413 | 414 | if self.grid.on_goal(agent_idx): 415 | self.grid.finishes_xy[agent_idx] = self._generate_new_target(agent_idx) 416 | 417 | for agent_idx in range(self.grid_config.num_agents): 418 | infos[agent_idx]['is_active'] = self.grid.is_active[agent_idx] 419 | 420 | obs = self._obs() 421 | 422 | terminated = [False] * self.grid_config.num_agents 423 | truncated = [False] * self.grid_config.num_agents 424 | return obs, rewards, terminated, truncated, infos 425 | 426 | 427 | class PogemaCoopFinish(Pogema): 428 | def __init__(self, grid_config=GridConfig(num_agents=2)): 429 | super().__init__(grid_config) 430 | self.num_agents = self.grid_config.num_agents 431 | self.is_multiagent = True 432 | 433 | def _initialize_grid(self): 434 | self.grid: Grid = Grid(grid_config=self.grid_config) 435 | 436 | def step(self, action: list): 437 | assert len(action) == self.grid_config.num_agents 438 | 439 | infos = [dict() for _ in range(self.grid_config.num_agents)] 440 | 441 | self.move_agents(action) 442 | self.update_was_on_goal() 443 | 444 | is_task_solved = all(self.was_on_goal) 445 | for agent_idx in range(self.grid_config.num_agents): 446 | infos[agent_idx]['is_active'] = self.grid.is_active[agent_idx] 447 | 448 | obs = self._obs() 449 | 450 | terminated = [is_task_solved] * self.grid_config.num_agents 451 | truncated = [False] * self.grid_config.num_agents 452 | rewards = [1.0 if is_task_solved else 0.0 for _ in range(self.grid_config.num_agents)] 453 | return obs, rewards, terminated, truncated, infos 454 | 455 | 456 | def _make_pogema(grid_config): 457 | if grid_config.on_target == 'restart': 458 | env = PogemaLifeLong(grid_config=grid_config) 459 | elif grid_config.on_target == 'nothing': 460 | env = PogemaCoopFinish(grid_config=grid_config) 461 | elif grid_config.on_target == 'finish': 462 | env = Pogema(grid_config=grid_config) 463 | else: 464 | raise KeyError(f'Unknown on_target option: {grid_config.on_target}') 465 | 466 | env = MultiTimeLimit(env, grid_config.max_episode_steps) 467 | if env.grid_config.persistent: 468 | env = PersistentWrapper(env) 469 | else: 470 | # adding metrics wrappers 471 | if grid_config.on_target == 'restart': 472 | env = LifeLongAverageThroughputMetric(env) 473 | elif grid_config.on_target == 'nothing': 474 | env = NonDisappearISRMetric(env) 475 | env = NonDisappearCSRMetric(env) 476 | env = NonDisappearEpLengthMetric(env) 477 | env = SumOfCostsAndMakespanMetric(env) 478 | elif grid_config.on_target == 'finish': 479 | env = ISRMetric(env) 480 | env = CSRMetric(env) 481 | env = EpLengthMetric(env) 482 | else: 483 | raise KeyError(f'Unknown on_target option: {grid_config.on_target}') 484 | 485 | return env 486 | --------------------------------------------------------------------------------