├── pogema
├── wrappers
│ ├── __init__.py
│ ├── multi_time_limit.py
│ ├── persistence.py
│ └── metrics.py
├── integrations
│ ├── __init__.py
│ ├── sample_factory.py
│ ├── make_pogema.py
│ ├── pettingzoo.py
│ └── pymarl.py
├── svg_animation
│ ├── __init__.py
│ ├── svg_objects.py
│ ├── animation_wrapper.py
│ └── animation_drawer.py
├── __init__.py
├── grid_registry.py
├── utils.py
├── a_star_policy.py
├── generator.py
├── grid_config.py
├── grid.py
└── envs.py
├── local_build.sh
├── requirements.txt
├── LICENSE
├── .github
└── workflows
│ ├── CI.yml
│ └── codeql-analysis.yml
├── setup.py
├── version_history.MD
├── tests
├── test_integrations.py
├── test_deterministic_policy.py
├── test_pogema_env.py
└── test_grid.py
└── README.md
/pogema/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pogema/integrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pogema/svg_animation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/local_build.sh:
--------------------------------------------------------------------------------
1 | pip3 install --no-cache-dir .
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>1.23.5,<=1.26.4
2 | pydantic>=1.8.2,<=1.9.1
3 | pytest>=6.2.5,<=7.1.2
4 | pettingzoo==1.22.3
5 | tabulate>=0.8.7,<=0.8.10
6 | gymnasium==0.28.1
7 |
--------------------------------------------------------------------------------
/pogema/__init__.py:
--------------------------------------------------------------------------------
1 | from gymnasium import register
2 | from pogema.grid_config import GridConfig
3 | from pogema.integrations.make_pogema import pogema_v0
4 | from pogema.svg_animation.animation_wrapper import AnimationMonitor, AnimationConfig
5 | from pogema.a_star_policy import AStarAgent, BatchAStarAgent
6 |
7 | __version__ = '1.5.0a0'
8 |
9 | __all__ = [
10 | 'GridConfig',
11 | 'pogema_v0',
12 | 'AStarAgent', 'BatchAStarAgent',
13 | "AnimationMonitor", "AnimationConfig",
14 | ]
15 |
16 | register(
17 | id="Pogema-v0",
18 | entry_point="pogema.integrations.make_pogema:make_single_agent_gym",
19 | )
20 |
--------------------------------------------------------------------------------
/pogema/wrappers/multi_time_limit.py:
--------------------------------------------------------------------------------
1 | from gymnasium.wrappers import TimeLimit
2 |
3 |
4 | class MultiTimeLimit(TimeLimit):
5 | def step(self, action):
6 | observation, reward, terminated, truncated, info = self.env.step(action)
7 | self._elapsed_steps += 1
8 | if self._elapsed_steps >= self._max_episode_steps:
9 | truncated = [True] * self.get_num_agents()
10 | return observation, reward, terminated, truncated, info
11 |
12 | def set_elapsed_steps(self, elapsed_steps):
13 | if not self.grid_config.persistent:
14 | raise ValueError("Cannot set elapsed steps for non-persistent environment!")
15 | assert elapsed_steps >= 0
16 | self._elapsed_steps = elapsed_steps
17 |
--------------------------------------------------------------------------------
/pogema/integrations/sample_factory.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | from gymnasium import Wrapper
4 |
5 |
6 | class IsMultiAgentWrapper(Wrapper):
7 | def __init__(self, env):
8 | super().__init__(env)
9 |
10 | self.is_multiagent = True
11 |
12 | @property
13 | def num_agents(self):
14 | return self.get_num_agents()
15 |
16 |
17 | class MetricsForwardingWrapper(Wrapper):
18 | def step(self, action):
19 |
20 | observations, rewards, terminated, truncated, infos = self.env.step(action)
21 | for info in infos:
22 | if 'metrics' in info:
23 | info.update(episode_extra_stats=deepcopy(info['metrics']))
24 | return observations, rewards, terminated, truncated, infos
25 |
26 |
27 | class AutoResetWrapper(Wrapper):
28 | def step(self, action):
29 | observations, rewards, terminated, truncated, infos = self.env.step(action)
30 | if all(terminated) or all(truncated):
31 | observations, info = self.env.reset()
32 | return observations, rewards, terminated, truncated, infos
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022, Alexey Skrynnik
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/CI.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: CI
5 |
6 | on:
7 | push:
8 | branches: [ main, dev]
9 | pull_request:
10 | branches: [ main, dev]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | python-version: ["3.8", "3.9"]
20 |
21 | steps:
22 | - uses: actions/checkout@v2
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v2
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | python -m pip install flake8 pytest
31 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
32 | - name: Lint with flake8
33 | run: |
34 | # stop the build if there are Python syntax errors or undefined names
35 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
36 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
37 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
38 | - name: Test with pytest
39 | run: |
40 | PYTHONPATH=. pytest -s
41 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import codecs
2 | import os
3 | import re
4 |
5 | from setuptools import setup, find_packages
6 |
7 | cur_dir = os.path.abspath(os.path.dirname(__file__))
8 | with open(os.path.join(cur_dir, 'README.md'), 'rb') as f:
9 | lines = [x.decode('utf-8') for x in f.readlines()]
10 | lines = ''.join([re.sub('^<.*>\n$', '', x) for x in lines])
11 | long_description = lines
12 |
13 |
14 | def read(*parts):
15 | with codecs.open(os.path.join(cur_dir, *parts), 'r') as fp:
16 | return fp.read()
17 |
18 |
19 | def find_version(*file_paths):
20 | version_file = read(*file_paths)
21 | version_match = re.search(
22 | r"^__version__ = ['\"]([^'\"]*)['\"]",
23 | version_file,
24 | re.M,
25 | )
26 | if version_match:
27 | return version_match.group(1)
28 |
29 | raise RuntimeError("Unable to find version string.")
30 |
31 |
32 | setup(
33 | name='pogema',
34 | author='Alexey Skrynnik',
35 | license='MIT',
36 | version=find_version("pogema", "__init__.py"),
37 | description='Partially Observable Grid Environment for Multiple Agents',
38 | long_description=long_description,
39 | long_description_content_type='text/markdown',
40 | url='https://github.com/AIRI-Institute/pogema',
41 | install_requires=[
42 | "gymnasium==0.28.1",
43 | "numpy>1.23.5,<=1.26.4",
44 | "pydantic>=1.8.2,<=1.9.1",
45 | ],
46 | extras_require={
47 |
48 | },
49 | package_dir={'': './'},
50 | packages=find_packages(where='./', include='pogema*'),
51 | include_package_data=True,
52 | python_requires='>=3.7',
53 | )
54 |
--------------------------------------------------------------------------------
/pogema/svg_animation/svg_objects.py:
--------------------------------------------------------------------------------
1 | class SvgObject:
2 | tag = None
3 |
4 | def __init__(self, **kwargs):
5 | self.attributes = kwargs
6 | self.animations = []
7 |
8 | def add_animation(self, animation):
9 | self.animations.append(animation)
10 |
11 | @staticmethod
12 | def render_attributes(attributes):
13 | result = " ".join([f'{x.replace("_", "-")}="{y}"' for x, y in sorted(attributes.items())])
14 | return result
15 |
16 | def render(self):
17 | animations = '\n'.join([a.render() for a in self.animations]) if self.animations else None
18 | if animations:
19 | return f"<{self.tag} {self.render_attributes(self.attributes)}> {animations} {self.tag}>"
20 | return f"<{self.tag} {self.render_attributes(self.attributes)} />"
21 |
22 |
23 | class Rectangle(SvgObject):
24 | """
25 | Rectangle class for the SVG.
26 | """
27 | tag = 'rect'
28 |
29 | def __init__(self, **kwargs):
30 | super().__init__(**kwargs)
31 | self.attributes['y'] = -self.attributes['y'] - self.attributes['height']
32 |
33 |
34 | class RectangleHref(SvgObject):
35 | """
36 | Rectangle class for the SVG.
37 | """
38 | tag = 'use'
39 |
40 | def __init__(self, **kwargs):
41 | super().__init__(**kwargs)
42 | self.attributes['y'] = -self.attributes['y'] - self.attributes['height']
43 | self.attributes['href'] = "#obstacle"
44 | del self.attributes['height']
45 |
46 |
47 | class Circle(SvgObject):
48 | """
49 | Circle class for the SVG.
50 | """
51 | tag = 'circle'
52 |
53 | def __init__(self, **kwargs):
54 | super().__init__(**kwargs)
55 | self.attributes['cy'] = -self.attributes['cy']
56 |
57 |
58 | class Line(SvgObject):
59 | """
60 | Line class for the SVG.
61 | """
62 | tag = 'line'
63 |
64 | def __init__(self, **kwargs):
65 | super().__init__(**kwargs)
66 | self.attributes['y1'] = -self.attributes['y1']
67 | self.attributes['y2'] = -self.attributes['y2']
68 |
69 |
70 | class Animation(SvgObject):
71 | """
72 | Animation class for the SVG.
73 | """
74 | tag = 'animate'
75 |
76 | def render(self):
77 | return f"<{self.tag} {self.render_attributes(self.attributes)}/>"
78 |
--------------------------------------------------------------------------------
/.github/workflows/codeql-analysis.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 |
14 | on:
15 | push:
16 | branches: [ main ]
17 | pull_request:
18 | # The branches below must be a subset of the branches above
19 | branches: [ main ]
20 | schedule:
21 | - cron: '21 4 * * 1'
22 |
23 | jobs:
24 | analyze:
25 | name: Analyze
26 | runs-on: ubuntu-latest
27 | permissions:
28 | actions: read
29 | contents: read
30 | security-events: write
31 |
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | language: [ 'python' ]
36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
37 | # Learn more about CodeQL language support at https://git.io/codeql-language-support
38 |
39 | steps:
40 | - name: Checkout repository
41 | uses: actions/checkout@v2
42 |
43 | # Initializes the CodeQL tools for scanning.
44 | - name: Initialize CodeQL
45 | uses: github/codeql-action/init@v1
46 | with:
47 | languages: ${{ matrix.language }}
48 | # If you wish to specify custom queries, you can do so here or in a config file.
49 | # By default, queries listed here will override any specified in a config file.
50 | # Prefix the list here with "+" to use these queries and those in the config file.
51 | # queries: ./path/to/local/query, your-org/your-repo/queries@main
52 |
53 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
54 | # If this step fails, then you should remove it and run the build manually (see below)
55 | - name: Autobuild
56 | uses: github/codeql-action/autobuild@v1
57 |
58 | # ℹ️ Command-line programs to run using the OS shell.
59 | # 📚 https://git.io/JvXDl
60 |
61 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
62 | # and modify them (or add more) to build your code if your project
63 | # uses a compiled language
64 |
65 | #- run: |
66 | # make bootstrap
67 | # make release
68 |
69 | - name: Perform CodeQL Analysis
70 | uses: github/codeql-action/analyze@v1
71 |
--------------------------------------------------------------------------------
/pogema/integrations/make_pogema.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Optional
2 |
3 | from gymnasium import Wrapper
4 |
5 | from pogema import GridConfig
6 | from pogema.envs import _make_pogema
7 | from pogema.integrations.pettingzoo import parallel_env
8 | from pogema.integrations.pymarl import PyMarlPogema
9 | from pogema.integrations.sample_factory import AutoResetWrapper, IsMultiAgentWrapper, MetricsForwardingWrapper
10 |
11 |
12 | def _make_sample_factory_integration(grid_config):
13 | env = _make_pogema(grid_config)
14 | env = MetricsForwardingWrapper(env)
15 | env = IsMultiAgentWrapper(env)
16 | if grid_config.auto_reset is None or grid_config.auto_reset:
17 | env = AutoResetWrapper(env)
18 | return env
19 |
20 |
21 | def _make_py_marl_integration(grid_config, *_, **__):
22 | return PyMarlPogema(grid_config)
23 |
24 |
25 | class SingleAgentWrapper(Wrapper):
26 |
27 | def step(self, action):
28 | observations, rewards, terminated, truncated, infos = self.env.step(
29 | [action] + [self.env.action_space.sample() for _ in range(self.get_num_agents() - 1)])
30 | return observations[0], rewards[0], terminated[0], truncated[0], infos[0]
31 |
32 | def reset(self, seed: Optional[int] = None, return_info: bool = True, options: Optional[dict] = None, ):
33 | observations, infos = self.env.reset()
34 | if return_info:
35 | return observations[0], infos[0]
36 | else:
37 | return observations[0]
38 |
39 |
40 | def make_single_agent_gym(grid_config: Union[GridConfig, dict] = GridConfig()):
41 | env = _make_pogema(grid_config)
42 | env = SingleAgentWrapper(env)
43 |
44 | return env
45 |
46 |
47 | def make_pogema(grid_config: Union[GridConfig, dict] = GridConfig(), *args, **kwargs):
48 | if isinstance(grid_config, dict):
49 | grid_config = GridConfig(**grid_config)
50 |
51 | if grid_config.integration != 'SampleFactory' and grid_config.auto_reset:
52 | raise KeyError(f"{grid_config.integration} does not support auto_reset")
53 |
54 | if grid_config.integration is None:
55 | return _make_pogema(grid_config)
56 | elif grid_config.integration == 'SampleFactory':
57 | return _make_sample_factory_integration(grid_config)
58 | elif grid_config.integration == 'PyMARL':
59 | return _make_py_marl_integration(grid_config, *args, **kwargs)
60 | elif grid_config.integration == 'rllib':
61 | raise NotImplementedError('Please use PettingZoo integration for rllib')
62 | elif grid_config.integration == 'PettingZoo':
63 | return parallel_env(grid_config)
64 | elif grid_config.integration == 'gymnasium':
65 | return make_single_agent_gym(grid_config)
66 |
67 | raise KeyError(grid_config.integration)
68 |
69 |
70 | pogema_v0 = make_pogema
71 |
--------------------------------------------------------------------------------
/pogema/integrations/pettingzoo.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import numpy as np
4 | from pogema import GridConfig
5 | from pogema.envs import _make_pogema
6 |
7 |
8 | def parallel_env(grid_config: GridConfig = GridConfig()):
9 | return PogemaParallel(grid_config)
10 |
11 |
12 | class PogemaParallel:
13 |
14 | def state(self):
15 | return self.pogema.get_state()
16 |
17 | def __init__(self, grid_config: GridConfig, render_mode='ansi'):
18 | self.metadata = {'render_modes': ['ansi'], "name": "pogema"}
19 | self.render_mode = render_mode
20 | self.pogema = _make_pogema(grid_config)
21 | self.possible_agents = ["player_" + str(r) for r in range(self.pogema.get_num_agents())]
22 | self.agent_name_mapping = dict(zip(self.possible_agents, list(range(len(self.possible_agents)))))
23 | self.agents = None
24 | self.num_moves = None
25 |
26 | @functools.lru_cache(maxsize=None)
27 | def observation_space(self, agent):
28 | assert agent in self.possible_agents
29 | return self.pogema.observation_space
30 |
31 | @functools.lru_cache(maxsize=None)
32 | def action_space(self, agent):
33 | assert agent in self.possible_agents
34 | return self.pogema.action_space
35 |
36 | def render(self, mode="human"):
37 | assert mode == 'human'
38 | return self.pogema.render()
39 |
40 | def reset(self, seed=None, options=None):
41 | observations, info = self.pogema.reset(seed=seed, options=options)
42 | self.agents = self.possible_agents[:]
43 | self.num_moves = 0
44 | observations = {agent: observations[self.agent_name_mapping[agent]].astype(np.float32) for agent in self.agents}
45 | return observations
46 |
47 | def step(self, actions):
48 | anm = self.agent_name_mapping
49 |
50 | actions = [actions[agent] if agent in actions else 0 for agent in self.possible_agents]
51 | observations, rewards, terminated, truncated, infos = self.pogema.step(actions)
52 | d_observations = {agent: observations[anm[agent]].astype(np.float32) for agent in
53 | self.agents}
54 | d_rewards = {agent: rewards[anm[agent]] for agent in self.agents}
55 | d_terminated = {agent: terminated[anm[agent]] for agent in self.agents}
56 | d_truncated = {agent: truncated[anm[agent]] for agent in self.agents}
57 | d_infos = {agent: infos[anm[agent]] for agent in self.agents}
58 |
59 | for agent, idx in anm.items():
60 | if (not self.pogema.grid.is_active[idx] or all(truncated) or all(terminated)) and agent in self.agents:
61 | self.agents.remove(agent)
62 |
63 | return d_observations, d_rewards, d_terminated, d_truncated, d_infos
64 |
65 | @property
66 | def unwrapped(self):
67 | return self
68 |
69 | def close(self):
70 | pass
71 |
--------------------------------------------------------------------------------
/pogema/integrations/pymarl.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pogema import GridConfig
4 | from pogema.envs import _make_pogema
5 |
6 |
7 | class PyMarlPogema:
8 |
9 | def __init__(self, grid_config, mh_distance=False):
10 | gc = grid_config
11 | self._grid_config: GridConfig = gc
12 |
13 | self.env = _make_pogema(grid_config)
14 | self._mh_distance = mh_distance
15 | self._observations, _ = self.env.reset()
16 | self.max_episode_steps = gc.max_episode_steps
17 | self.episode_limit = gc.max_episode_steps
18 | self.n_agents = self.env.get_num_agents()
19 |
20 | self.spec = None
21 |
22 | @property
23 | def unwrapped(self):
24 | return self
25 |
26 | def step(self, actions):
27 | self._observations, rewards, terminated, truncated, infos = self.env.step(actions)
28 | info = {}
29 | done = all(terminated) or all(truncated)
30 | if done:
31 | for key, value in infos[0]['metrics'].items():
32 | info[key] = value
33 |
34 | return sum(rewards), done, info
35 |
36 | def get_obs(self):
37 | return np.array([self.get_obs_agent(agent_id) for agent_id in range(self.n_agents)])
38 |
39 | def get_obs_agent(self, agent_id):
40 | return np.array(self._observations[agent_id]).flatten()
41 |
42 | def get_obs_size(self):
43 | return len(np.array(self._observations[0]).flatten())
44 |
45 | def get_state(self):
46 | return self.env.get_state()
47 |
48 | def get_state_size(self):
49 | return len(self.get_state())
50 |
51 | def get_avail_actions(self):
52 | actions = []
53 | for i in range(self.env.get_num_agents()):
54 | actions.append(self.get_avail_agent_actions(i))
55 | return actions
56 |
57 | # noinspection PyUnusedLocal
58 | @staticmethod
59 | def get_avail_agent_actions(agent_id):
60 | return list(range(5))
61 |
62 | @staticmethod
63 | def get_total_actions():
64 | return 5
65 |
66 | def reset(self):
67 | self._grid_config = self.env.grid_config
68 | self._observations, _ = self.env.reset()
69 | return np.array(self._observations).flatten()
70 |
71 | def save_replay(self):
72 | return
73 |
74 | def render(self, *args, **kwargs):
75 | return self.env.render(*args, **kwargs)
76 |
77 | def get_env_info(self):
78 | env_info = {"state_shape": self.get_state_size(),
79 | "obs_shape": self.get_obs_size(),
80 | "n_actions": self.get_total_actions(),
81 | "n_agents": self.n_agents,
82 | "episode_limit": self.episode_limit,
83 | }
84 | return env_info
85 |
86 | @staticmethod
87 | def get_stats():
88 | return {}
89 |
90 | def close(self):
91 | return
92 |
93 | def sample_actions(self):
94 | return self.env.sample_actions()
95 |
--------------------------------------------------------------------------------
/pogema/grid_registry.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pogema.utils import check_grid, render_grid
4 |
5 | GRID_STR_REGISTRY = {}
6 |
7 |
8 | def in_registry(name):
9 | return name in GRID_STR_REGISTRY
10 |
11 |
12 | def get_grid(name):
13 | if in_registry(name):
14 | return GRID_STR_REGISTRY[name]
15 | else:
16 | raise KeyError(f"Grid with name {name} not found")
17 |
18 |
19 | class RegisteredGrid:
20 | FREE = 0
21 | OBSTACLE = 1
22 |
23 | def str_to_grid(self, grid_str):
24 | obstacles = []
25 | agents = {}
26 | targets = {}
27 | for idx, line in enumerate(grid_str.split()):
28 | row = []
29 | for char in line:
30 | if char == '.':
31 | row.append(self.FREE)
32 | elif char == '#':
33 | row.append(self.OBSTACLE)
34 | elif 'A' <= char <= 'Z':
35 | targets[char.lower()] = len(obstacles), len(row)
36 | row.append(self.FREE)
37 | elif 'a' <= char <= 'z':
38 | agents[char.lower()] = len(obstacles), len(row)
39 | row.append(self.FREE)
40 | else:
41 | raise KeyError(f"Unsupported symbol '{char}' at line {idx}")
42 | if row:
43 | if obstacles:
44 | assert len(obstacles[-1]) == len(row), f"Wrong string size for row {idx};"
45 | obstacles.append(row)
46 | return obstacles, agents, targets
47 |
48 | def __init__(self, name: str, grid_str: str = None, agents_positions: list = None, agents_targets: list = None):
49 | self.name = name
50 | self.grid_str = grid_str
51 | self.agents_positions = agents_positions
52 | self.agents_targets = agents_targets
53 |
54 | self.obstacles, agents, targets = self.str_to_grid(grid_str)
55 | self.obstacles = np.array(self.obstacles, dtype=np.int32)
56 |
57 | if agents_positions and agents:
58 | raise ValueError("Agents positions are already defined in the grid string!")
59 | if agents_targets and targets:
60 | raise ValueError("Agents targets are already defined in the grid string!")
61 |
62 | if agents:
63 | self.agents_xy = []
64 | for _, (x, y) in sorted(agents.items()):
65 | self.agents_xy.append([x, y])
66 | else:
67 | self.agents_xy = agents_positions
68 |
69 | if targets:
70 | self.targets_xy = []
71 | for _, (x, y) in sorted(targets.items()):
72 | self.targets_xy.append([x, y])
73 | else:
74 | self.targets_xy = agents_targets
75 | if in_registry(name):
76 | raise ValueError(f"Grid with name {self.name} already registered!")
77 | check_grid(self.obstacles, self.agents_xy, self.targets_xy)
78 |
79 | register_grid(self)
80 |
81 | def get_obstacles(self):
82 | return self.obstacles
83 |
84 | def get_agents_xy(self):
85 | return self.agents_xy
86 |
87 | def get_targets_xy(self):
88 | return self.targets_xy
89 |
90 | def render(self):
91 | render_grid(obstacles=self.get_obstacles(), positions_xy=self.get_agents_xy(), targets_xy=self.get_targets_xy())
92 |
93 |
94 | def register_grid(rg: RegisteredGrid):
95 | if in_registry(rg.name):
96 | raise KeyError(f"Grid with name {rg.name} already registered")
97 | GRID_STR_REGISTRY[rg.name] = rg
98 |
--------------------------------------------------------------------------------
/pogema/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from pydantic import BaseModel
4 |
5 | from typing_extensions import Literal
6 |
7 |
8 | class AgentsTargetsSizeError(Exception):
9 | pass
10 |
11 |
12 | def grid_to_str(grid):
13 | return '\n'.join(''.join('.' if cell == 0 else '#' for cell in row) for row in grid)
14 |
15 |
16 | def check_grid(obstacles, agents_xy, targets_xy):
17 | if bool(agents_xy) != bool(targets_xy):
18 | raise AgentsTargetsSizeError("Agents and targets must be defined together/undefined together!")
19 |
20 | if not agents_xy or not targets_xy:
21 | return
22 |
23 | if len(agents_xy) != len(targets_xy):
24 | raise IndexError("Can't create task. Please provide agents_xy and targets_xy of the same size.")
25 |
26 | # check overlapping of agents
27 | for i in range(len(agents_xy)):
28 | for j in range(i + 1, len(agents_xy)):
29 | if agents_xy[i] == agents_xy[j]:
30 | raise ValueError(f"Agents can't overlap! {agents_xy[i]} is in both {i} and {j} position.")
31 |
32 | for start_xy, finish_xy in zip(agents_xy, targets_xy):
33 | s_x, s_y = start_xy
34 | if obstacles[s_x, s_y]:
35 | raise KeyError(f'Cell is {s_x, s_y} occupied by obstacle.')
36 | f_x, f_y = finish_xy
37 | if obstacles[f_x, f_y]:
38 | raise KeyError(f'Cell is {f_x, f_y} occupied by obstacle.')
39 |
40 | # todo check connectivity of starts and finishes
41 |
42 |
43 | def render_grid(obstacles, positions_xy=None, targets_xy=None, is_active=None, mode='human'):
44 | if positions_xy is None:
45 | positions_xy = []
46 | if targets_xy is None:
47 | targets_xy = []
48 | if is_active is None:
49 | if positions_xy:
50 | is_active = [True] * len(positions_xy)
51 | else:
52 | is_active = []
53 | from io import StringIO
54 | import string
55 | from gymnasium import utils as gym_utils
56 | from contextlib import closing
57 |
58 | outfile = StringIO() if mode == 'ansi' else sys.stdout
59 | chars = string.digits + string.ascii_letters + string.punctuation
60 | positions_map = {(x, y): id_ for id_, (x, y) in enumerate(positions_xy) if is_active[id_]}
61 | finishes_map = {(x, y): id_ for id_, (x, y) in enumerate(targets_xy) if is_active[id_]}
62 | for line_index, line in enumerate(obstacles):
63 | out = ''
64 | for cell_index, cell in enumerate(line):
65 | if cell == CommonSettings().FREE:
66 | agent_id = positions_map.get((line_index, cell_index), None)
67 | finish_id = finishes_map.get((line_index, cell_index), None)
68 |
69 | if agent_id is not None:
70 | out += str(gym_utils.colorize(' ' + chars[agent_id % len(chars)] + ' ', color='red', bold=True,
71 | highlight=False))
72 | elif finish_id is not None:
73 | out += str(
74 | gym_utils.colorize('|' + chars[finish_id % len(chars)] + '|', 'white', highlight=False))
75 | else:
76 | out += str(gym_utils.colorize(str(' . '), 'white', highlight=False))
77 | else:
78 | out += str(gym_utils.colorize(str(' '), 'cyan', bold=False, highlight=True))
79 | out += '\n'
80 | outfile.write(out)
81 |
82 | if mode != 'human':
83 | with closing(outfile):
84 | return outfile.getvalue()
85 |
86 |
87 | class CommonSettings(BaseModel):
88 | MOVES: list = [[0, 0], [-1, 0], [1, 0], [0, -1], [0, 1], ]
89 | FREE: Literal[0] = 0
90 | OBSTACLE: Literal[1] = 1
91 | empty_outside: bool = True
92 |
--------------------------------------------------------------------------------
/pogema/wrappers/persistence.py:
--------------------------------------------------------------------------------
1 | from gymnasium import Wrapper
2 |
3 |
4 | class AgentState:
5 | def __init__(self, x, y, tx, ty, step, active):
6 | self.x = x
7 | self.y = y
8 | self.tx = tx
9 | self.ty = ty
10 | self.step = step
11 | self.active = active
12 |
13 | def get_xy(self):
14 | return self.x, self.y
15 |
16 | def get_target_xy(self):
17 | return self.tx, self.ty
18 |
19 | def is_active(self):
20 | return self.active
21 |
22 | def get_step(self):
23 | return self.step
24 |
25 | def __eq__(self, other):
26 | o = other
27 | return self.x == o.x and self.y == o.y and self.tx == o.tx and self.ty == o.ty and self.active == o.active
28 |
29 | def __str__(self):
30 | return str([self.x, self.y, self.tx, self.ty, self.step, self.active])
31 |
32 |
33 | class PersistentWrapper(Wrapper):
34 | def __init__(self, env, xy_offset=None):
35 | super().__init__(env)
36 | self._step = None
37 | self._agent_states = None
38 | self._xy_offset = xy_offset
39 |
40 | def step(self, action):
41 | result = self.env.step(action)
42 | self._step += 1
43 | for agent_idx in range(self.get_num_agents()):
44 | agent_state = self._get_agent_state(self.grid, agent_idx)
45 | if agent_state != self._agent_states[agent_idx][-1]:
46 | self._agent_states[agent_idx].append(agent_state)
47 |
48 | return result
49 |
50 | def step_back(self):
51 | if self._step <= 0:
52 | return False
53 | self._step -= 1
54 | self.set_elapsed_steps(self._step)
55 | for idx in reversed(range(self.get_num_agents())):
56 |
57 | if self._step < self._agent_states[idx][-1].step:
58 | self._agent_states[idx].pop()
59 | state = self._agent_states[idx][-1]
60 |
61 | if state.active:
62 | self.grid.show_agent(idx)
63 | else:
64 | self.grid.hide_agent(idx)
65 | self.grid.move_agent_to_cell(idx, state.x, state.y)
66 | self.grid.finishes_xy[idx] = state.tx, state.ty
67 |
68 | return True
69 |
70 | def _get_agent_state(self, grid, agent_idx):
71 | x, y = grid.positions_xy[agent_idx]
72 | tx, ty = grid.finishes_xy[agent_idx]
73 | active = grid.is_active[agent_idx]
74 | if self._xy_offset:
75 | x += self._xy_offset
76 | y += self._xy_offset
77 | tx += self._xy_offset
78 | ty += self._xy_offset
79 | return AgentState(x, y, tx, ty, self._step, active)
80 |
81 | def reset(self, **kwargs):
82 | result = self.env.reset(**kwargs)
83 |
84 | self._step = 0
85 |
86 | self._agent_states = []
87 | for agent_idx in range(self.get_num_agents()):
88 | self._agent_states.append([self._get_agent_state(self.grid, agent_idx)])
89 |
90 | return result
91 |
92 | @staticmethod
93 | def agent_state_to_full_list(agent_states, num_steps):
94 | result = []
95 | current_state_id = 0
96 | # going over num_steps + 1, to handle last step
97 | for episode_step in range(num_steps + 1):
98 | if current_state_id < len(agent_states) - 1 and agent_states[current_state_id + 1].step == episode_step:
99 | current_state_id += 1
100 | result.append(agent_states[current_state_id])
101 | return result
102 |
103 | @classmethod
104 | def decompress_history(cls, history):
105 | max_steps = max([agent_states[-1].step for agent_states in history])
106 | result = [cls.agent_state_to_full_list(agent_states, max_steps) for agent_states in history]
107 | return result
108 |
109 | def get_full_history(self):
110 | return [self.agent_state_to_full_list(agent_states, self._step) for agent_states in self._agent_states]
111 |
112 | def get_history(self):
113 | return self._agent_states
114 |
--------------------------------------------------------------------------------
/version_history.MD:
--------------------------------------------------------------------------------
1 | The development history of POGEMA, starting from version 1.0.0.
2 |
3 | Version 1.5.0 (August, 2025)
4 | • Added support of custom targets_xy for lifelong MAPF.
5 | • Improved work with rectangular grids. Added width and height attributes to GridConfig.
6 | • Added method update_config to properly update all attributes of GridConfig.
7 | • Added more tests for new features - custom targets_xy and width/height attributes.
8 |
9 | Version 1.4.0 (April 5, 2025)
10 | • Extended limits for size of maps and number of agents.
11 | • Fixed ep_length value.
12 | • Updated some tests.
13 |
14 | Version 1.3.0 (June 13, 2024)
15 |
16 | • Updates for integration with newer version of gymnasium.
17 | • Refactored AgentsDensityWrapper for modularity and clarity.
18 | • Introduced RuntimeMetricWrapper for runtime monitoring.
19 | • Enhanced map generation methods and added new metrics like SOC_Makespan.
20 | • Animation improvements for better visualization.
21 |
22 | Version 1.2.2 (September 22, 2023)
23 |
24 | • Implemented soft collision handling for agent interactions.
25 | • Improved lifelong scenario seeding for consistent agent behavior.
26 | • Enhanced metric logging for better integration with PyMARL.
27 |
28 | Version 1.2.0 (August 30, 2023)
29 |
30 | • Fixed import issues with Literal and animation issues.
31 | • Improved visualizations, including grid lines and border toggles.
32 |
33 | Version 1.1.0 (March 30, 2023)
34 |
35 | • Updated dependencies for gymnasium and PettingZoo.
36 | • Added an option to remove animation borders for cleaner outputs.
37 | • Fixed animation bugs for stuck agents.
38 |
39 | Version 1.0.0 (February 2023)
40 |
41 | • Launched core features, including A* policy implementations* and CI/CD support.
42 | • Introduced basic visualization and fixed animation bugs.
43 |
44 | Post-Version Updates
45 |
46 | • Adjusted the number of agents in setups.
47 | • Updated package metadata for better compatibility.
48 | • Addressed legacy issues and improved benchmark generation.
49 |
50 | Version 1.1.6 (February 21, 2023)
51 |
52 | • Fixed static animation issues and added grid object rendering.
53 |
54 | Version 1.1.5 (December 28, 2022)
55 |
56 | • Fixed Python 3.7 compatibility issues and added map registries for better management.
57 | • Introduced an attrition metric.
58 |
59 | Version 1.1.4 (November 18, 2022)
60 |
61 | • Fixed flake8 warnings for improved code quality.
62 |
63 | Version 1.1.3 (October 28, 2022)
64 |
65 | • Corrected random seed initialization for PogemaLifeLong.
66 | • Optimized animation behavior.
67 |
68 | Version 1.1.2 (October 5, 2022)
69 |
70 | • Upgraded SVG animations for better compression.
71 |
72 | Version 1.1.1 (August 30, 2022)
73 |
74 | • Added map_name attributes for clearer references.
75 | • Implemented new observation types (MAPF, POMAPF) and enhanced metrics aggregation.
76 |
77 | Version 1.0.x and Earlier
78 |
79 | • Introduced cooperative reward wrappers and lifelong environment versions.
80 | • Dropped Python 3.6 support and refined animation handling.
81 |
82 | Version 1.0.3 (June 29, 2022)
83 |
84 | • Fixed rendering issues for inactive agents.
85 |
86 | Version 1.0.2 (June 27, 2022)
87 |
88 | • Enhanced customization for agent and target positions.
89 |
90 | Pre-1.0.2 Development (June 2022)
91 |
92 | • Improved tests, refactored code, and removed unnecessary dependencies.
93 | • Introduced the PogemaLifeLong class with target generation and metrics tailored for lifelong scenarios.
94 | • Introduced customizable map rules and agent/target positions.
95 | • Simplified installation by removing unnecessary dependencies.
96 |
97 |
98 | Version 1.0.0 (March 31, 2022)
99 |
100 | • Added predefined configurations for grid environments and improved visualization.
101 | • Integrated PettingZoo support and enhanced usability with better examples.
102 | • Introduced grid_config class for environment configuration and improved state management.
103 | • Added methods for relative position observations and fixed PettingZoo compatibility.
104 | • Documentation improvements for better user guidance.
105 |
106 |
--------------------------------------------------------------------------------
/pogema/a_star_policy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pogema import GridConfig
3 |
4 | from heapq import heappop, heappush
5 |
6 | INF = 1e7
7 |
8 |
9 | class GridMemory:
10 | def __init__(self, start_r=64):
11 | self._memory = np.zeros(shape=(start_r * 2 + 1, start_r * 2 + 1), dtype=bool)
12 |
13 | @staticmethod
14 | def _try_to_insert(x, y, source, target):
15 | r = source.shape[0] // 2
16 | try:
17 | target[x - r:x + r + 1, y - r:y + r + 1] = source
18 | return True
19 | except ValueError:
20 | return False
21 |
22 | def _increase_memory(self):
23 | m = self._memory
24 | r = self._memory.shape[0]
25 | self._memory = np.zeros(shape=(r * 2 + 1, r * 2 + 1))
26 | assert self._try_to_insert(r, r, m, self._memory)
27 |
28 | def update(self, x, y, obstacles):
29 | while True:
30 | r = self._memory.shape[0] // 2
31 | if self._try_to_insert(r + x, r + y, obstacles, self._memory):
32 | break
33 | self._increase_memory()
34 |
35 | def is_obstacle(self, x, y):
36 | r = self._memory.shape[0] // 2
37 | if -r <= x <= r and -r <= y <= r:
38 | return self._memory[r + x, r + y]
39 | else:
40 | return False
41 |
42 |
43 | class Node:
44 | def __init__(self, coord: (int, int) = (INF, INF), g: int = 0, h: int = 0):
45 | self.i, self.j = coord
46 | self.g = g
47 | self.h = h
48 | self.f = g + h
49 |
50 | def __lt__(self, other):
51 | if self.f != other.f:
52 | return self.f < other.f
53 | elif self.g != other.g:
54 | return self.g < other.g
55 | else:
56 | return self.i < other.i or self.j < other.j
57 |
58 |
59 | def h(node, target):
60 | nx, ny = node
61 | tx, ty = target
62 | return abs(nx - tx) + abs(ny - ty)
63 |
64 |
65 | def a_star(start, target, grid: GridMemory, max_steps=10000):
66 | open_ = list()
67 | closed = {start: None}
68 |
69 | heappush(open_, Node(start, 0, h(start, target)))
70 |
71 | for step in range(int(max_steps)):
72 | u = heappop(open_)
73 |
74 | for n in [(u.i - 1, u.j), (u.i + 1, u.j), (u.i, u.j - 1), (u.i, u.j + 1)]:
75 | if not grid.is_obstacle(*n) and n not in closed:
76 | heappush(open_, Node(n, u.g + 1, h(n, target)))
77 | closed[n] = (u.i, u.j)
78 |
79 | if step >= max_steps or (u.i, u.j) == target or len(open_) == 0:
80 | break
81 |
82 | next_node = target if target in closed else None
83 | path = []
84 | while next_node is not None:
85 | path.append(next_node)
86 | next_node = closed[next_node]
87 |
88 | return list(reversed(path))
89 |
90 |
91 | class AStarAgent:
92 | def __init__(self, seed=0):
93 | self._moves = GridConfig().MOVES
94 | self._reverse_actions = {tuple(self._moves[i]): i for i in range(len(self._moves))}
95 |
96 | self._gm = None
97 | self._saved_xy = None
98 | self.clear_state()
99 | self._rnd = np.random.default_rng(seed)
100 |
101 | def act(self, obs):
102 | xy, target_xy, obstacles, agents = obs['xy'], obs['target_xy'], obs['obstacles'], obs['agents']
103 |
104 |
105 | if self._saved_xy is not None and h(self._saved_xy, xy) > 1:
106 | raise IndexError("Agent moved more than 1 step. Please, call clear_state method before new episode.")
107 | if self._saved_xy is not None and h(self._saved_xy, xy) == 0 and xy != target_xy:
108 | return self._rnd.integers(len(self._moves))
109 | self._gm.update(*xy, obstacles)
110 | path = a_star(xy, target_xy, self._gm, )
111 | if len(path) <= 1:
112 | action = 0
113 | else:
114 | (x, y), (tx, ty), *_ = path
115 | action = self._reverse_actions[tx - x, ty - y]
116 |
117 | self._saved_xy = xy
118 | return action
119 |
120 | def clear_state(self):
121 | self._saved_xy = None
122 | self._gm = GridMemory()
123 |
124 |
125 | class BatchAStarAgent:
126 | def __init__(self):
127 | self.astar_agents = {}
128 |
129 | def act(self, observations):
130 | actions = []
131 | for idx, obs in enumerate(observations):
132 | if idx not in self.astar_agents:
133 | self.astar_agents[idx] = AStarAgent()
134 | actions.append(self.astar_agents[idx].act(obs))
135 | return actions
136 |
137 | def reset_states(self):
138 | self.astar_agents = {}
139 |
--------------------------------------------------------------------------------
/tests/test_integrations.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 |
3 | import numpy as np
4 |
5 | from pogema import GridConfig
6 | from pogema.integrations.make_pogema import pogema_v0
7 |
8 |
9 | def test_gym_creation():
10 | import gymnasium
11 |
12 | env = gymnasium.make("Pogema-v0", grid_config=GridConfig(integration='gymnasium'))
13 | env.reset()
14 |
15 |
16 | def test_integrations():
17 | for integration in ['SampleFactory', 'PyMARL', 'gymnasium', "PettingZoo", None]:
18 | env = pogema_v0(grid_config=GridConfig(integration=integration))
19 | env.reset()
20 |
21 |
22 | def test_sample_factory_integration():
23 | env = pogema_v0(GridConfig(seed=7, num_agents=4, size=12, integration='SampleFactory'))
24 | env.reset()
25 |
26 | assert env.num_agents == 4
27 | assert env.is_multiagent is True
28 |
29 | # testing auto-reset wrapper
30 | for _ in range(2):
31 | dones = [False]
32 | infos = None
33 | while True:
34 | _, _, terminated, truncated, infos = env.step(env.sample_actions())
35 | if all(terminated) or all(truncated):
36 | break
37 |
38 | assert np.isclose(infos[0]['episode_extra_stats']['ISR'], 0.0)
39 | assert np.isclose(infos[0]['episode_extra_stats']['CSR'], 0.0)
40 |
41 |
42 | def test_pymarl_integration():
43 | gc = GridConfig(seed=7, num_agents=4, obs_radius=3, max_episode_steps=16, integration='PyMARL')
44 | env = pogema_v0(gc)
45 |
46 | _state = [0.14285714285714285, 1.0, 1.0, 0.5714285714285714, 0.42857142857142855, 0.7142857142857143,
47 | 0.8571428571428571, 0.2857142857142857, 0.8571428571428571, 0.42857142857142855, 0.42857142857142855, 1.0,
48 | 0.5714285714285714, 0.7142857142857143, 0.14285714285714285, 0.42857142857142855, 0.0, 0.0, 0.0, 0.0, 0.0,
49 | 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
50 | 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0,
51 | 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
52 | 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
53 | 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0,
54 | 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0,
55 | 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
56 | 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
57 | 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
58 | 0.0, 0.0]
59 | assert np.isclose(_state, env.get_state()).all()
60 |
61 | assert env.episode_limit == 16
62 | assert env.get_env_info()['state_shape'] == 212
63 | assert env.get_env_info()['obs_shape'] == 147
64 | assert env.get_env_info()['n_agents'] == 4
65 | assert env.get_env_info()['episode_limit'] == 16
66 |
67 | num_agents, dimension = env.get_obs().shape
68 | assert num_agents == gc.num_agents
69 | assert dimension == reduce(lambda a, b: a * b, env.env.observation_space.shape)
70 | assert dimension == env.get_obs_size()
71 | assert env.get_state_size() == env.get_state().shape[0]
72 |
73 | done = False
74 | cnt = 0
75 | while not done:
76 | assert cnt < gc.max_episode_steps
77 | _, done, _ = env.step(env.sample_actions())
78 | cnt += 1
79 |
80 |
81 | def test_single_agent_gym_integration():
82 | gc = GridConfig(seed=7, num_agents=1, integration='gymnasium')
83 | env = pogema_v0(gc)
84 |
85 | obs, info = env.reset()
86 |
87 | assert obs.shape == env.observation_space.shape
88 | done = False
89 |
90 | cnt = 0
91 | while True:
92 | assert cnt < gc.max_episode_steps
93 | obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
94 | assert obs.shape == env.observation_space.shape
95 | assert isinstance(reward, float)
96 | assert isinstance(done, bool)
97 | assert isinstance(info, dict)
98 | cnt += 1
99 | if terminated or truncated:
100 | break
101 |
102 |
103 | def test_petting_zoo():
104 | from pettingzoo.test import api_test, parallel_api_test, render_test
105 | # from pettingzoo.test import render_test
106 |
107 | gc = GridConfig(num_agents=16, size=16, integration='PettingZoo')
108 |
109 | parallel_api_test(pogema_v0(gc), num_cycles=1000)
110 |
111 | try:
112 | from pettingzoo.utils import parallel_to_aec
113 |
114 | def env(grid_config: GridConfig = GridConfig(num_agents=20, size=16)):
115 | return parallel_to_aec(pogema_v0(grid_config))
116 |
117 | api_test(env(gc), num_cycles=1000, verbose_progress=True)
118 | # todo fix this
119 | # render_test(lambda: pogema_v0(gc))
120 | except ImportError:
121 | pass
122 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | [](https://github.com/Cognitive-AI-Systems/pogema)
5 |
6 | **Partially-Observable Grid Environment for Multiple Agents**
7 |
8 | [](https://www.codefactor.io/repository/github/tviskaron/pogema)
9 | [](https://pepy.tech/project/pogema)
10 | [](https://github.com/Cognitive-AI-Systems/pogema/actions/workflows/CI.yml)
11 | [](https://github.com/Cognitive-AI-Systems/pogema/actions/workflows/codeql-analysis.yml)
12 |
13 |
14 |
15 | Partially Observable Multi-Agent Pathfinding (PO-MAPF) is a challenging problem that fundamentally differs from regular MAPF. In regular MAPF, a central controller constructs a joint plan for all agents before they start execution. However, PO-MAPF is intrinsically decentralized, and decision-making, such as planning, is interleaved with execution. At each time step, an agent receives a local observation of the environment and decides which action to take. The ultimate goal for the agents is to reach their goals while avoiding collisions with each other and the static obstacles.
16 |
17 | POGEMA stands for Partially-Observable Grid Environment for Multiple Agents. It is a grid-based environment that was specifically designed to be flexible, tunable, and scalable. It can be tailored to a variety of PO-MAPF settings. Currently, the somewhat standard setting is supported, in which agents can move between the cardinal-adjacent cells of the grid, and each action (move or wait) takes one time step. No information sharing occurs between the agents. POGEMA can generate random maps and start/goal locations for the agents. It also accepts custom maps as input.
18 |
19 | ## Installation
20 |
21 | Just install from PyPI:
22 |
23 | ```pip install pogema```
24 |
25 | ## Using Example
26 |
27 | ```python
28 | from pogema import pogema_v0, GridConfig
29 |
30 | env = pogema_v0(grid_config=GridConfig())
31 |
32 | obs, info = env.reset()
33 |
34 | while True:
35 | # Using random policy to make actions
36 | obs, reward, terminated, truncated, info = env.step(env.sample_actions())
37 | env.render()
38 | if all(terminated) or all(truncated):
39 | break
40 | ```
41 |
42 | [](https://colab.research.google.com/drive/19dSEGTQeM3oVJtVjpC162t1XApmv6APc?usp=sharing)
43 |
44 |
45 | ## Baselines and Evaluation Protocol
46 | The baseline implementations and evaluation pipeline are presented in [POGEMA Benchmark](https://github.com/Cognitive-AI-Systems/pogema-benchmark) repository.
47 |
48 | ## Interfaces
49 | Pogema provides integrations with a range of MARL frameworks: PettingZoo, PyMARL and SampleFactory.
50 |
51 | ### PettingZoo
52 |
53 | ```python
54 | from pogema import pogema_v0, GridConfig
55 |
56 | # Create Pogema environment with PettingZoo interface
57 | env = pogema_v0(GridConfig(integration="PettingZoo"))
58 | ```
59 |
60 | ### PyMARL
61 |
62 | ```python
63 | from pogema import pogema_v0, GridConfig
64 |
65 | env = pogema_v0(GridConfig(integration="PyMARL"))
66 | ```
67 |
68 | ### SampleFactory
69 |
70 | ```python
71 | from pogema import pogema_v0, GridConfig
72 |
73 | env = pogema_v0(GridConfig(integration="SampleFactory"))
74 | ```
75 |
76 | ### Gymnasium
77 |
78 | Pogema is fully capable for single-agent pathfinding tasks.
79 |
80 | ```python
81 | from pogema import pogema_v0, GridConfig
82 |
83 | env = pogema_v0(GridConfig(integration="gymnasium"))
84 | ```
85 |
86 | Example of training [stable-baselines3](https://github.com/DLR-RM/stable-baselines3) DQN to solve single-agent pathfinding tasks: [](https://colab.research.google.com/drive/1vPwTd0PnzpWrB-bCHqoLSVwU9G9Lgcmv?usp=sharing)
87 |
88 |
89 |
90 |
91 | ## Customization
92 |
93 | ### Random maps
94 | ```python
95 | from pogema import pogema_v0, GridConfig
96 |
97 | # Define random configuration
98 | grid_config = GridConfig(num_agents=4, # number of agents
99 | size=8, # size of the grid
100 | density=0.4, # obstacle density
101 | seed=1, # set to None for random
102 | # obstacles, agents and targets
103 | # positions at each reset
104 | max_episode_steps=128, # horizon
105 | obs_radius=3, # defines field of view
106 | )
107 |
108 | env = pogema_v0(grid_config=grid_config)
109 | env.reset()
110 | env.render()
111 |
112 | ```
113 |
114 | ### Custom maps
115 | ```python
116 | from pogema import pogema_v0, GridConfig
117 |
118 | grid = """
119 | .....#.....
120 | .....#.....
121 | ...........
122 | .....#.....
123 | .....#.....
124 | #.####.....
125 | .....###.##
126 | .....#.....
127 | .....#.....
128 | ...........
129 | .....#.....
130 | """
131 |
132 | # Define new configuration with 8 randomly placed agents
133 | grid_config = GridConfig(map=grid, num_agents=8)
134 |
135 | # Create custom Pogema environment
136 | env = pogema_v0(grid_config=grid_config)
137 | ```
138 |
139 |
140 |
141 |
142 | ## Citation
143 | If you use this repository in your research or wish to cite it, please make a reference to our paper:
144 | ```
145 | @inproceedings{skrynnik2025pogema,
146 | title={POGEMA: A Benchmark Platform for Cooperative Multi-Agent Pathfinding},
147 | author={Skrynnik, Alexey and Andreychuk, Anton and Borzilov, Anatolii and Chernyavskiy, Alexander and Yakovlev, Konstantin and Panov, Aleksandr},
148 | booktitle={The Thirteenth International Conference on Learning Representations},
149 | year={2025}
150 | }
151 | ```
152 |
--------------------------------------------------------------------------------
/pogema/generator.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 |
6 | from pogema import GridConfig
7 |
8 |
9 | def generate_obstacles(grid_config: GridConfig, rnd=None):
10 | if rnd is None:
11 | rnd = np.random.default_rng(grid_config.seed)
12 | return rnd.binomial(1, grid_config.density, (grid_config.height, grid_config.width))
13 |
14 |
15 | def bfs(grid, moves, start_id, free_cell):
16 | q = []
17 | current_id = start_id
18 |
19 | components = [0 for _ in range(start_id)]
20 |
21 | size_x = len(grid)
22 | size_y = len(grid[0])
23 |
24 | for x in range(size_x):
25 | for y in range(size_y):
26 | if grid[x, y] != free_cell:
27 | continue
28 |
29 | grid[x, y] = current_id
30 | components.append(1)
31 | q.append((x, y))
32 |
33 | while len(q):
34 | cx, cy = q.pop(0)
35 |
36 | for dx, dy in moves:
37 | nx, ny = cx + dx, cy + dy
38 | if 0 <= nx < size_x and 0 <= ny < size_y:
39 | if grid[nx, ny] == free_cell:
40 | grid[nx, ny] = current_id
41 | components[current_id] += 1
42 | q.append((nx, ny))
43 |
44 | current_id += 1
45 | return components
46 |
47 |
48 | def placing_fast(order, components, grid, start_id, num_agents):
49 | link_to_next = [-1 for _ in range(len(order))]
50 | colors = [-1 for _ in range(len(components))]
51 | size = len(order)
52 | for index in range(size):
53 | reversed_index = len(order) - index - 1
54 | color = grid[order[reversed_index]]
55 | link_to_next[reversed_index] = colors[color]
56 | colors[color] = reversed_index
57 |
58 | positions_xy = []
59 | finishes_xy = []
60 |
61 | for index in range(len(order)):
62 | next_index = link_to_next[index]
63 | if next_index == -1:
64 | continue
65 |
66 | positions_xy.append(order[index])
67 | finishes_xy.append(order[next_index])
68 |
69 | link_to_next[next_index] = -1
70 | if len(finishes_xy) >= num_agents:
71 | break
72 | return positions_xy, finishes_xy
73 |
74 |
75 | def placing(order, components, grid, start_id, num_agents):
76 | requests = [[] for _ in range(len(components))]
77 |
78 | done_requests = 0
79 | positions_xy = []
80 | finishes_xy = [(-1, -1) for _ in range(num_agents)]
81 | for x, y in order:
82 | if grid[x, y] < start_id:
83 | continue
84 |
85 | id_ = grid[x, y]
86 | grid[x, y] = 0
87 |
88 | if requests[id_]:
89 | tt = requests[id_].pop()
90 | finishes_xy[tt] = x, y
91 | done_requests += 1
92 | continue
93 |
94 | if len(positions_xy) >= num_agents:
95 | if done_requests >= num_agents:
96 | break
97 | continue
98 |
99 | if components[id_] >= 2:
100 | components[id_] -= 2
101 | requests[id_].append(len(positions_xy))
102 | positions_xy.append((x, y))
103 |
104 | return positions_xy, finishes_xy
105 |
106 | def generate_from_possible_positions(grid_config: GridConfig):
107 | if len(grid_config.possible_agents_xy) < grid_config.num_agents or len(grid_config.possible_targets_xy) < grid_config.num_agents:
108 | raise OverflowError(f"Can't create task. Not enough possible positions for {grid_config.num_agents} agents.")
109 | rng = np.random.default_rng(grid_config.seed)
110 | rng.shuffle(grid_config.possible_agents_xy)
111 | rng.shuffle(grid_config.possible_targets_xy)
112 | return grid_config.possible_agents_xy[:grid_config.num_agents], grid_config.possible_targets_xy[:grid_config.num_agents]
113 |
114 |
115 | def generate_positions_and_targets_fast(obstacles, grid_config):
116 | c = grid_config
117 | grid = obstacles.copy()
118 |
119 | start_id = max(c.FREE, c.OBSTACLE) + 1
120 |
121 | components = bfs(grid, tuple(c.MOVES), start_id, free_cell=c.FREE)
122 | height, width = obstacles.shape
123 | order = [(x, y) for x in range(height) for y in range(width) if grid[x, y] >= start_id]
124 | np.random.default_rng(c.seed).shuffle(order)
125 |
126 | return placing(order=order, components=components, grid=grid, start_id=start_id, num_agents=c.num_agents)
127 |
128 | def generate_from_possible_targets(rnd_generator, possible_positions, position):
129 | new_target = tuple(rnd_generator.choice(possible_positions, 1)[0])
130 | while new_target == position:
131 | new_target = tuple(rnd_generator.choice(possible_positions, 1)[0])
132 | return new_target
133 |
134 | def generate_new_target(rnd_generator, point_to_component, component_to_points, position):
135 | component_id = point_to_component[position]
136 | component = component_to_points[component_id]
137 | new_target = tuple(*rnd_generator.choice(component, 1))
138 | while new_target == position:
139 | new_target = tuple(*rnd_generator.choice(component, 1))
140 | return new_target
141 |
142 |
143 | def get_components(grid_config, obstacles, positions_xy, target_xy):
144 | c = grid_config
145 | grid = obstacles.copy()
146 |
147 | start_id = max(c.FREE, c.OBSTACLE) + 1
148 | bfs(grid, tuple(c.MOVES), start_id, free_cell=c.FREE)
149 | height, width = obstacles.shape
150 |
151 | comp_to_points = defaultdict(list)
152 | point_to_comp = {}
153 | for x in range(height):
154 | for y in range(width):
155 | comp_to_points[grid[x, y]].append((x, y))
156 | point_to_comp[(x, y)] = grid[x, y]
157 | return comp_to_points, point_to_comp
158 |
159 |
160 | def time_it(func, num_iterations):
161 | start = time.monotonic()
162 | for index in range(num_iterations):
163 | grid_config = GridConfig(num_agents=64, size=64, seed=index)
164 | obstacles = generate_obstacles(grid_config)
165 | result = func(obstacles, grid_config, )
166 | if index == 0 and num_iterations > 1:
167 | print(result)
168 | finish = time.monotonic()
169 |
170 | return finish - start
171 |
172 |
173 | def main():
174 | num_iterations = 1000
175 | time_it(generate_positions_and_targets_fast, num_iterations=1)
176 | print('fast:', time_it(generate_positions_and_targets_fast, num_iterations=num_iterations))
177 |
178 |
179 | if __name__ == '__main__':
180 | main()
181 |
--------------------------------------------------------------------------------
/pogema/wrappers/metrics.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | from gymnasium import Wrapper
5 |
6 |
7 | class AbstractMetric(Wrapper):
8 | def _compute_stats(self, step, is_on_goal, finished):
9 | raise NotImplementedError
10 |
11 | def __init__(self, env):
12 | super().__init__(env)
13 | self._current_step = 0
14 |
15 | def step(self, action):
16 | obs, reward, terminated, truncated, infos = self.env.step(action)
17 | finished = all(truncated) or all(terminated)
18 |
19 | metric = self._compute_stats(self._current_step, self.was_on_goal, finished)
20 | self._current_step += 1
21 | if finished:
22 | self._current_step = 0
23 |
24 | if metric:
25 | if 'metrics' not in infos[0]:
26 | infos[0]['metrics'] = {}
27 | infos[0]['metrics'].update(**metric)
28 |
29 | return obs, reward, terminated, truncated, infos
30 |
31 |
32 | class LifeLongAverageThroughputMetric(AbstractMetric):
33 |
34 | def __init__(self, env):
35 | super().__init__(env)
36 | self._solved_instances = 0
37 |
38 | def _compute_stats(self, step, is_on_goal, finished):
39 | for agent_idx, on_goal in enumerate(is_on_goal):
40 | if on_goal:
41 | self._solved_instances += 1
42 | if finished:
43 | result = {'avg_throughput': self._solved_instances / self.grid_config.max_episode_steps}
44 | self._solved_instances = 0
45 | return result
46 |
47 |
48 | class NonDisappearCSRMetric(AbstractMetric):
49 |
50 | def _compute_stats(self, step, is_on_goal, finished):
51 | if finished:
52 | return {'CSR': float(all(is_on_goal))}
53 |
54 |
55 | class NonDisappearISRMetric(AbstractMetric):
56 |
57 | def _compute_stats(self, step, is_on_goal, finished):
58 | if finished:
59 | return {'ISR': float(sum(is_on_goal)) / self.get_num_agents()}
60 |
61 |
62 | class NonDisappearEpLengthMetric(AbstractMetric):
63 |
64 | def _compute_stats(self, step, is_on_goal, finished):
65 | if finished:
66 | return {'ep_length': step + 1}
67 |
68 |
69 | class EpLengthMetric(AbstractMetric):
70 | def __init__(self, env):
71 | super().__init__(env)
72 | self._solve_time = [None for _ in range(self.get_num_agents())]
73 |
74 | def _compute_stats(self, step, is_on_goal, finished):
75 | for idx, on_goal in enumerate(is_on_goal):
76 | if self._solve_time[idx] is None:
77 | if on_goal or finished:
78 | self._solve_time[idx] = step
79 |
80 | if finished:
81 | result = {'ep_length': sum(self._solve_time) / self.get_num_agents() + 1}
82 | self._solve_time = [None for _ in range(self.get_num_agents())]
83 | return result
84 |
85 |
86 | class CSRMetric(AbstractMetric):
87 | def __init__(self, env):
88 | super().__init__(env)
89 | self._solved_instances = 0
90 |
91 | def _compute_stats(self, step, is_on_goal, finished):
92 | self._solved_instances += sum(is_on_goal)
93 | if finished:
94 | results = {'CSR': float(self._solved_instances == self.get_num_agents())}
95 | self._solved_instances = 0
96 | return results
97 |
98 |
99 | class ISRMetric(AbstractMetric):
100 | def __init__(self, env):
101 | super().__init__(env)
102 | self._solved_instances = 0
103 |
104 | def _compute_stats(self, step, is_on_goal, finished):
105 | self._solved_instances += sum(is_on_goal)
106 | if finished:
107 | results = {'ISR': self._solved_instances / self.get_num_agents()}
108 | self._solved_instances = 0
109 | return results
110 |
111 |
112 | class SumOfCostsAndMakespanMetric(AbstractMetric):
113 | def __init__(self, env):
114 | super().__init__(env)
115 | self._solve_time = [None for _ in range(self.get_num_agents())]
116 |
117 | def _compute_stats(self, step, is_on_goal, finished):
118 | for idx, on_goal in enumerate(is_on_goal):
119 | if self._solve_time[idx] is None and (on_goal or finished):
120 | self._solve_time[idx] = step
121 | if not on_goal and not finished:
122 | self._solve_time[idx] = None
123 |
124 | if finished:
125 | result = {'SoC': sum(self._solve_time) + self.get_num_agents(), 'makespan': max(self._solve_time) + 1}
126 | self._solve_time = [None for _ in range(self.get_num_agents())]
127 | return result
128 |
129 |
130 | class AgentsDensityWrapper(Wrapper):
131 | def __init__(self, env):
132 | super().__init__(env)
133 | self._avg_agents_density = None
134 |
135 | def count_agents(self, observations):
136 | avg_agents_density = []
137 | for obs in observations:
138 | traversable_cells = np.size(obs['obstacles']) - np.count_nonzero(obs['obstacles'])
139 | avg_agents_density.append(np.count_nonzero(obs['agents']) / traversable_cells)
140 | self._avg_agents_density.append(np.mean(avg_agents_density))
141 |
142 | def step(self, actions):
143 | observations, rewards, terminated, truncated, infos = self.env.step(actions)
144 | self.count_agents(observations)
145 | if all(terminated) or all(truncated):
146 | if 'metrics' not in infos[0]:
147 | infos[0]['metrics'] = {}
148 | infos[0]['metrics'].update(avg_agents_density=float(np.mean(self._avg_agents_density)))
149 | return observations, rewards, terminated, truncated, infos
150 |
151 | def reset(self, **kwargs):
152 | self._avg_agents_density = []
153 | observations, info = self.env.reset(**kwargs)
154 | self.count_agents(observations)
155 | return observations, info
156 |
157 |
158 | class RuntimeMetricWrapper(Wrapper):
159 | def __init__(self, env):
160 | super().__init__(env)
161 | self._start_time = None
162 | self._env_step_time = None
163 |
164 | def step(self, actions):
165 | env_step_start = time.monotonic()
166 | observations, rewards, terminated, truncated, infos = self.env.step(actions)
167 | env_step_end = time.monotonic()
168 | self._env_step_time += env_step_end - env_step_start
169 | if all(terminated) or all(truncated):
170 | final_time = time.monotonic() - self._start_time - self._env_step_time
171 | if 'metrics' not in infos[0]:
172 | infos[0]['metrics'] = {}
173 | infos[0]['metrics'].update(runtime=final_time)
174 | return observations, rewards, terminated, truncated, infos
175 |
176 | def reset(self, **kwargs):
177 | obs = self.env.reset(**kwargs)
178 | self._start_time = time.monotonic()
179 | self._env_step_time = 0.0
180 | return obs
181 |
--------------------------------------------------------------------------------
/pogema/svg_animation/animation_wrapper.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import cycle
3 | from gymnasium import logger, Wrapper
4 |
5 | from pogema import GridConfig
6 | from pogema.svg_animation.animation_drawer import AnimationConfig, SvgSettings, GridHolder, AnimationDrawer
7 | from pogema.wrappers.persistence import PersistentWrapper, AgentState
8 |
9 |
10 | class AnimationMonitor(Wrapper):
11 | """
12 | Defines the animation, which saves the episode as SVG.
13 | """
14 |
15 | def __init__(self, env, animation_config=AnimationConfig()):
16 | self._working_radius = env.grid_config.obs_radius - 1
17 | env = PersistentWrapper(env, xy_offset=-self._working_radius)
18 |
19 | super().__init__(env)
20 |
21 | self.history = env.get_history()
22 |
23 | self.svg_settings: SvgSettings = SvgSettings()
24 | self.animation_config: AnimationConfig = animation_config
25 |
26 | self._episode_idx = 0
27 |
28 | def step(self, action):
29 | """
30 | Saves information about the episode.
31 | :param action: current actions
32 | :return: obs, reward, done, info
33 | """
34 | obs, reward, terminated, truncated, info = self.env.step(action)
35 |
36 | multi_agent_terminated = isinstance(terminated, (list, tuple)) and all(terminated)
37 | single_agent_terminated = isinstance(terminated, (bool, int)) and terminated
38 | multi_agent_truncated = isinstance(truncated, (list, tuple)) and all(truncated)
39 | single_agent_truncated = isinstance(truncated, (bool, int)) and truncated
40 |
41 | if multi_agent_terminated or single_agent_terminated or multi_agent_truncated or single_agent_truncated:
42 | save_tau = self.animation_config.save_every_idx_episode
43 | if save_tau:
44 | if (self._episode_idx + 1) % save_tau or save_tau == 1:
45 | if not os.path.exists(self.animation_config.directory):
46 | logger.info(f"Creating pogema monitor directory {self.animation_config.directory}", )
47 | os.makedirs(self.animation_config.directory, exist_ok=True)
48 |
49 | path = os.path.join(self.animation_config.directory,
50 | self.pick_name(self.grid_config, self._episode_idx))
51 | self.save_animation(path)
52 |
53 | return obs, reward, terminated, truncated, info
54 |
55 | @staticmethod
56 | def pick_name(grid_config: GridConfig, episode_idx=None, zfill_ep=5):
57 | """
58 | Picks a name for the SVG file.
59 | :param grid_config: configuration of the grid
60 | :param episode_idx: idx of the episode
61 | :param zfill_ep: zfill for the episode number
62 | :return:
63 | """
64 | gc = grid_config
65 | name = 'pogema'
66 | if episode_idx is not None:
67 | name += f'-ep{str(episode_idx).zfill(zfill_ep)}'
68 | if gc:
69 | if gc.map_name:
70 | name += f'-{gc.map_name}'
71 | if gc.seed is not None:
72 | name += f'-seed{gc.seed}'
73 | else:
74 | name += '-render'
75 | return name + '.svg'
76 |
77 | def reset(self, **kwargs):
78 | """
79 | Resets the environment and resets the current positions of agents and targets
80 | :param kwargs:
81 | :return: obs: observation
82 | """
83 | obs = self.env.reset(**kwargs)
84 |
85 | self._episode_idx += 1
86 | self.history = self.env.get_history()
87 |
88 | return obs
89 |
90 | def save_animation(self, name='render.svg', animation_config: AnimationConfig = AnimationConfig()):
91 | """
92 | Saves the animation.
93 | :param name: name of the file
94 | :param animation_config: animation configuration
95 | :return: None
96 | """
97 | wr = self._working_radius
98 | if wr > 0:
99 | obstacles = self.env.get_obstacles(ignore_borders=False)[wr:-wr, wr:-wr]
100 | else:
101 | obstacles = self.env.get_obstacles(ignore_borders=False)
102 | history: list[list[AgentState]] = self.env.decompress_history(self.history)
103 |
104 | svg_settings = SvgSettings()
105 | colors_cycle = cycle(svg_settings.colors)
106 | agents_colors = {index: next(colors_cycle) for index in range(self.grid_config.num_agents)}
107 |
108 | for agent_idx in range(self.grid_config.num_agents):
109 | history[agent_idx].append(history[agent_idx][-1])
110 |
111 | episode_length = len(history[0])
112 | # Change episode length for egocentric environment
113 | if animation_config.egocentric_idx is not None and self.grid_config.on_target == 'finish':
114 | episode_length = history[animation_config.egocentric_idx][-1].step + 1
115 | for agent_idx in range(self.grid_config.num_agents):
116 | history[agent_idx] = history[agent_idx][:episode_length]
117 |
118 | grid_holder = GridHolder(
119 | width=len(obstacles), height=len(obstacles[0]),
120 | obstacles=obstacles,
121 | episode_length=episode_length,
122 | history=history,
123 | obs_radius=self.grid_config.obs_radius,
124 | on_target=self.grid_config.on_target,
125 | colors=agents_colors,
126 | config=animation_config,
127 | svg_settings=svg_settings
128 | )
129 |
130 | animation = AnimationDrawer().create_animation(grid_holder)
131 | with open(name, "w") as f:
132 | f.write(animation.render())
133 |
134 |
135 | def main():
136 | from pogema import GridConfig, pogema_v0, AnimationMonitor, BatchAStarAgent, AnimationConfig
137 |
138 | for egocentric_idx in [0, 1]:
139 | for on_target in ['nothing', 'restart', 'finish']:
140 | grid = """
141 | ....#..
142 | ..#....
143 | .......
144 | .......
145 | #.#.#..
146 | #.#.#..
147 | """
148 | grid_config = GridConfig(size=32, num_agents=2, obs_radius=2, seed=8, on_target=on_target,
149 | max_episode_steps=16,
150 | density=0.1, map=grid, observation_type="POMAPF")
151 | env = pogema_v0(grid_config=grid_config)
152 | env = AnimationMonitor(env, AnimationConfig(save_every_idx_episode=None))
153 |
154 | obs, _ = env.reset()
155 | truncated = terminated = [False]
156 |
157 | agent = BatchAStarAgent()
158 | while not all(terminated) and not all(truncated):
159 | obs, _, terminated, truncated, _ = env.step(agent.act(obs))
160 |
161 | anim_folder = 'renders'
162 | if not os.path.exists(anim_folder):
163 | os.makedirs(anim_folder)
164 |
165 | env.save_animation(f'{anim_folder}/anim-{on_target}.svg')
166 | env.save_animation(f'{anim_folder}/anim-{on_target}-ego-{egocentric_idx}.svg',
167 | AnimationConfig(egocentric_idx=egocentric_idx))
168 | env.save_animation(f'{anim_folder}/anim-static.svg', AnimationConfig(static=True))
169 | env.save_animation(f'{anim_folder}/anim-static-ego.svg', AnimationConfig(egocentric_idx=0, static=True))
170 | env.save_animation(f'{anim_folder}/anim-static-no-agents.svg',
171 | AnimationConfig(show_agents=False, static=True))
172 |
173 |
174 | if __name__ == '__main__':
175 | main()
176 |
--------------------------------------------------------------------------------
/tests/test_deterministic_policy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from heapq import heappop, heappush
3 | from pogema import GridConfig, pogema_v0, AnimationMonitor
4 |
5 | # from pogema.animation import AnimationMonitor
6 |
7 | INF = 1000000007
8 |
9 |
10 | class Node:
11 | def __init__(self, coord=(INF, INF), g: int = 0, h: int = 0):
12 | self.i, self.j = coord
13 | self.g = g
14 | self.h = h
15 | self.f = g + h
16 |
17 | def __lt__(self, other):
18 | return self.f < other.f or ((self.f == other.f) and (self.g < other.g))
19 |
20 |
21 | class AStar:
22 | def __init__(self):
23 | self.start = (0, 0)
24 | self.goal = (0, 0)
25 | self.max_steps = 500
26 | self.OPEN = list()
27 | self.CLOSED = dict()
28 | self.obstacles = set()
29 | self.other_agents = set()
30 | self.best_node = Node(self.start, 0, self.h(self.start))
31 |
32 | def h(self, node):
33 | return abs(node[0] - self.goal[0]) + abs(node[1] - self.goal[1])
34 |
35 | def get_neighbours(self, u):
36 | neighbors = []
37 | for d in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
38 | if (u[0] + d[0], u[1] + d[1]) not in self.obstacles:
39 | neighbors.append((u[0] + d[0], u[1] + d[1]))
40 | return neighbors
41 |
42 | def compute_shortest_path(self):
43 | u = Node()
44 | steps = 0
45 | while len(self.OPEN) > 0 and steps < self.max_steps and (u.i, u.j) != self.goal:
46 | u = heappop(self.OPEN)
47 | if self.best_node.h > u.h:
48 | self.best_node = u
49 | steps += 1
50 | for n in self.get_neighbours((u.i, u.j)):
51 | if n not in self.CLOSED and n not in self.other_agents:
52 | heappush(self.OPEN, Node(n, u.g + 1, self.h(n)))
53 | self.CLOSED[n] = (u.i, u.j)
54 |
55 | def get_next_node(self):
56 | next_node = self.start
57 | if self.goal in self.CLOSED:
58 | next_node = self.goal
59 | while self.CLOSED[next_node] != self.start:
60 | next_node = self.CLOSED[next_node]
61 | return next_node
62 |
63 | def update_obstacles(self, obs, other_agents):
64 | obstacles = np.nonzero(obs)
65 | self.obstacles.clear()
66 | for k in range(len(obstacles[0])):
67 | self.obstacles.add((obstacles[0][k], obstacles[1][k]))
68 | self.other_agents.clear()
69 | agents = np.nonzero(other_agents)
70 | for k in range(len(agents[0])):
71 | self.other_agents.add((agents[0][k], agents[1][k]))
72 |
73 | def reset(self):
74 | self.CLOSED = dict()
75 | self.OPEN = list()
76 | heappush(self.OPEN, Node(self.start, 0, self.h(self.start)))
77 | self.best_node = Node(self.start, 0, self.h(self.start))
78 |
79 | def update_path(self, start, goal):
80 | self.start = start
81 | self.goal = goal
82 | self.reset()
83 | self.compute_shortest_path()
84 |
85 |
86 | class DeterministicPolicy:
87 | def __init__(self, random_seed=42, random_rate=0.2):
88 | self.agents = None
89 | self.actions = {tuple(GridConfig().MOVES[i]): i for i in
90 | range(len(GridConfig().MOVES))} # make a dictionary to translate coordinates of action into id
91 | self.obs_radius = GridConfig().obs_radius
92 | self._rnd = np.random.RandomState(random_seed)
93 | self._random_rate = random_rate
94 |
95 | def get_goal(self, obs):
96 | goal = np.nonzero(obs[2])
97 | goal_i = goal[0][0]
98 | goal_j = goal[1][0]
99 | if obs[0][goal_i][goal_j]:
100 | goal_i -= 1 if goal_i == 0 else 0
101 | goal_i += 1 if goal_i == self.obs_radius * 2 else 0
102 | goal_j -= 1 if goal_j == 0 else 0
103 | goal_j += 1 if goal_j == self.obs_radius * 2 else 0
104 | return goal_i, goal_j
105 |
106 | def act(self, obs) -> list:
107 | if self.agents is None:
108 | self.obs_radius = len(obs[0][0]) // 2
109 | self.agents = [AStar() for _ in range(len(obs))] # create a planner for each of the agents
110 | actions = []
111 | for k in range(len(obs)):
112 | start = (self.obs_radius, self.obs_radius)
113 | goal = self.get_goal(obs[k])
114 | if start == goal: # don't waste time on the agents that have already reached their goals
115 | actions.append(0) # just add useless action to save the order and length of the actions
116 | continue
117 | self.agents[k].update_obstacles(obs[k][0], obs[k][1])
118 | self.agents[k].update_path(start, goal)
119 | next_node = self.agents[k].get_next_node()
120 | actions.append(self.actions[(next_node[0] - start[0], next_node[1] - start[1])])
121 | for idx, action in enumerate(actions):
122 | if self._rnd.random() < self._random_rate:
123 | actions[idx] = self._rnd.randint(1, 4)
124 | return actions
125 |
126 |
127 | def run_policy(gc: GridConfig, save_animation=False):
128 | policy = DeterministicPolicy()
129 | env = pogema_v0(grid_config=gc)
130 | if save_animation:
131 | env = AnimationMonitor(env)
132 |
133 | while True:
134 | obs, info = env.reset()
135 | while True:
136 | obs, reward, terminated, truncated, info = env.step(policy.act(obs))
137 | if all(terminated) or all(truncated):
138 | break
139 |
140 | yield info[0]['metrics']
141 |
142 |
143 | def test_life_long():
144 | gc = GridConfig(num_agents=20, size=8, obs_radius=4, seed=42, max_episode_steps=64, on_target='restart')
145 | results_generator = run_policy(gc, save_animation=False)
146 |
147 | metrics = results_generator.__next__()
148 | assert np.isclose(metrics['avg_throughput'], 1.671875)
149 | metrics = results_generator.__next__()
150 | assert np.isclose(metrics['avg_throughput'], 1.609375)
151 |
152 | gc = GridConfig(num_agents=24, size=8, obs_radius=4, seed=43, max_episode_steps=64, on_target='restart')
153 | results_generator = run_policy(gc, save_animation=False)
154 |
155 | metrics = results_generator.__next__()
156 | assert np.isclose(metrics['avg_throughput'], 0.4375)
157 |
158 |
159 | def test_disappearing():
160 | gc = GridConfig(num_agents=20, size=8, obs_radius=2, seed=42, density=0.2, max_episode_steps=32, on_target='finish')
161 | results_generator = run_policy(gc, save_animation=False)
162 |
163 | metrics = results_generator.__next__()
164 | assert np.isclose(metrics['ep_length'], 22.95)
165 | assert np.isclose(metrics['ISR'], 0.5)
166 | assert np.isclose(metrics['CSR'], 0.0)
167 |
168 | metrics = results_generator.__next__()
169 | assert np.isclose(metrics['ep_length'], 15.55)
170 | assert np.isclose(metrics['ISR'], 0.9)
171 | assert np.isclose(metrics['CSR'], 0.0)
172 |
173 |
174 | def test_non_disappearing():
175 | gc = GridConfig(num_agents=4, size=5, obs_radius=2, seed=3, density=0.2, max_episode_steps=32, on_target='nothing')
176 | results_generator = run_policy(gc, save_animation=False)
177 |
178 | metrics = results_generator.__next__()
179 | assert np.isclose(metrics['ep_length'], 21)
180 | assert np.isclose(metrics['CSR'], 1.0)
181 | assert np.isclose(metrics['ISR'], 1.0)
182 |
183 | metrics = results_generator.__next__()
184 | assert np.isclose(metrics['ep_length'], 14)
185 | assert np.isclose(metrics['CSR'], 1.0)
186 | assert np.isclose(metrics['ISR'], 1.0)
187 |
188 | gc = GridConfig(num_agents=7, size=5, obs_radius=2, seed=3, density=0.2, max_episode_steps=32, on_target='nothing')
189 | results_generator = run_policy(gc, save_animation=False)
190 |
191 | metrics = results_generator.__next__()
192 | assert np.isclose(metrics['ep_length'], 32)
193 | assert np.isclose(metrics['CSR'], 0.0)
194 | assert np.isclose(metrics['ISR'], 0.71428571428)
195 |
--------------------------------------------------------------------------------
/tests/test_pogema_env.py:
--------------------------------------------------------------------------------
1 | import re
2 | import time
3 |
4 | import numpy as np
5 | import pytest
6 | from tabulate import tabulate
7 |
8 | from pogema import pogema_v0, AnimationMonitor
9 |
10 | from pogema.envs import ActionsSampler
11 | from pogema.grid import GridConfig
12 |
13 |
14 | class ActionMapping:
15 | noop: int = 0
16 | up: int = 1
17 | down: int = 2
18 | left: int = 3
19 | right: int = 4
20 |
21 |
22 | def test_moving():
23 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42))
24 | ac = ActionMapping()
25 | env.reset()
26 |
27 | env.step([ac.right, ac.noop])
28 | env.step([ac.up, ac.noop])
29 | env.step([ac.left, ac.noop])
30 | env.step([ac.down, ac.noop])
31 | env.step([ac.down, ac.noop])
32 | env.step([ac.left, ac.noop])
33 | env.step([ac.left, ac.noop])
34 | env.step([ac.up, ac.noop])
35 | env.step([ac.up, ac.noop])
36 | env.step([ac.up, ac.noop])
37 |
38 | env.step([ac.right, ac.noop])
39 | env.step([ac.up, ac.noop])
40 | env.step([ac.right, ac.noop])
41 | env.step([ac.down, ac.noop])
42 | obs, reward, terminated, truncated, infos = env.step([ac.right, ac.noop])
43 |
44 | assert np.isclose([1.0, 0.0], reward).all()
45 | assert np.isclose([True, False], terminated).all()
46 |
47 |
48 | def test_types():
49 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42))
50 | obs, info = env.reset()
51 | assert obs[0].dtype == np.float32
52 |
53 |
54 | def run_episode(grid_config=None, env=None):
55 | if env is None:
56 | env = pogema_v0(grid_config)
57 | env.reset()
58 |
59 | obs, rewards, terminated, truncated, infos = env.reset(), [None], [False], [False], [None]
60 |
61 | results = [[obs, rewards, terminated, truncated, infos]]
62 | while True:
63 | results.append(env.step(env.sample_actions()))
64 | terminated, truncated = results[-1][2], results[-1][3]
65 | if all(terminated) or all(truncated):
66 | break
67 | return results
68 |
69 |
70 | def test_metrics():
71 | *_, infos = run_episode(GridConfig(num_agents=2, seed=5, size=5, max_episode_steps=64))[-1]
72 | assert np.isclose(infos[0]['metrics']['CSR'], 0.0)
73 | assert np.isclose(infos[0]['metrics']['ISR'], 0.5)
74 |
75 | *_, infos = run_episode(GridConfig(num_agents=2, seed=5, size=5, max_episode_steps=512))[-1]
76 | assert np.isclose(infos[0]['metrics']['CSR'], 1.0)
77 | assert np.isclose(infos[0]['metrics']['ISR'], 1.0)
78 |
79 | *_, infos = run_episode(GridConfig(num_agents=5, seed=5, size=5, max_episode_steps=64))[-1]
80 | assert np.isclose(infos[0]['metrics']['CSR'], 0.0)
81 | assert np.isclose(infos[0]['metrics']['ISR'], 0.2)
82 |
83 |
84 | def test_standard_pogema():
85 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='finish'))
86 | env.reset()
87 | run_episode(env=env)
88 |
89 |
90 | def test_pomapf_observation():
91 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='finish',
92 | observation_type='POMAPF'))
93 | obs, info = env.reset()
94 | assert 'agents' in obs[0]
95 | assert 'obstacles' in obs[0]
96 | assert 'xy' in obs[0]
97 | assert 'target_xy' in obs[0]
98 | run_episode(env=env)
99 |
100 |
101 | def test_mapf_observation():
102 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='finish',
103 | observation_type='MAPF'))
104 | obs, info = env.reset()
105 | assert 'global_obstacles' in obs[0]
106 | assert 'global_xy' in obs[0]
107 | assert 'global_target_xy' in obs[0]
108 | run_episode(env=env)
109 |
110 |
111 | def test_standard_pogema_animation():
112 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='finish'))
113 | env = AnimationMonitor(env)
114 | env.reset()
115 | run_episode(env=env)
116 |
117 |
118 | def test_gym_pogema_animation():
119 | import gymnasium
120 | env = gymnasium.make('Pogema-v0',
121 | grid_config=GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42,
122 | on_target='finish'))
123 | env = AnimationMonitor(env)
124 | env.reset()
125 | done = False
126 | while True:
127 | _, _, terminated, truncated, _ = env.step(env.action_space.sample())
128 | if terminated or truncated:
129 | break
130 |
131 |
132 | def test_non_disappearing_pogema():
133 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='nothing'))
134 | env.reset()
135 | run_episode(env=env)
136 |
137 |
138 | def test_non_disappearing_pogema_no_seed():
139 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=None, on_target='nothing'))
140 | env.reset()
141 | run_episode(env=env)
142 |
143 |
144 | def test_non_disappearing_pogema_animation():
145 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='nothing'))
146 | env = AnimationMonitor(env)
147 | env.reset()
148 | run_episode(env=env)
149 |
150 |
151 | def test_life_long_pogema():
152 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='restart'))
153 | env.reset()
154 | run_episode(env=env)
155 |
156 |
157 | def test_life_long_pogema_empty_seed():
158 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=None, on_target='restart'))
159 | env.reset()
160 | run_episode(env=env)
161 |
162 |
163 | def test_life_long_pogema_animation():
164 | env = pogema_v0(GridConfig(num_agents=2, size=6, obs_radius=2, density=0.3, seed=42, on_target='restart'))
165 | env = AnimationMonitor(env)
166 | env.reset()
167 | run_episode(env=env)
168 |
169 |
170 | def test_custom_positions_and_num_agents():
171 | grid = """
172 | ....
173 | ....
174 | """
175 | gc = GridConfig(
176 | map=grid,
177 | agents_xy=[[0, 0], [0, 1], [0, 2], [0, 3]],
178 | targets_xy=[[1, 0], [1, 1], [1, 2], [1, 3]],
179 | )
180 |
181 | for num_agents in range(1, 5):
182 | gc.num_agents = num_agents
183 | env = pogema_v0(grid_config=gc)
184 | env.reset()
185 | assert num_agents == len(env.get_agents_xy())
186 | assert num_agents == len(env.get_targets_xy())
187 |
188 |
189 | def test_custom_positions_and_empty_num_agents():
190 | grid = """
191 | ....
192 | ....
193 | """
194 | gc = GridConfig(
195 | map=grid,
196 | agents_xy=[[0, 0], [0, 1], [0, 2], [0, 3]],
197 | targets_xy=[[1, 0], [1, 1], [1, 2], [1, 3]],
198 | )
199 | env = pogema_v0(grid_config=gc)
200 | env.reset()
201 | assert len(gc.agents_xy) == len(env.get_agents_xy())
202 |
203 |
204 | def test_persistent_env(num_steps=100):
205 | seed = 42
206 |
207 | env = pogema_v0(
208 | grid_config=GridConfig(on_target='finish', seed=seed, num_agents=8, density=0.132, size=8, obs_radius=2,
209 | persistent=True))
210 |
211 | env.reset()
212 | action_sampler = ActionsSampler(env.action_space.n, seed=seed)
213 |
214 | first_run_observations = []
215 |
216 | def state_repr(observations, rewards, terminates, truncates, infos):
217 | return np.concatenate([np.array(observations).flatten(), terminates, truncates, np.array(rewards), ])
218 |
219 | for current_step in range(num_steps):
220 | actions = action_sampler.sample_actions(dim=env.get_num_agents())
221 | obs, reward, terminated, truncated, info = env.step(actions)
222 |
223 | first_run_observations.append(state_repr(obs, reward, terminated, truncated, info))
224 | if all(terminated) or all(truncated):
225 | break
226 |
227 | # resetting the environment to the initial state using backward steps
228 | for current_step in range(num_steps):
229 | if not env.step_back():
230 | break
231 |
232 | action_sampler = ActionsSampler(env.action_space.n, seed=seed)
233 |
234 | second_run_observations = []
235 | for current_step in range(num_steps):
236 | actions = action_sampler.sample_actions(dim=env.get_num_agents())
237 | obs, reward, terminated, truncated, info = env.step(actions)
238 | second_run_observations.append(state_repr(obs, reward, terminated, truncated, info))
239 | assert np.isclose(first_run_observations[current_step], second_run_observations[current_step]).all()
240 | if all(terminated) or all(truncated):
241 | break
242 | assert np.isclose(first_run_observations, second_run_observations).all()
243 |
244 |
245 | def test_steps_per_second_throughput():
246 | table = []
247 | for on_target in ['finish', 'nothing', 'restart']:
248 | for num_agents in [1, 32, 64]:
249 | for size in [32, 64]:
250 | gc = GridConfig(obs_radius=5, seed=42, max_episode_steps=1024,
251 | size=size, num_agents=num_agents, on_target=on_target)
252 |
253 | start_time = time.monotonic()
254 | run_episode(grid_config=gc)
255 | end_time = time.monotonic()
256 | steps_per_second = gc.max_episode_steps / (end_time - start_time)
257 | table.append([on_target, num_agents, size, steps_per_second * gc.num_agents])
258 | print('\n' + tabulate(table, headers=['on_target', 'num_agents', 'size', 'SPS (individual)'], tablefmt='grid'))
259 |
--------------------------------------------------------------------------------
/pogema/grid_config.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Optional, Union
3 | from pydantic import validator, root_validator
4 |
5 | from pogema.utils import CommonSettings
6 |
7 | from typing_extensions import Literal
8 |
9 |
10 | class GridConfig(CommonSettings, ):
11 | on_target: Literal['finish', 'nothing', 'restart'] = 'finish'
12 | seed: Optional[int] = None
13 | width: Optional[int] = None
14 | height: Optional[int] = None
15 | size: int = 8
16 | density: float = 0.3
17 | obs_radius: int = 5
18 | agents_xy: Optional[list] = None
19 | targets_xy: Optional[list] = None
20 | num_agents: Optional[int] = None
21 | possible_agents_xy: Optional[list] = None
22 | possible_targets_xy: Optional[list] = None
23 | collision_system: Literal['block_both', 'priority', 'soft'] = 'priority'
24 | persistent: bool = False
25 | observation_type: Literal['POMAPF', 'MAPF', 'default'] = 'default'
26 | map: Optional[Union[list, str]] = None
27 |
28 | map_name: Optional[str] = None
29 |
30 | integration: Literal['SampleFactory', 'PyMARL', 'rllib', 'gymnasium', 'PettingZoo'] = None
31 | max_episode_steps: int = 64
32 | auto_reset: Optional[bool] = None
33 |
34 | @root_validator
35 | def validate_dimensions_and_positions(cls, values):
36 | width_provided = values.get('width') is not None
37 | height_provided = values.get('height') is not None
38 |
39 | if width_provided and not height_provided:
40 | raise ValueError("Invalid dimension configuration. Please provide height.")
41 | elif not width_provided and height_provided:
42 | raise ValueError("Invalid dimension configuration. Please provide width.")
43 |
44 | if not width_provided and not height_provided:
45 | values['width'] = values.get('size', 8)
46 | values['height'] = values.get('size', 8)
47 | if 'size' not in values or values.get('size') != max(values.get('width'), values.get('height')):
48 | values['size'] = max(values.get('width'), values.get('height'))
49 |
50 |
51 | width = values.get('width')
52 | height = values.get('height')
53 |
54 | if width is not None and height is not None:
55 | agents_xy = values.get('agents_xy')
56 | if agents_xy is not None:
57 | cls.check_positions(agents_xy, width, height)
58 |
59 | targets_xy = values.get('targets_xy')
60 | if targets_xy is not None:
61 | first_element = targets_xy[0]
62 | if isinstance(first_element[0], (list, tuple)):
63 | for agent_goals in targets_xy:
64 | cls.check_positions(agent_goals, width, height)
65 | else:
66 | cls.check_positions(targets_xy, width, height)
67 |
68 | return values
69 |
70 | @validator('seed')
71 | def seed_initialization(cls, v):
72 | assert v is None or (0 <= v < sys.maxsize), "seed must be in [0, " + str(sys.maxsize) + ']'
73 | return v
74 |
75 | @staticmethod
76 | def _validate_dimension(v, field_name):
77 | if v is not None:
78 | if field_name == 'size':
79 | assert 2 <= v <= 4096, f"{field_name} must be in [2, 4096]"
80 | else:
81 | assert 1 <= v <= 4096, f"{field_name} must be in [1, 4096]"
82 | return v
83 |
84 | @validator('size')
85 | def size_restrictions(cls, v):
86 | return cls._validate_dimension(v, 'size')
87 |
88 | @validator('width')
89 | def width_restrictions(cls, v):
90 | return cls._validate_dimension(v, 'width')
91 |
92 | @validator('height')
93 | def height_restrictions(cls, v):
94 | return cls._validate_dimension(v, 'height')
95 |
96 | @validator('density')
97 | def density_restrictions(cls, v):
98 | assert 0.0 <= v <= 1, "density must be in [0, 1]"
99 | return v
100 |
101 | @validator('agents_xy')
102 | def agents_xy_validation(cls, v, values):
103 | if v is not None:
104 | if not isinstance(v, (list, tuple)):
105 | raise ValueError("agents_xy must be a list")
106 | for position in v:
107 | if not isinstance(position, (list, tuple)) or len(position) != 2:
108 | raise ValueError("Position must be a list/tuple of length 2")
109 | if not all(isinstance(coord, int) for coord in position):
110 | raise ValueError("Position coordinates must be integers")
111 | return v
112 |
113 | @validator('targets_xy')
114 | def targets_xy_validation(cls, v, values):
115 | if v is not None:
116 | if not v or not isinstance(v, (list, tuple)):
117 | raise ValueError("targets_xy must be a list")
118 |
119 | first_element = v[0]
120 | if not isinstance(first_element, (list, tuple)):
121 | raise ValueError("Invalid targets_xy format")
122 |
123 | if isinstance(first_element[0], (list, tuple)):
124 | for agent_goals in v:
125 | if not isinstance(agent_goals, (list, tuple)) or len(agent_goals) < 2:
126 | raise ValueError("Each agent must have at least two goals in the sequence")
127 | for position in agent_goals:
128 | if not isinstance(position, (list, tuple)) or len(position) != 2:
129 | raise ValueError("Position must be a list/tuple of length 2")
130 | if not all(isinstance(coord, int) for coord in position):
131 | raise ValueError("Position coordinates must be integers")
132 | else:
133 | on_target = values.get('on_target', 'finish')
134 | if on_target == 'restart':
135 | raise ValueError("on_target='restart' requires goal sequences, not single goals. Use format: targets_xy: [[[x1,y1],[x2,y2]], [[x3,y3],[x4,y4]]]")
136 | for position in v:
137 | if not isinstance(position, (list, tuple)) or len(position) != 2:
138 | raise ValueError("Position must be a list/tuple of length 2")
139 | if not all(isinstance(coord, int) for coord in position):
140 | raise ValueError("Position coordinates must be integers")
141 | return v
142 |
143 | @staticmethod
144 | def check_positions(v, width, height):
145 | for position in v:
146 | if not isinstance(position, (list, tuple)) or len(position) != 2:
147 | raise ValueError("Position must be a list/tuple of length 2")
148 | x, y = position
149 | if not isinstance(x, int) or not isinstance(y, int):
150 | raise ValueError("Position coordinates must be integers")
151 | if not (0 <= x < height and 0 <= y < width):
152 | raise IndexError(f"Position is out of bounds! {position} is not in [{0}, {height}] x [{0}, {width}]")
153 |
154 |
155 | @validator('num_agents', always=True)
156 | def num_agents_must_be_positive(cls, v, values):
157 | if v is None:
158 | if values['agents_xy']:
159 | v = len(values['agents_xy'])
160 | else:
161 | v = 1
162 | assert 1 <= v <= 10000000, "num_agents must be in [1, 10000000]"
163 | return v
164 |
165 | @validator('obs_radius')
166 | def obs_radius_must_be_positive(cls, v):
167 | assert 1 <= v <= 128, "obs_radius must be in [1, 128]"
168 | return v
169 |
170 | @validator('map', always=True)
171 | def map_validation(cls, v, values):
172 | if v is None:
173 | return None
174 | if isinstance(v, str):
175 | v, agents_xy, targets_xy, possible_agents_xy, possible_targets_xy = cls.str_map_to_list(v, values['FREE'],
176 | values['OBSTACLE'])
177 | if agents_xy and targets_xy and values.get('agents_xy') is not None and values.get(
178 | 'targets_xy') is not None:
179 | raise KeyError("""Can't create task. Please provide agents_xy and targets_xy only once.
180 | Either with parameters or with a map.""")
181 | if (agents_xy or targets_xy) and (possible_agents_xy or possible_targets_xy):
182 | raise KeyError("""Can't create task. Mark either possible locations or precise ones.""")
183 | elif agents_xy and targets_xy:
184 | values['agents_xy'] = agents_xy
185 | values['targets_xy'] = targets_xy
186 | values['num_agents'] = len(agents_xy)
187 | elif (values.get('agents_xy') is None or values.get(
188 | 'targets_xy') is None) and possible_agents_xy and possible_targets_xy:
189 | values['possible_agents_xy'] = possible_agents_xy
190 | values['possible_targets_xy'] = possible_targets_xy
191 |
192 | height = len(v)
193 | width = 0
194 | area = 0
195 | for line in v:
196 | width = max(width, len(line))
197 | area += len(line)
198 |
199 | values['size'] = max(width, height)
200 | values['width'] = width
201 | values['height'] = height
202 | values['density'] = sum([sum(line) for line in v]) / area
203 |
204 | return v
205 |
206 | @validator('possible_agents_xy')
207 | def possible_agents_xy_validation(cls, v):
208 | return v
209 |
210 | @validator('possible_targets_xy')
211 | def possible_targets_xy_validation(cls, v):
212 | return v
213 |
214 | @staticmethod
215 | def str_map_to_list(str_map, free, obstacle):
216 | obstacles = []
217 | agents = {}
218 | targets = {}
219 | possible_agents_xy = []
220 | possible_targets_xy = []
221 | special_chars = {'@', '$', '!'}
222 |
223 | for row_idx, line in enumerate(str_map.split()):
224 | row = []
225 | for col_idx, char in enumerate(line):
226 | position = (row_idx, col_idx)
227 |
228 | if char == '.':
229 | row.append(free)
230 | possible_agents_xy.append(position)
231 | possible_targets_xy.append(position)
232 | elif char == '#':
233 | row.append(obstacle)
234 | elif char in special_chars:
235 | row.append(free)
236 | if char == '@':
237 | possible_agents_xy.append(position)
238 | elif char == '$':
239 | possible_targets_xy.append(position)
240 | elif 'A' <= char <= 'Z':
241 | targets[char.lower()] = position
242 | row.append(free)
243 | possible_agents_xy.append(position)
244 | possible_targets_xy.append(position)
245 | elif 'a' <= char <= 'z':
246 | agents[char.lower()] = position
247 | row.append(free)
248 | possible_agents_xy.append(position)
249 | possible_targets_xy.append(position)
250 | else:
251 | raise KeyError(f"Unsupported symbol '{char}' at line {row_idx}")
252 |
253 | if row:
254 | assert len(obstacles[-1]) == len(row) if obstacles else True, f"Wrong string size for row {row_idx};"
255 | obstacles.append(row)
256 |
257 | agents_xy = [[x, y] for _, (x, y) in sorted(agents.items())]
258 | targets_xy = [[x, y] for _, (x, y) in sorted(targets.items())]
259 |
260 | assert len(targets_xy) == len(agents_xy), "Mismatch in number of agents and targets."
261 |
262 | if not any(char in special_chars for char in str_map):
263 | possible_agents_xy, possible_targets_xy = None, None
264 |
265 | return obstacles, agents_xy, targets_xy, possible_agents_xy, possible_targets_xy
266 |
267 | def update_config(self, **kwargs):
268 | current_values = self.dict()
269 |
270 | if 'size' in kwargs:
271 | current_values.pop('width', None)
272 | current_values.pop('height', None)
273 | elif 'width' in kwargs or 'height' in kwargs:
274 | current_values.pop('size', None)
275 | current_values.update(kwargs)
276 | new_instance = GridConfig(**current_values)
277 |
278 | for field_name, field_value in new_instance.__dict__.items():
279 | setattr(self, field_name, field_value)
280 |
--------------------------------------------------------------------------------
/pogema/grid.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import warnings
3 |
4 | import numpy as np
5 |
6 | from pogema.generator import generate_obstacles, generate_positions_and_targets_fast, \
7 | get_components, generate_from_possible_positions
8 | from .grid_config import GridConfig
9 | from .grid_registry import in_registry, get_grid
10 | from .utils import render_grid
11 |
12 |
13 | class Grid:
14 |
15 | def __init__(self, grid_config: GridConfig, add_artificial_border: bool = True, num_retries=10):
16 |
17 | self.config = grid_config
18 | self.rnd = np.random.default_rng(grid_config.seed)
19 | if self.config.map is None:
20 | self.obstacles = generate_obstacles(self.config)
21 | else:
22 | self.obstacles = np.array([np.array(line) for line in self.config.map])
23 | if in_registry(self.config.map_name):
24 | self.obstacles = get_grid(self.config.map_name).get_obstacles()
25 | self.obstacles = self.obstacles.astype(np.int32)
26 |
27 | if grid_config.targets_xy and grid_config.agents_xy:
28 | self.starts_xy = grid_config.agents_xy
29 |
30 | if isinstance(grid_config.targets_xy[0][0], (list, tuple)):
31 | self.finishes_xy = [sequence[0] for sequence in grid_config.targets_xy]
32 | else:
33 | self.finishes_xy = grid_config.targets_xy
34 |
35 | if len(self.starts_xy) != len(self.finishes_xy):
36 | raise IndexError("Can't create task. Please provide agents_xy and targets_xy of the same size.")
37 | if grid_config.num_agents > len(self.starts_xy):
38 | raise IndexError(f"Not enough agents_xy and targets_xy to place {grid_config.num_agents} agents")
39 | self.starts_xy = self.starts_xy[:grid_config.num_agents]
40 | self.finishes_xy = self.finishes_xy[:grid_config.num_agents]
41 | for start_xy, finish_xy in zip(self.starts_xy, self.finishes_xy):
42 | s_x, s_y = start_xy
43 | f_x, f_y = finish_xy
44 | if self.config.map is not None and self.obstacles[s_x, s_y] == grid_config.OBSTACLE:
45 | warnings.warn(f"There is an obstacle on a start point ({s_x}, {s_y}), replacing with free cell",
46 | Warning, stacklevel=2)
47 | self.obstacles[s_x, s_y] = grid_config.FREE
48 | if self.config.map is not None and self.obstacles[f_x, f_y] == grid_config.OBSTACLE:
49 | warnings.warn(f"There is an obstacle on a finish point ({f_x}, {f_y}), replacing with free cell",
50 | Warning, stacklevel=2)
51 | self.obstacles[f_x, f_y] = grid_config.FREE
52 | elif grid_config.possible_agents_xy and grid_config.possible_targets_xy:
53 | self.starts_xy, self.finishes_xy = generate_from_possible_positions(self.config)
54 | else:
55 | self.starts_xy, self.finishes_xy = generate_positions_and_targets_fast(self.obstacles, self.config)
56 |
57 | if len(self.starts_xy) != len(self.finishes_xy):
58 | for attempt in range(num_retries):
59 | if len(self.starts_xy) == len(self.finishes_xy):
60 | warnings.warn(f'Created valid configuration only with {attempt} attempts.', Warning, stacklevel=2)
61 | break
62 | if self.config.map is None:
63 | self.obstacles = generate_obstacles(self.config)
64 | self.starts_xy, self.finishes_xy = generate_positions_and_targets_fast(self.obstacles, self.config)
65 |
66 | if not self.starts_xy or not self.finishes_xy or len(self.starts_xy) != len(self.finishes_xy):
67 | raise OverflowError(
68 | "Can't create task. Please check grid grid_config, especially density, num_agent and map.")
69 |
70 | if add_artificial_border:
71 | self.add_artificial_border()
72 |
73 | filled_positions = np.zeros(self.obstacles.shape)
74 | for x, y in self.starts_xy:
75 | filled_positions[x, y] = 1
76 |
77 | self.positions = filled_positions
78 | self.positions_xy = self.starts_xy
79 | self._initial_xy = deepcopy(self.starts_xy)
80 | self.is_active = {agent_id: True for agent_id in range(self.config.num_agents)}
81 |
82 | def add_artificial_border(self):
83 | gc = self.config
84 | r = gc.obs_radius
85 | if gc.empty_outside:
86 | filled_obstacles = np.zeros(np.array(self.obstacles.shape) + r * 2)
87 | else:
88 | filled_obstacles = self.rnd.binomial(1, gc.density, np.array(self.obstacles.shape) + r * 2)
89 |
90 | height, width = filled_obstacles.shape
91 | filled_obstacles[r - 1, r - 1:width - r + 1] = gc.OBSTACLE
92 | filled_obstacles[r - 1:height - r + 1, r - 1] = gc.OBSTACLE
93 | filled_obstacles[height - r, r - 1:width - r + 1] = gc.OBSTACLE
94 | filled_obstacles[r - 1:height - r + 1, width - r] = gc.OBSTACLE
95 | filled_obstacles[r:height - r, r:width - r] = self.obstacles
96 |
97 | self.obstacles = filled_obstacles
98 |
99 | self.starts_xy = [(x + r, y + r) for x, y in self.starts_xy]
100 | self.finishes_xy = [(x + r, y + r) for x, y in self.finishes_xy]
101 |
102 | def get_obstacles(self, ignore_borders=False):
103 | gc = self.config
104 | if ignore_borders:
105 | return self.obstacles[gc.obs_radius:-gc.obs_radius, gc.obs_radius:-gc.obs_radius].copy()
106 | return self.obstacles.copy()
107 |
108 | @staticmethod
109 | def _cut_borders_xy(positions, obs_radius):
110 | return [[x - obs_radius, y - obs_radius] for x, y in positions]
111 |
112 | @staticmethod
113 | def _filter_inactive(pos, active_flags):
114 | return [pos for idx, pos in enumerate(pos) if active_flags[idx]]
115 |
116 | def get_grid_config(self):
117 | return deepcopy(self.config)
118 |
119 | # def _get_grid_config(self) -> GridConfig:
120 | # return self.env.grid_config
121 |
122 | def _prepare_positions(self, positions, only_active, ignore_borders):
123 | gc = self.config
124 |
125 | if only_active:
126 | positions = self._filter_inactive(positions, [idx for idx, active in self.is_active.items() if active])
127 |
128 | if ignore_borders:
129 | positions = self._cut_borders_xy(positions, gc.obs_radius)
130 |
131 | return positions
132 |
133 | def get_agents_xy(self, only_active=False, ignore_borders=False):
134 | return self._prepare_positions(deepcopy(self.positions_xy), only_active, ignore_borders)
135 |
136 | @staticmethod
137 | def to_relative(coordinates, offset):
138 | result = deepcopy(coordinates)
139 | for idx, _ in enumerate(result):
140 | x, y = result[idx]
141 | dx, dy = offset[idx]
142 | result[idx] = x - dx, y - dy
143 | return result
144 |
145 | def get_agents_xy_relative(self):
146 | return self.to_relative(self.positions_xy, self._initial_xy)
147 |
148 | def get_targets_xy_relative(self):
149 | return self.to_relative(self.finishes_xy, self._initial_xy)
150 |
151 | def get_targets_xy(self, only_active=False, ignore_borders=False):
152 | return self._prepare_positions(deepcopy(self.finishes_xy), only_active, ignore_borders)
153 |
154 | def _normalize_coordinates(self, coordinates):
155 | gc = self.config
156 |
157 | x, y = coordinates
158 |
159 | x -= gc.obs_radius
160 | y -= gc.obs_radius
161 |
162 | x /= gc.height - 1
163 | y /= gc.width - 1
164 |
165 | return x, y
166 |
167 | def get_state(self, ignore_borders=False, as_dict=False):
168 | agents_xy = list(map(self._normalize_coordinates, self.get_agents_xy(ignore_borders)))
169 | targets_xy = list(map(self._normalize_coordinates, self.get_targets_xy(ignore_borders)))
170 |
171 | obstacles = self.get_obstacles(ignore_borders)
172 |
173 | if as_dict:
174 | return {"obstacles": obstacles, "agents_xy": agents_xy, "targets_xy": targets_xy}
175 |
176 | return np.concatenate(list(map(lambda x: np.array(x).flatten(), [agents_xy, targets_xy, obstacles])))
177 |
178 | def get_observation_shape(self):
179 | full_radius = self.config.obs_radius * 2 + 1
180 | return 2, full_radius, full_radius
181 |
182 | def get_num_actions(self):
183 | return len(self.config.MOVES)
184 |
185 | def get_obstacles_for_agent(self, agent_id):
186 | x, y = self.positions_xy[agent_id]
187 | r = self.config.obs_radius
188 | return self.obstacles[x - r:x + r + 1, y - r:y + r + 1].astype(np.float32)
189 |
190 | def get_positions(self, agent_id):
191 | x, y = self.positions_xy[agent_id]
192 | r = self.config.obs_radius
193 | return self.positions[x - r:x + r + 1, y - r:y + r + 1].astype(np.float32)
194 |
195 | def get_target(self, agent_id):
196 |
197 | x, y = self.positions_xy[agent_id]
198 | fx, fy = self.finishes_xy[agent_id]
199 | if x == fx and y == fy:
200 | return 0.0, 0.0
201 | rx, ry = fx - x, fy - y
202 | dist = np.sqrt(rx ** 2 + ry ** 2)
203 | return rx / dist, ry / dist
204 |
205 | def get_square_target(self, agent_id):
206 | c = self.config
207 | full_size = self.config.obs_radius * 2 + 1
208 | result = np.zeros((full_size, full_size))
209 | x, y = self.positions_xy[agent_id]
210 | fx, fy = self.finishes_xy[agent_id]
211 | dx, dy = x - fx, y - fy
212 |
213 | dx = min(dx, c.obs_radius) if dx >= 0 else max(dx, -c.obs_radius)
214 | dy = min(dy, c.obs_radius) if dy >= 0 else max(dy, -c.obs_radius)
215 | result[c.obs_radius - dx, c.obs_radius - dy] = 1
216 | return result.astype(np.float32)
217 |
218 | def render(self, mode='human'):
219 | render_grid(self.obstacles, self.positions_xy, self.finishes_xy, self.is_active, mode=mode)
220 |
221 | def move_agent_to_cell(self, agent_id, x, y):
222 | if self.positions[self.positions_xy[agent_id]] == self.config.FREE:
223 | raise KeyError("Agent {} is not in the map".format(agent_id))
224 | self.positions[self.positions_xy[agent_id]] = self.config.FREE
225 | if self.obstacles[x, y] != self.config.FREE or self.positions[x, y] != self.config.FREE:
226 | raise ValueError(f"Can't force agent to blocked position {x} {y}")
227 | self.positions_xy[agent_id] = x, y
228 | self.positions[self.positions_xy[agent_id]] = self.config.OBSTACLE
229 |
230 | def has_obstacle(self, x, y):
231 | return self.obstacles[x, y] == self.config.OBSTACLE
232 |
233 | def move_without_checks(self, agent_id, action):
234 | x, y = self.positions_xy[agent_id]
235 | dx, dy = self.config.MOVES[action]
236 | self.positions[x, y] = self.config.FREE
237 | self.positions[x+dx, y+dy] = self.config.OBSTACLE
238 | self.positions_xy[agent_id] = (x+dx, y+dy)
239 |
240 | def move(self, agent_id, action):
241 | x, y = self.positions_xy[agent_id]
242 | dx, dy = self.config.MOVES[action]
243 | if self.obstacles[x + dx, y + dy] == self.config.FREE:
244 | if self.positions[x + dx, y + dy] == self.config.FREE:
245 | self.positions[x, y] = self.config.FREE
246 | x += dx
247 | y += dy
248 | self.positions[x, y] = self.config.OBSTACLE
249 | self.positions_xy[agent_id] = (x, y)
250 |
251 | def on_goal(self, agent_id):
252 | return self.positions_xy[agent_id] == self.finishes_xy[agent_id]
253 |
254 | def is_active(self, agent_id):
255 | return self.is_active[agent_id]
256 |
257 | def hide_agent(self, agent_id):
258 | if not self.is_active[agent_id]:
259 | return False
260 | self.is_active[agent_id] = False
261 |
262 | self.positions[self.positions_xy[agent_id]] = self.config.FREE
263 |
264 | return True
265 |
266 | def show_agent(self, agent_id):
267 | if self.is_active[agent_id]:
268 | return False
269 |
270 | self.is_active[agent_id] = True
271 | if self.positions[self.positions_xy[agent_id]] == self.config.OBSTACLE:
272 | raise KeyError("The cell is already occupied")
273 | self.positions[self.positions_xy[agent_id]] = self.config.OBSTACLE
274 | return True
275 |
276 |
277 | class GridLifeLong(Grid):
278 | def __init__(self, grid_config: GridConfig, add_artificial_border: bool = True, num_retries=10):
279 |
280 | super().__init__(grid_config, add_artificial_border, num_retries)
281 |
282 | self.component_to_points, self.point_to_component = get_components(grid_config, self.obstacles,
283 | self.positions_xy, self.finishes_xy)
284 |
285 | for i in range(len(self.positions_xy)):
286 | position, target = self.positions_xy[i], self.finishes_xy[i]
287 | if self.point_to_component[position] != self.point_to_component[target]:
288 | warnings.warn(f"The start point ({position[0]}, {position[1]}) and the goal"
289 | f" ({target[0]}, {target[1]}) are in different components. The goal is changed.",
290 | Warning, stacklevel=2)
291 |
--------------------------------------------------------------------------------
/pogema/svg_animation/animation_drawer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import typing
3 | from dataclasses import dataclass
4 |
5 | from pogema import GridConfig
6 | from pogema.svg_animation.svg_objects import Line, RectangleHref, Animation, Circle, Rectangle
7 |
8 |
9 | @dataclass
10 | class AnimationConfig:
11 | directory: str = 'renders/'
12 | static: bool = False
13 | show_agents: bool = True
14 | egocentric_idx: typing.Optional[int] = None
15 | uid: typing.Optional[str] = None
16 | save_every_idx_episode: typing.Optional[int] = 1
17 | show_grid_lines: bool = True
18 |
19 |
20 | @dataclass
21 | class SvgSettings:
22 | r: int = 35
23 | stroke_width: int = 10
24 | scale_size: int = 100
25 | time_scale: float = 0.25
26 | draw_start: int = 100
27 | rx: int = 15
28 |
29 | obstacle_color: str = '#84A1AE'
30 | ego_color: str = '#c1433c'
31 | ego_other_color: str = '#6e81af'
32 | shaded_opacity: float = 0.2
33 | egocentric_shaded: bool = True
34 | stroke_dasharray: int = 25
35 |
36 | colors: tuple = (
37 | '#c1433c',
38 | '#2e6f9e',
39 | '#6e81af',
40 | '#00b9c8',
41 | '#72D5C8',
42 | '#0ea08c',
43 | '#8F7B66',
44 | )
45 |
46 |
47 | @dataclass
48 | class GridHolder:
49 | obstacles: typing.Any = None
50 | episode_length: int = None
51 | height: int = None
52 | width: int = None
53 | colors: dict = None
54 | history: list = None
55 | obs_radius: int = None
56 | grid_config: GridConfig = None
57 | on_target: str = None
58 | config: AnimationConfig = None
59 | svg_settings: SvgSettings = None
60 |
61 |
62 | class Drawing:
63 |
64 | def __init__(self, height, width, svg_settings):
65 | self.height = height
66 | self.width = width
67 | self.origin = (0, 0)
68 | self.elements = []
69 | self.svg_settings = svg_settings
70 |
71 | def add_element(self, element):
72 | self.elements.append(element)
73 |
74 | def render(self):
75 | scale = max(self.height, self.width) / 512
76 | scaled_width = math.ceil(self.width / scale)
77 | scaled_height = math.ceil(self.height / scale)
78 |
79 | dx, dy = self.origin
80 | view_box = (dx, dy - self.height, self.width, self.height)
81 |
82 | svg_header = f'''
83 | ')
98 | return "\n".join(elements_svg)
99 |
100 |
101 | class AnimationDrawer:
102 |
103 | def __init__(self):
104 | pass
105 |
106 | def create_animation(self, grid_holder: GridHolder):
107 | gh = grid_holder
108 | render_width = gh.height * gh.svg_settings.scale_size + gh.svg_settings.scale_size
109 | render_height = gh.width * gh.svg_settings.scale_size + gh.svg_settings.scale_size
110 | drawing = Drawing(width=render_width, height=render_height, svg_settings=SvgSettings())
111 | obstacles = self.create_obstacles(gh)
112 |
113 | agents = []
114 | targets = []
115 |
116 | if gh.config.show_agents:
117 | agents = self.create_agents(gh)
118 | targets = self.create_targets(gh)
119 |
120 | if not gh.config.static:
121 | self.animate_agents(agents, gh)
122 | self.animate_targets(targets, gh)
123 | if gh.config.show_grid_lines:
124 | grid_lines = self.create_grid_lines(gh, render_width, render_height)
125 | for line in grid_lines:
126 | drawing.add_element(line)
127 | for obj in [*obstacles, *agents, *targets]:
128 | drawing.add_element(obj)
129 |
130 | if gh.config.egocentric_idx is not None:
131 | field_of_view = self.create_field_of_view(grid_holder=gh)
132 | if not gh.config.static:
133 | self.animate_obstacles(obstacles=obstacles, grid_holder=gh)
134 | self.animate_field_of_view(field_of_view, gh)
135 | drawing.add_element(field_of_view)
136 |
137 | return drawing
138 |
139 | @staticmethod
140 | def fix_point(x, y, length):
141 | return length - y - 1, x
142 |
143 | @staticmethod
144 | def check_in_radius(x1, y1, x2, y2, r) -> bool:
145 | return x2 - r <= x1 <= x2 + r and y2 - r <= y1 <= y2 + r
146 |
147 | @staticmethod
148 | def create_grid_lines(grid_holder: GridHolder, render_width, render_height):
149 | gh = grid_holder
150 | offset = 0
151 | stroke_settings = {'class': 'line'}
152 | grid_lines = []
153 | for i in range(-1, grid_holder.height + 1):
154 | x = i * gh.svg_settings.scale_size + gh.svg_settings.scale_size / 2
155 | grid_lines.append(Line(x1=x, y1=offset, x2=x, y2=render_height - offset, **stroke_settings))
156 |
157 | for i in range(-1, grid_holder.width + 1):
158 | y = i * gh.svg_settings.scale_size + gh.svg_settings.scale_size / 2
159 | grid_lines.append(Line(x1=offset, y1=y, x2=render_width - offset, y2=y, **stroke_settings))
160 |
161 | return grid_lines
162 |
163 | @staticmethod
164 | def create_field_of_view(grid_holder):
165 | gh: GridHolder = grid_holder
166 | ego_idx = gh.config.egocentric_idx
167 | x, y = gh.history[ego_idx][0].get_xy()
168 | cx = gh.svg_settings.draw_start + y * gh.svg_settings.scale_size
169 | cy = gh.svg_settings.draw_start + (gh.width - x - 1) * gh.svg_settings.scale_size
170 |
171 | dr = (grid_holder.obs_radius + 1) * gh.svg_settings.scale_size - gh.svg_settings.stroke_width * 2
172 | result = Rectangle(
173 | x=cx - dr + gh.svg_settings.r, y=cy - dr + gh.svg_settings.r,
174 | width=2 * dr - 2 * gh.svg_settings.r, height=2 * dr - 2 * gh.svg_settings.r,
175 | stroke=gh.svg_settings.ego_color, stroke_width=gh.svg_settings.stroke_width,
176 | fill='none', rx=gh.svg_settings.rx, stroke_dasharray=gh.svg_settings.stroke_dasharray
177 | )
178 |
179 | return result
180 |
181 | def animate_field_of_view(self, view, grid_holder):
182 | gh: GridHolder = grid_holder
183 | x_path = []
184 | y_path = []
185 | ego_idx = grid_holder.config.egocentric_idx
186 | for state in gh.history[ego_idx]:
187 | x, y = state.get_xy()
188 | dr = (grid_holder.obs_radius + 1) * gh.svg_settings.scale_size - gh.svg_settings.stroke_width * 2
189 | cx = gh.svg_settings.draw_start + y * gh.svg_settings.scale_size
190 | cy = -gh.svg_settings.draw_start + -(gh.width - x - 1) * gh.svg_settings.scale_size
191 | x_path.append(str(cx - dr + gh.svg_settings.r))
192 | y_path.append(str(cy - dr + gh.svg_settings.r))
193 |
194 | visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[ego_idx]]
195 |
196 | view.add_animation(self.compressed_anim('x', x_path, gh.svg_settings.time_scale))
197 | view.add_animation(self.compressed_anim('y', y_path, gh.svg_settings.time_scale))
198 | view.add_animation(self.compressed_anim('visibility', visibility, gh.svg_settings.time_scale))
199 |
200 | def animate_agents(self, agents, grid_holder):
201 | gh: GridHolder = grid_holder
202 | ego_idx = gh.config.egocentric_idx
203 |
204 | for agent_idx, agent in enumerate(agents):
205 | x_path = []
206 | y_path = []
207 | opacity = []
208 | for idx, agent_state in enumerate(gh.history[agent_idx]):
209 | x, y = agent_state.get_xy()
210 |
211 | x_path.append(str(gh.svg_settings.draw_start + y * gh.svg_settings.scale_size))
212 | y_path.append(str(-gh.svg_settings.draw_start + -(gh.width - x - 1) * gh.svg_settings.scale_size))
213 |
214 | if ego_idx is not None:
215 | ego_x, ego_y = gh.history[ego_idx][idx].get_xy()
216 | if self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius):
217 | opacity.append('1.0')
218 | else:
219 | opacity.append(str(gh.svg_settings.shaded_opacity))
220 |
221 | visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]]
222 |
223 | agent.add_animation(self.compressed_anim('cy', y_path, gh.svg_settings.time_scale))
224 | agent.add_animation(self.compressed_anim('cx', x_path, gh.svg_settings.time_scale))
225 | agent.add_animation(self.compressed_anim('visibility', visibility, gh.svg_settings.time_scale))
226 | if opacity:
227 | agent.add_animation(self.compressed_anim('opacity', opacity, gh.svg_settings.time_scale))
228 |
229 | @classmethod
230 | def compressed_anim(cls, attr_name, tokens, time_scale, rep_cnt='indefinite'):
231 | tokens, times = cls.compress_tokens(tokens)
232 | cumulative = [0, ]
233 | for t in times:
234 | cumulative.append(cumulative[-1] + t)
235 | times = [str(round(value / cumulative[-1], 10)) for value in cumulative]
236 | tokens = [tokens[0]] + tokens
237 |
238 | times = times
239 | tokens = tokens
240 | return Animation(
241 | attributeName=attr_name, dur=f'{time_scale * (-1 + cumulative[-1])}s',
242 | values=";".join(tokens), repeatCount=rep_cnt, keyTimes=";".join(times)
243 | )
244 |
245 | @staticmethod
246 | def wisely_add(token, cnt, tokens, times):
247 | if cnt > 1:
248 | tokens += [token, token]
249 | times += [1, cnt - 1]
250 | else:
251 | tokens.append(token)
252 | times.append(cnt)
253 |
254 | @classmethod
255 | def compress_tokens(cls, input_tokens: list):
256 | tokens = []
257 | times = []
258 | if input_tokens:
259 | cur_idx = 0
260 | cnt = 1
261 | for idx in range(1, len(input_tokens)):
262 | if input_tokens[idx] == input_tokens[cur_idx]:
263 | cnt += 1
264 | else:
265 | cls.wisely_add(input_tokens[cur_idx], cnt, tokens, times)
266 | cnt = 1
267 | cur_idx = idx
268 | cls.wisely_add(input_tokens[cur_idx], cnt, tokens, times)
269 | return tokens, times
270 |
271 | def animate_targets(self, targets, grid_holder):
272 | gh: GridHolder = grid_holder
273 | ego_idx = gh.config.egocentric_idx
274 |
275 | for agent_idx, target in enumerate(targets):
276 | target_idx = ego_idx if ego_idx is not None else agent_idx
277 |
278 | x_path = []
279 | y_path = []
280 |
281 | for step_idx, state in enumerate(gh.history[target_idx]):
282 | x, y = state.get_target_xy()
283 | x_path.append(str(gh.svg_settings.draw_start + y * gh.svg_settings.scale_size))
284 | y_path.append(str(-gh.svg_settings.draw_start + -(gh.width - x - 1) * gh.svg_settings.scale_size))
285 |
286 | visibility = ['visible' if state.is_active() else 'hidden' for state in gh.history[agent_idx]]
287 |
288 | if gh.on_target == 'restart' or gh.on_target == 'wait':
289 | target.add_animation(self.compressed_anim('cy', y_path, gh.svg_settings.time_scale))
290 | target.add_animation(self.compressed_anim('cx', x_path, gh.svg_settings.time_scale))
291 | target.add_animation(self.compressed_anim("visibility", visibility, gh.svg_settings.time_scale))
292 |
293 | def create_obstacles(self, grid_holder):
294 | gh = grid_holder
295 | result = []
296 |
297 | for i in range(gh.height):
298 | for j in range(gh.width):
299 | x, y = self.fix_point(i, j, gh.width)
300 |
301 | if gh.obstacles[x][y]:
302 | obs_settings = {}
303 | obs_settings.update(
304 | x=gh.svg_settings.draw_start + i * gh.svg_settings.scale_size - gh.svg_settings.r,
305 | y=gh.svg_settings.draw_start + j * gh.svg_settings.scale_size - gh.svg_settings.r,
306 | height=gh.svg_settings.r * 2,
307 | )
308 |
309 | if gh.config.egocentric_idx is not None and gh.svg_settings.egocentric_shaded:
310 | initial_positions = [agent_states[0].get_xy() for agent_states in gh.history]
311 | ego_x, ego_y = initial_positions[gh.config.egocentric_idx]
312 | if not self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius):
313 | obs_settings.update(opacity=gh.svg_settings.shaded_opacity)
314 |
315 | result.append(RectangleHref(**obs_settings))
316 |
317 | return result
318 |
319 | def animate_obstacles(self, obstacles, grid_holder):
320 | gh: GridHolder = grid_holder
321 | obstacle_idx = 0
322 |
323 | for i in range(gh.height):
324 | for j in range(gh.width):
325 | x, y = self.fix_point(i, j, gh.width)
326 | if not gh.obstacles[x][y]:
327 | continue
328 | opacity = []
329 | seen = set()
330 | for step_idx, agent_state in enumerate(gh.history[gh.config.egocentric_idx]):
331 | ego_x, ego_y = agent_state.get_xy()
332 | if self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius):
333 | seen.add((x, y))
334 | if (x, y) in seen:
335 | opacity.append(str(1.0))
336 | else:
337 | opacity.append(str(gh.svg_settings.shaded_opacity))
338 |
339 | obstacle = obstacles[obstacle_idx]
340 | obstacle.add_animation(self.compressed_anim('opacity', opacity, gh.svg_settings.time_scale))
341 |
342 | obstacle_idx += 1
343 |
344 | def create_agents(self, grid_holder):
345 | initial_positions = [state[0].get_xy() for state in grid_holder.history if state[0].is_active()]
346 | agents = []
347 | gh: GridHolder = grid_holder
348 | ego_idx = grid_holder.config.egocentric_idx
349 |
350 | for idx, (x, y) in enumerate(initial_positions):
351 | circle_settings = {
352 | 'cx': gh.svg_settings.draw_start + y * gh.svg_settings.scale_size,
353 | 'cy': gh.svg_settings.draw_start + (grid_holder.width - x - 1) * gh.svg_settings.scale_size,
354 | 'r': gh.svg_settings.r, 'fill': grid_holder.colors[idx], 'class': 'agent',
355 | }
356 |
357 | if ego_idx is not None:
358 | ego_x, ego_y = initial_positions[ego_idx]
359 | is_out_of_radius = not self.check_in_radius(x, y, ego_x, ego_y, grid_holder.obs_radius)
360 | circle_settings['fill'] = gh.svg_settings.ego_other_color
361 | if idx == ego_idx:
362 | circle_settings['fill'] = gh.svg_settings.ego_color
363 | elif is_out_of_radius and gh.svg_settings.egocentric_shaded:
364 | circle_settings['opacity'] = gh.svg_settings.shaded_opacity
365 |
366 | agents.append(Circle(**circle_settings))
367 |
368 | return agents
369 |
370 | @staticmethod
371 | def create_targets(grid_holder):
372 | gh: GridHolder = grid_holder
373 | targets = []
374 | for agent_idx, agent_states in enumerate(gh.history):
375 |
376 | tx, ty = agent_states[0].get_target_xy()
377 | x, y = ty, gh.width - tx - 1
378 |
379 | if not any([agent_state.is_active() for agent_state in gh.history[agent_idx]]):
380 | continue
381 |
382 | circle_settings = {"class": 'target'}
383 | circle_settings.update(
384 | cx=gh.svg_settings.draw_start + x * gh.svg_settings.scale_size, r=gh.svg_settings.r,
385 | cy=gh.svg_settings.draw_start + y * gh.svg_settings.scale_size, stroke=gh.colors[agent_idx],
386 | )
387 |
388 | if gh.config.egocentric_idx is not None:
389 | if gh.config.egocentric_idx != agent_idx:
390 | continue
391 |
392 | circle_settings.update(stroke=gh.svg_settings.ego_color)
393 | target = Circle(**circle_settings)
394 | targets.append(target)
395 | return targets
396 |
--------------------------------------------------------------------------------
/tests/test_grid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pydantic import ValidationError
3 |
4 | from pogema import GridConfig
5 | from pogema.grid import Grid
6 | import pytest
7 |
8 | from pogema.integrations.make_pogema import pogema_v0
9 |
10 |
11 | def test_obstacle_creation():
12 | config = GridConfig(seed=1, obs_radius=2, size=5, num_agents=1, density=0.2)
13 | obstacles = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
14 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
15 | [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
16 | [0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
17 | [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
18 | [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
19 | [0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0],
20 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
21 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
22 | assert np.isclose(Grid(config).obstacles, obstacles).all()
23 |
24 | config = GridConfig(seed=3, obs_radius=1, size=4, num_agents=1, density=0.4)
25 | obstacles = [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
26 | [1.0, 0.0, 0.0, 1.0, 0.0, 1.0],
27 | [1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
28 | [1.0, 1.0, 0.0, 0.0, 0.0, 1.0],
29 | [1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
30 | [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
31 | assert np.isclose(Grid(config).obstacles, obstacles).all()
32 |
33 |
34 | def test_initial_positions():
35 | config = GridConfig(seed=1, obs_radius=2, size=5, num_agents=1, density=0.2)
36 | positions_xy = [(2, 4)]
37 | assert np.isclose(Grid(config).positions_xy, positions_xy).all()
38 |
39 | config = GridConfig(seed=1, obs_radius=2, size=12, num_agents=10, density=0.2)
40 | positions_xy = [(13, 10), (7, 4), (4, 3), (2, 11), (12, 6), (8, 11), (6, 8), (2, 12), (2, 10), (9, 11)]
41 | assert np.isclose(Grid(config).positions_xy, positions_xy).all()
42 |
43 |
44 | def test_goals():
45 | config = GridConfig(seed=1, obs_radius=2, size=5, num_agents=1, density=0.4)
46 | finishes_xy = [(5, 2)]
47 | assert np.isclose(Grid(config).finishes_xy, finishes_xy).all()
48 |
49 | config = GridConfig(seed=2, obs_radius=2, size=12, num_agents=10, density=0.2)
50 | finishes_xy = [(11, 10), (8, 11), (2, 13), (3, 5), (12, 6), (9, 12), (9, 6), (9, 2), (10, 2), (6, 11)]
51 | assert np.isclose(Grid(config).finishes_xy, finishes_xy).all()
52 |
53 |
54 | def test_overflow():
55 | with pytest.raises(OverflowError):
56 | Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=100, density=0.0))
57 |
58 | with pytest.raises(OverflowError):
59 | Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=1, density=1.0))
60 |
61 |
62 | def test_overflow_warning():
63 | with pytest.warns(Warning):
64 | for _ in range(1000):
65 | Grid(GridConfig(obs_radius=2, size=4, num_agents=6, density=0.3), num_retries=10000)
66 |
67 |
68 | def test_edge_cases():
69 | with pytest.raises(ValidationError):
70 | GridConfig(seed=1, obs_radius=2, size=1, num_agents=1, density=0.4)
71 |
72 | with pytest.raises(ValidationError):
73 | GridConfig(seed=1, obs_radius=2, size=4, num_agents=0, density=0.4)
74 |
75 | with pytest.raises(OverflowError):
76 | Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=1, density=1.0))
77 |
78 | with pytest.raises(ValidationError):
79 | Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=1, density=2.0))
80 |
81 |
82 | def test_edge_cases_for_custom_map():
83 | test_map = [[0, 0, 0]]
84 | with pytest.raises(OverflowError):
85 | Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=2, map=test_map))
86 | with pytest.raises(OverflowError):
87 | Grid(GridConfig(seed=2, obs_radius=2, size=4, num_agents=4, map=test_map))
88 |
89 |
90 | def test_custom_map():
91 | test_map = [
92 | [1, 0, 0],
93 | [0, 1, 0],
94 | [0, 0, 1],
95 | ]
96 | grid = Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=2, map=test_map))
97 | obstacles = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
98 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
99 | [0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
100 | [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
101 | [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
102 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
103 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
104 | assert np.isclose(grid.obstacles, obstacles).all()
105 |
106 | test_map = [
107 | [0, 1, 0],
108 | [0, 1, 0],
109 | [0, 0, 0],
110 | [0, 1, 0],
111 | [0, 1, 0],
112 | ]
113 | grid = Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=2, map=test_map))
114 | obstacles = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
115 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
116 | [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
117 | [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
118 | [0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
119 | [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
120 | [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
121 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
122 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
123 | assert np.isclose(grid.obstacles, obstacles).all()
124 |
125 | test_map = [
126 | [0, 0, 1, 0, 0],
127 | [1, 0, 0, 0, 0],
128 | [0, 1, 0, 0, 1],
129 | ]
130 | grid = Grid(GridConfig(seed=1, obs_radius=2, size=4, num_agents=2, map=test_map))
131 | obstacles = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
132 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
133 | [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
134 | [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
135 | [0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
136 | [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
137 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
138 | assert np.isclose(grid.obstacles, obstacles).all()
139 |
140 |
141 | def test_overflow_for_custom_map():
142 | test_map = [
143 | [0, 0, 1, 0, 0],
144 | [0, 1, 0, 1, 0],
145 | [0, 1, 0, 0, 1],
146 | ]
147 | with pytest.raises(OverflowError):
148 | Grid(GridConfig(obs_radius=2, size=4, num_agents=5, density=0.3, map=test_map), num_retries=100)
149 |
150 |
151 | def test_str_custom_map():
152 | grid_map = """
153 | .a...#.....
154 | .....#.....
155 | ..C.....b..
156 | .....#.....
157 | .....#.....
158 | #.####.....
159 | .....###.##
160 | .....#.....
161 | .c...#.....
162 | .B.......A.
163 | .....#.....
164 | """
165 | grid = Grid(GridConfig(obs_radius=2, size=4, density=0.3, map=grid_map))
166 | assert (grid.config.num_agents == 3)
167 | assert (np.isclose(0.1404958, grid.config.density))
168 | assert (np.isclose(11, grid.config.size))
169 |
170 | grid_map = """.....#...."""
171 | grid = Grid(GridConfig(seed=2, num_agents=3, map=grid_map))
172 | assert (grid.config.num_agents == 3)
173 | assert (np.isclose(0.1, grid.config.density))
174 | assert (np.isclose(10, grid.config.size))
175 |
176 |
177 | def test_custom_starts_and_finishes_random():
178 | agents_xy = [(x, x) for x in range(8)]
179 | targets_xy = [(x, x) for x in range(8, 16)]
180 | grid_config = GridConfig(seed=12, size=16, num_agents=8, agents_xy=agents_xy, targets_xy=targets_xy)
181 | env = pogema_v0(grid_config=grid_config)
182 | env.reset()
183 | r = grid_config.obs_radius
184 | assert [(x - r, y - r) for x, y in env.grid.positions_xy] == agents_xy and \
185 | [(x - r, y - r) for x, y in env.grid.finishes_xy] == targets_xy
186 |
187 |
188 | def test_out_of_bounds_for_custom_positions():
189 | Grid(GridConfig(seed=12, size=17, agents_xy=[[0, 16]], targets_xy=[[16, 0]]))
190 |
191 | with pytest.raises(IndexError):
192 | GridConfig(seed=12, size=17, agents_xy=[[0, 17]], targets_xy=[[0, 0]])
193 | with pytest.raises(IndexError):
194 | GridConfig(seed=12, size=17, agents_xy=[[0, 0]], targets_xy=[[0, 17]])
195 | with pytest.raises(IndexError):
196 | GridConfig(seed=12, size=17, agents_xy=[[-1, 0]], targets_xy=[[0, 0]])
197 | with pytest.raises(IndexError):
198 | GridConfig(seed=12, size=17, agents_xy=[[0, 0]], targets_xy=[[0, -1]])
199 |
200 |
201 | def test_duplicated_params():
202 | grid_map = "Aa"
203 | with pytest.raises(KeyError):
204 | GridConfig(agents_xy=[[0, 0]], targets_xy=[[0, 0]], map=grid_map)
205 |
206 |
207 | def test_custom_grid_with_empty_agents_and_targets():
208 | grid_map = """...."""
209 | Grid(GridConfig(agents_xy=None, targets_xy=None, map=grid_map, num_agents=1))
210 |
211 |
212 | def test_custom_grid_with_specific_positions():
213 | grid_map = """
214 | !!!!!!!!!!!!!!!!!!
215 | !@@!@@!$$$$$$$$$$!
216 | !@@!@@!##########!
217 | !@@!@@!$$$$$$$$$$!
218 | !!!!!!!!!!!!!!!!!!
219 | !@@!@@!$$$$$$$$$$!
220 | !@@!@@!##########!
221 | !@@!@@!$$$$$$$$$$!
222 | !!!!!!!!!!!!!!!!!!
223 | """
224 | Grid(GridConfig(obs_radius=2, size=4, num_agents=24, map=grid_map))
225 | with pytest.raises(OverflowError):
226 | Grid(GridConfig(obs_radius=2, size=4, num_agents=25, map=grid_map))
227 |
228 | grid_map = """
229 | !!!!!!!!!!!
230 | !@@!@@!$$$$
231 | !@@!@@!####
232 | !@@!@@!$$$$
233 | !!!!!!!!!!!
234 | !@@!@@!$$$$
235 | !@@!@@!####
236 | !@@!@@!$$$$
237 | !!!!!!!!!!!
238 | """
239 | Grid(GridConfig(obs_radius=2, num_agents=16, map=grid_map))
240 | with pytest.raises(OverflowError):
241 | Grid(GridConfig(obs_radius=2, num_agents=17, map=grid_map))
242 |
243 | grid_map = """
244 | !!!!!!!!!!!
245 | !@@!@@!.Ab.
246 | !@@!@@!####
247 | !@@!@@!.aB.
248 |
249 | """
250 | with pytest.raises(KeyError):
251 | Grid(GridConfig(obs_radius=2, map=grid_map))
252 |
253 |
254 | def test_restricted_grid():
255 | grid = """
256 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
257 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@!
258 | !@@!@@!##########!##########!##########!@@!@@!
259 | !@@!@@!$$$$$$$$$$!$$$$$$$$$$!$$$$$$$$$$!@@!@@!
260 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
261 | """
262 | env = pogema_v0(grid_config=GridConfig(map=grid, num_agents=24, seed=0, obs_radius=2))
263 | env.reset()
264 |
265 | with pytest.raises(OverflowError):
266 | env = pogema_v0(grid_config=GridConfig(map=grid, num_agents=25, seed=0, obs_radius=2))
267 | env.reset()
268 |
269 |
270 | def test_rectangular_grid_basic():
271 | config = GridConfig(width=12, height=8)
272 | assert np.isclose(config.width, 12)
273 | assert np.isclose(config.height, 8)
274 | assert np.isclose(config.size, 12)
275 |
276 |
277 | def test_rectangular_grid_backward_compatibility():
278 | config = GridConfig(size=10)
279 | assert np.isclose(config.size, 10)
280 | assert np.isclose(config.width, 10)
281 | assert np.isclose(config.height, 10)
282 |
283 |
284 | def test_rectangular_grid_mixed_config():
285 | config = GridConfig(size=8, width=12, height=6)
286 | assert np.isclose(config.width, 12)
287 | assert np.isclose(config.height, 6)
288 | assert np.isclose(config.size, 12)
289 |
290 |
291 | def test_rectangular_grid_validation():
292 | with pytest.raises(ValueError):
293 | GridConfig(width=12)
294 |
295 | with pytest.raises(ValueError):
296 | GridConfig(height=8)
297 |
298 | GridConfig(width=12, height=8)
299 | GridConfig(size=10)
300 | GridConfig(size=10, width=12, height=8)
301 |
302 |
303 | def test_rectangular_grid_position_validation():
304 | config = GridConfig(width=12, height=8, agents_xy=[[0, 11], [7, 0]], targets_xy=[[7, 11], [0, 0]])
305 | assert len(config.agents_xy) == 2
306 | assert len(config.targets_xy) == 2
307 |
308 | with pytest.raises(IndexError):
309 | GridConfig(width=12, height=8, agents_xy=[[8, 0]], targets_xy=[[0, 0]])
310 |
311 | with pytest.raises(IndexError):
312 | GridConfig(width=12, height=8, agents_xy=[[0, 12]], targets_xy=[[0, 0]])
313 |
314 |
315 | def test_rectangular_grid_creation():
316 | config = GridConfig(width=12, height=8, seed=1, num_agents=2)
317 | grid = Grid(config)
318 |
319 | assert np.isclose(grid.config.width, 12)
320 | assert np.isclose(grid.config.height, 8)
321 | assert np.isclose(grid.config.size, 12)
322 |
323 |
324 | def test_goal_sequences_validation():
325 | config = GridConfig(
326 | width=8, height=8,
327 | agents_xy=[[0, 0], [1, 1]],
328 | targets_xy=[
329 | [[2, 2], [3, 3], [4, 4]],
330 | [[2, 4], [3, 5]]
331 | ]
332 | )
333 | assert np.isclose(len(config.targets_xy), 2)
334 | assert np.isclose(len(config.targets_xy[0]), 3)
335 | assert np.isclose(len(config.targets_xy[1]), 2)
336 |
337 | config = GridConfig(
338 | width=8, height=8,
339 | agents_xy=[[0, 0], [1, 1]],
340 | targets_xy=[[7, 7], [6, 6]]
341 | )
342 | assert np.isclose(len(config.targets_xy), 2)
343 |
344 | with pytest.raises(ValueError):
345 | GridConfig(
346 | width=8, height=8,
347 | agents_xy=[[0, 0], [1, 1]],
348 | targets_xy=[[[2, 2], [3, 3]], [4, 4]]
349 | )
350 |
351 | with pytest.raises(ValueError):
352 | GridConfig(
353 | width=8, height=8,
354 | agents_xy=[[0, 0]],
355 | targets_xy=[[[2, 2]]]
356 | )
357 |
358 | with pytest.raises(ValueError):
359 | GridConfig(
360 | width=8, height=8,
361 | agents_xy=[[0, 0]],
362 | targets_xy=[[[2.5, 2], [3, 3]]]
363 | )
364 |
365 | with pytest.raises(IndexError):
366 | GridConfig(
367 | width=8, height=8,
368 | agents_xy=[[0, 0]],
369 | targets_xy=[[[2, 2], [10, 10]]]
370 | )
371 |
372 | with pytest.raises(ValueError, match="on_target='restart' requires goal sequences"):
373 | GridConfig(
374 | width=8, height=8,
375 | agents_xy=[[0, 0], [1, 1]],
376 | targets_xy=[[2, 2], [3, 3]],
377 | on_target='restart'
378 | )
379 |
380 |
381 | def test_grid_with_goal_sequences():
382 | config = GridConfig(
383 | width=8, height=8,
384 | agents_xy=[[0, 0], [1, 1]],
385 | targets_xy=[
386 | [[2, 2], [3, 3], [4, 4]],
387 | [[2, 4], [3, 5]]
388 | ]
389 | )
390 |
391 | grid = Grid(config)
392 |
393 | expected_initial_targets = [[2, 2], [2, 4]]
394 | r = config.obs_radius
395 | expected_with_offset = [(x + r, y + r) for x, y in expected_initial_targets]
396 |
397 | assert np.isclose(grid.finishes_xy, expected_with_offset).all()
398 |
399 |
400 | def test_pogema_lifelong_with_sequences():
401 | from pogema.envs import PogemaLifeLong
402 | import warnings
403 |
404 | config = GridConfig(
405 | width=8, height=8,
406 | agents_xy=[[1, 1], [1, 2]],
407 | targets_xy=[
408 | [[2, 2], [3, 3], [4, 4]],
409 | [[2, 4], [3, 5]]
410 | ],
411 | on_target='restart'
412 | )
413 |
414 | env = PogemaLifeLong(grid_config=config)
415 | obs = env.reset()
416 |
417 | assert env.has_custom_sequences == True
418 | assert np.isclose(env.current_goal_indices, [0, 0]).all()
419 |
420 | with warnings.catch_warnings(record=True) as w:
421 | warnings.simplefilter("always")
422 |
423 | target1 = env._generate_new_target(0)
424 | assert np.isclose(env.current_goal_indices[0], 1)
425 |
426 | target2 = env._generate_new_target(0)
427 | assert np.isclose(env.current_goal_indices[0], 2)
428 |
429 | target3 = env._generate_new_target(0)
430 | assert np.isclose(env.current_goal_indices[0], 0)
431 |
432 | cycling_warnings = [warning for warning in w if "completed all 3 provided targets" in str(warning.message)]
433 | assert np.isclose(len(cycling_warnings), 1)
434 |
435 | with warnings.catch_warnings(record=True) as w:
436 | warnings.simplefilter("always")
437 |
438 | env._generate_new_target(1)
439 | env._generate_new_target(1)
440 |
441 | assert np.isclose(len(w), 1)
442 | assert "completed all 2 provided targets" in str(w[0].message)
443 | assert "cycling back to the beginning" in str(w[0].message)
444 |
445 |
446 | def test_pogema_lifelong_reset():
447 | from pogema.envs import PogemaLifeLong
448 |
449 | config = GridConfig(
450 | width=8, height=8,
451 | agents_xy=[[1, 1], [1, 2]],
452 | targets_xy=[
453 | [[2, 2], [3, 3]],
454 | [[2, 4], [3, 5]]
455 | ],
456 | on_target='restart'
457 | )
458 |
459 | env = PogemaLifeLong(grid_config=config)
460 | env.reset()
461 |
462 | env._generate_new_target(0)
463 | env._generate_new_target(1)
464 | assert np.isclose(env.current_goal_indices, [1, 1]).all()
465 |
466 | env.reset()
467 | assert np.isclose(env.current_goal_indices, [0, 0]).all()
468 |
469 |
470 | def test_pogema_lifelong_without_sequences():
471 | from pogema.envs import PogemaLifeLong
472 |
473 | config = GridConfig(
474 | width=8, height=8,
475 | num_agents=2,
476 | on_target='restart'
477 | )
478 |
479 | env = PogemaLifeLong(grid_config=config)
480 | obs = env.reset()
481 |
482 | assert env.has_custom_sequences == False
483 |
484 | target = env._generate_new_target(0)
485 | assert isinstance(target, tuple)
486 | assert np.isclose(len(target), 2)
487 |
488 |
489 | def test_goal_sequences_position_format():
490 | with pytest.raises(ValueError, match="Position must be a list/tuple of length 2"):
491 | GridConfig(
492 | width=8, height=8,
493 | agents_xy=[[0, 0]],
494 | targets_xy=[[[2, 2, 3], [4, 4]]]
495 | )
496 |
497 | with pytest.raises(ValueError, match="Position coordinates must be integers"):
498 | GridConfig(
499 | width=8, height=8,
500 | agents_xy=[[0, 0]],
501 | targets_xy=[[[2.5, 2], [4, 4]]]
502 | )
503 |
--------------------------------------------------------------------------------
/pogema/envs.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import warnings
3 | import numpy as np
4 | import gymnasium
5 | from gymnasium.error import ResetNeeded
6 |
7 | from pogema.grid import Grid, GridLifeLong
8 | from pogema.grid_config import GridConfig
9 | from pogema.wrappers.metrics import LifeLongAverageThroughputMetric, NonDisappearEpLengthMetric, \
10 | NonDisappearCSRMetric, NonDisappearISRMetric, EpLengthMetric, ISRMetric, CSRMetric, SumOfCostsAndMakespanMetric
11 | from pogema.wrappers.multi_time_limit import MultiTimeLimit
12 | from pogema.generator import generate_new_target, generate_from_possible_targets
13 | from pogema.wrappers.persistence import PersistentWrapper
14 |
15 |
16 | class ActionsSampler:
17 | """
18 | Samples the random actions for the given number of agents using the given seed.
19 | """
20 |
21 | def __init__(self, num_actions, seed=42):
22 | self._num_actions = num_actions
23 | self._rnd = None
24 | self.update_seed(seed)
25 |
26 | def update_seed(self, seed=None):
27 | self._rnd = np.random.default_rng(seed)
28 |
29 | def sample_actions(self, dim=1):
30 | return self._rnd.integers(self._num_actions, size=dim)
31 |
32 |
33 | class PogemaBase(gymnasium.Env):
34 | """
35 | Abstract class of the Pogema environment.
36 | """
37 | metadata = {"render_modes": ["ansi"], }
38 |
39 | def step(self, action):
40 | raise NotImplementedError
41 |
42 | def reset(self, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None, ):
43 | raise NotImplementedError
44 |
45 | def __init__(self, grid_config: GridConfig = GridConfig()):
46 | # noinspection PyTypeChecker
47 | self.grid: Grid = None
48 | self.grid_config = grid_config
49 |
50 | self.action_space: gymnasium.spaces.Discrete = gymnasium.spaces.Discrete(len(self.grid_config.MOVES))
51 | self._multi_action_sampler = ActionsSampler(self.action_space.n, seed=self.grid_config.seed)
52 |
53 | def _get_agents_obs(self, agent_id=0):
54 | """
55 | Returns the observation of the agent with the given id.
56 | :param agent_id:
57 | :return:
58 | """
59 | return np.concatenate([
60 | self.grid.get_obstacles_for_agent(agent_id)[None],
61 | self.grid.get_positions(agent_id)[None],
62 | self.grid.get_square_target(agent_id)[None]
63 | ])
64 |
65 | def check_reset(self):
66 | """
67 | Checks if the reset needed.
68 | :return:
69 | """
70 | if self.grid is None:
71 | raise ResetNeeded("Please reset environment first!")
72 |
73 | def render(self, mode='human'):
74 | """
75 | Renders the environment using ascii graphics.
76 | :param mode:
77 | :return:
78 | """
79 | self.check_reset()
80 | return self.grid.render(mode=mode)
81 |
82 | def sample_actions(self):
83 | """
84 | Samples the random actions for the given number of agents.
85 | :return:
86 | """
87 | return self._multi_action_sampler.sample_actions(dim=self.grid_config.num_agents)
88 |
89 | def get_num_agents(self):
90 | """
91 | Returns the number of agents in the environment.
92 | :return:
93 | """
94 | return self.grid_config.num_agents
95 |
96 |
97 | class Pogema(PogemaBase):
98 | def __init__(self, grid_config=GridConfig(num_agents=2)):
99 | super().__init__(grid_config)
100 | self.was_on_goal = None
101 | full_size = self.grid_config.obs_radius * 2 + 1
102 | if self.grid_config.observation_type == 'default':
103 | self.observation_space = gymnasium.spaces.Box(-1.0, 1.0, shape=(3, full_size, full_size))
104 | elif self.grid_config.observation_type == 'POMAPF':
105 | self.observation_space: gymnasium.spaces.Dict = gymnasium.spaces.Dict(
106 | obstacles=gymnasium.spaces.Box(0.0, 1.0, shape=(full_size, full_size)),
107 | agents=gymnasium.spaces.Box(0.0, 1.0, shape=(full_size, full_size)),
108 | xy=gymnasium.spaces.Box(low=-1024, high=1024, shape=(2,), dtype=int),
109 | target_xy=gymnasium.spaces.Box(low=-1024, high=1024, shape=(2,), dtype=int),
110 | )
111 | elif self.grid_config.observation_type == 'MAPF':
112 | self.observation_space: gymnasium.spaces.Dict = gymnasium.spaces.Dict(
113 | obstacles=gymnasium.spaces.Box(0.0, 1.0, shape=(full_size, full_size)),
114 | agents=gymnasium.spaces.Box(0.0, 1.0, shape=(full_size, full_size)),
115 | xy=gymnasium.spaces.Box(low=-1024, high=1024, shape=(2,), dtype=int),
116 | target_xy=gymnasium.spaces.Box(low=-1024, high=1024, shape=(2,), dtype=int),
117 | # global_obstacles=None, # todo define shapes of global state variables
118 | # global_xy=None,
119 | # global_target_xy=None,
120 | )
121 | else:
122 | raise ValueError(f"Unknown observation type: {self.grid.config.observation_type}")
123 |
124 | def step(self, action: list):
125 | assert len(action) == self.grid_config.num_agents
126 | rewards = []
127 |
128 | terminated = []
129 |
130 | self.move_agents(action)
131 | self.update_was_on_goal()
132 |
133 | for agent_idx in range(self.grid_config.num_agents):
134 |
135 | on_goal = self.grid.on_goal(agent_idx)
136 | if on_goal and self.grid.is_active[agent_idx]:
137 | rewards.append(1.0)
138 | else:
139 | rewards.append(0.0)
140 | terminated.append(on_goal)
141 |
142 | for agent_idx in range(self.grid_config.num_agents):
143 | if self.grid.on_goal(agent_idx):
144 | self.grid.hide_agent(agent_idx)
145 | self.grid.is_active[agent_idx] = False
146 |
147 | infos = self._get_infos()
148 |
149 | observations = self._obs()
150 | truncated = [False] * self.grid_config.num_agents
151 | return observations, rewards, terminated, truncated, infos
152 |
153 | def _initialize_grid(self):
154 | self.grid: Grid = Grid(grid_config=self.grid_config)
155 |
156 | def update_was_on_goal(self):
157 | self.was_on_goal = [self.grid.on_goal(agent_idx) and self.grid.is_active[agent_idx]
158 | for agent_idx in range(self.grid_config.num_agents)]
159 |
160 | def reset(self, seed: Optional[int] = None, return_info: bool = True, options: Optional[dict] = None, ):
161 | self._initialize_grid()
162 | self.update_was_on_goal()
163 |
164 | if seed is not None:
165 | self.grid.seed = seed
166 |
167 | if return_info:
168 | return self._obs(), self._get_infos()
169 | return self._obs()
170 |
171 | def _obs(self):
172 | if self.grid_config.observation_type == 'default':
173 | return [self._get_agents_obs(index) for index in range(self.grid_config.num_agents)]
174 | elif self.grid_config.observation_type == 'POMAPF':
175 | return self._pomapf_obs()
176 |
177 | elif self.grid_config.observation_type == 'MAPF':
178 | results = self._pomapf_obs()
179 | global_obstacles = self.grid.get_obstacles()
180 | global_agents_xy = self.grid.get_agents_xy()
181 | global_targets_xy = self.grid.get_targets_xy()
182 |
183 | for agent_idx in range(self.grid_config.num_agents):
184 | result = results[agent_idx]
185 | result.update(global_obstacles=global_obstacles)
186 | result['global_xy'] = global_agents_xy[agent_idx]
187 | result['global_target_xy'] = global_targets_xy[agent_idx]
188 |
189 | return results
190 | else:
191 | raise ValueError(f"Unknown observation type: {self.grid.config.observation_type}")
192 |
193 | def _pomapf_obs(self):
194 | results = []
195 | agents_xy_relative = self.grid.get_agents_xy_relative()
196 | targets_xy_relative = self.grid.get_targets_xy_relative()
197 |
198 | for agent_idx in range(self.grid_config.num_agents):
199 | result = {'obstacles': self.grid.get_obstacles_for_agent(agent_idx),
200 | 'agents': self.grid.get_positions(agent_idx),
201 | 'xy': agents_xy_relative[agent_idx],
202 | 'target_xy': targets_xy_relative[agent_idx]}
203 |
204 | results.append(result)
205 | return results
206 |
207 | def _get_infos(self):
208 | infos = [dict() for _ in range(self.grid_config.num_agents)]
209 | for agent_idx in range(self.grid_config.num_agents):
210 | infos[agent_idx]['is_active'] = self.grid.is_active[agent_idx]
211 | return infos
212 |
213 | def _revert_action(self, agent_idx, used_cells, cell, actions):
214 | actions[agent_idx] = 0
215 | used_cells[cell].remove(agent_idx)
216 | new_cell = self.grid.positions_xy[agent_idx]
217 | if new_cell in used_cells and len(used_cells[new_cell]) > 0:
218 | used_cells[new_cell].append(agent_idx)
219 | return self._revert_action(used_cells[new_cell][0], used_cells, new_cell, actions)
220 | else:
221 | used_cells.setdefault(new_cell, []).append(agent_idx)
222 | return actions, used_cells
223 |
224 | def move_agents(self, actions):
225 | if self.grid.config.collision_system == 'priority':
226 | for agent_idx in range(self.grid_config.num_agents):
227 | if self.grid.is_active[agent_idx]:
228 | self.grid.move(agent_idx, actions[agent_idx])
229 | elif self.grid.config.collision_system == 'block_both':
230 | used_cells = {}
231 | agents_xy = self.grid.get_agents_xy()
232 | for agent_idx, (x, y) in enumerate(agents_xy):
233 | if self.grid.is_active[agent_idx]:
234 | dx, dy = self.grid_config.MOVES[actions[agent_idx]]
235 | used_cells[x + dx, y + dy] = 'blocked' if (x + dx, y + dy) in used_cells else 'visited'
236 | used_cells[x, y] = 'blocked'
237 | for agent_idx in range(self.grid_config.num_agents):
238 | if self.grid.is_active[agent_idx]:
239 | x, y = agents_xy[agent_idx]
240 | dx, dy = self.grid_config.MOVES[actions[agent_idx]]
241 | if used_cells.get((x + dx, y + dy), None) != 'blocked':
242 | self.grid.move(agent_idx, actions[agent_idx])
243 | elif self.grid.config.collision_system == 'soft':
244 | used_cells = dict()
245 | used_edges = dict()
246 | agents_xy = self.grid.get_agents_xy()
247 | for agent_idx, (x, y) in enumerate(agents_xy):
248 | if self.grid.is_active[agent_idx]:
249 | dx, dy = self.grid.config.MOVES[actions[agent_idx]]
250 | used_cells.setdefault((x + dx, y + dy), []).append(agent_idx)
251 | used_edges[x, y, x + dx, y + dy] = [agent_idx]
252 | if dx != 0 or dy != 0:
253 | used_edges.setdefault((x + dx, y + dy, x, y), []).append(agent_idx)
254 | for agent_idx, (x, y) in enumerate(agents_xy):
255 | if self.grid.is_active[agent_idx]:
256 | dx, dy = self.grid.config.MOVES[actions[agent_idx]]
257 | if len(used_edges[x, y, x + dx, y + dy]) > 1:
258 | used_cells[x + dx, y + dy].remove(agent_idx)
259 | used_cells.setdefault((x, y), []).append(agent_idx)
260 | actions[agent_idx] = 0
261 | for agent_idx in reversed(range(len(agents_xy))):
262 | x, y = agents_xy[agent_idx]
263 | if self.grid.is_active[agent_idx]:
264 | dx, dy = self.grid.config.MOVES[actions[agent_idx]]
265 | if len(used_cells[x + dx, y + dy]) > 1 or self.grid.has_obstacle(x + dx, y + dy):
266 | actions, used_cells = self._revert_action(agent_idx, used_cells, (x + dx, y + dy), actions)
267 | for agent_idx in range(self.grid_config.num_agents):
268 | if self.grid.is_active[agent_idx]:
269 | self.grid.move_without_checks(agent_idx, actions[agent_idx])
270 | else:
271 | raise ValueError('Unknown collision system: {}'.format(self.grid.config.collision_system))
272 |
273 | def get_agents_xy_relative(self):
274 | return self.grid.get_agents_xy_relative()
275 |
276 | def get_targets_xy_relative(self):
277 | return self.grid.get_targets_xy_relative()
278 |
279 | def get_obstacles(self, ignore_borders=False):
280 | return self.grid.get_obstacles(ignore_borders=ignore_borders)
281 |
282 | def get_agents_xy(self, only_active=False, ignore_borders=False):
283 | return self.grid.get_agents_xy(only_active=only_active, ignore_borders=ignore_borders)
284 |
285 | def get_targets_xy(self, only_active=False, ignore_borders=False):
286 | return self.grid.get_targets_xy(only_active=only_active, ignore_borders=ignore_borders)
287 |
288 | def get_state(self, ignore_borders=False, as_dict=False):
289 | return self.grid.get_state(ignore_borders=ignore_borders, as_dict=as_dict)
290 |
291 |
292 | class PogemaLifeLong(Pogema):
293 | def __init__(self, grid_config=GridConfig(num_agents=2)):
294 | super().__init__(grid_config)
295 | self.current_goal_indices = [0] * grid_config.num_agents
296 | self.has_custom_sequences = grid_config.targets_xy is not None
297 |
298 | def _initialize_grid(self):
299 | self.grid: GridLifeLong = GridLifeLong(grid_config=self.grid_config)
300 |
301 | main_rng = np.random.default_rng(self.grid_config.seed)
302 | seeds = main_rng.integers(np.iinfo(np.int32).max, size=self.grid_config.num_agents)
303 | self.random_generators = [np.random.default_rng(seed) for seed in seeds]
304 |
305 | def get_lifelong_targets_xy(self, ignore_borders=False):
306 | if self.has_custom_sequences:
307 | if ignore_borders:
308 | return self.grid_config.targets_xy
309 | else:
310 | return [[[x + self.grid_config.obs_radius, y + self.grid_config.obs_radius] for x, y in sequence]
311 | for sequence in self.grid_config.targets_xy]
312 |
313 | sequences = []
314 |
315 | main_rng = np.random.default_rng(self.grid_config.seed)
316 | seeds = main_rng.integers(np.iinfo(np.int32).max, size=self.grid_config.num_agents)
317 | temp_generators = [np.random.default_rng(seed) for seed in seeds]
318 |
319 | for agent_idx in range(self.grid_config.num_agents):
320 | agent_sequence = []
321 | start_pos = self.get_agents_xy(ignore_borders=ignore_borders)[agent_idx]
322 | initial_target = self.get_targets_xy(ignore_borders=ignore_borders)[agent_idx]
323 | agent_sequence.append(initial_target)
324 | current_pos = initial_target
325 | total_distance = abs(start_pos[0] - initial_target[0]) + abs(start_pos[1] - initial_target[1])
326 |
327 | while total_distance < self.grid_config.max_episode_steps:
328 | if ignore_borders:
329 | generator_pos = (current_pos[0] + self.grid_config.obs_radius,
330 | current_pos[1] + self.grid_config.obs_radius)
331 | else:
332 | generator_pos = tuple(current_pos)
333 |
334 | if self.grid_config.possible_targets_xy is not None:
335 | new_goal = generate_from_possible_targets(
336 | temp_generators[agent_idx],
337 | self.grid_config.possible_targets_xy,
338 | generator_pos
339 | )
340 | if ignore_borders:
341 | goal_coords = list(new_goal)
342 | else:
343 | goal_coords = [new_goal[0] + self.grid_config.obs_radius,
344 | new_goal[1] + self.grid_config.obs_radius]
345 | else:
346 | new_goal = generate_new_target(
347 | temp_generators[agent_idx],
348 | self.grid.point_to_component,
349 | self.grid.component_to_points,
350 | generator_pos
351 | )
352 | if ignore_borders:
353 | goal_coords = [new_goal[0] - self.grid_config.obs_radius,
354 | new_goal[1] - self.grid_config.obs_radius]
355 | else:
356 | goal_coords = list(new_goal)
357 |
358 | agent_sequence.append(goal_coords)
359 | total_distance += abs(current_pos[0] - goal_coords[0]) + abs(current_pos[1] - goal_coords[1])
360 | current_pos = goal_coords
361 | sequences.append(agent_sequence)
362 | return sequences
363 |
364 | def reset(self, seed: Optional[int] = None, return_info: bool = True, options: Optional[dict] = None):
365 | self.current_goal_indices = [0] * self.grid_config.num_agents
366 | return super().reset(seed=seed, return_info=return_info, options=options)
367 |
368 | def _generate_new_target(self, agent_idx):
369 | if self.has_custom_sequences:
370 | agent_targets = self.grid_config.targets_xy[agent_idx]
371 | current_idx = self.current_goal_indices[agent_idx]
372 | next_target = agent_targets[(current_idx + 1) % len(agent_targets)]
373 |
374 | self.current_goal_indices[agent_idx] = (current_idx + 1) % len(agent_targets)
375 |
376 | if self.current_goal_indices[agent_idx] == 0 and current_idx == len(agent_targets) - 1:
377 | warnings.warn(
378 | f"Agent {agent_idx} has completed all {len(agent_targets)} provided targets and "
379 | f"is cycling back to the beginning. Provide more targets for the "
380 | f"{self.grid_config.max_episode_steps} episode length.",
381 | UserWarning,
382 | stacklevel=2
383 | )
384 |
385 | return (next_target[0] + self.grid_config.obs_radius,
386 | next_target[1] + self.grid_config.obs_radius)
387 | elif self.grid_config.possible_targets_xy is not None:
388 | new_goal = generate_from_possible_targets(self.random_generators[agent_idx],
389 | self.grid_config.possible_targets_xy,
390 | self.grid.positions_xy[agent_idx])
391 | return (new_goal[0] + self.grid_config.obs_radius, new_goal[1] + self.grid_config.obs_radius)
392 | else:
393 | return generate_new_target(self.random_generators[agent_idx],
394 | self.grid.point_to_component,
395 | self.grid.component_to_points,
396 | self.grid.positions_xy[agent_idx])
397 |
398 | def step(self, action: list):
399 | assert len(action) == self.grid_config.num_agents
400 | rewards = []
401 |
402 | infos = [dict() for _ in range(self.grid_config.num_agents)]
403 |
404 | self.move_agents(action)
405 | self.update_was_on_goal()
406 |
407 | for agent_idx in range(self.grid_config.num_agents):
408 | on_goal = self.grid.on_goal(agent_idx)
409 | if on_goal and self.grid.is_active[agent_idx]:
410 | rewards.append(1.0)
411 | else:
412 | rewards.append(0.0)
413 |
414 | if self.grid.on_goal(agent_idx):
415 | self.grid.finishes_xy[agent_idx] = self._generate_new_target(agent_idx)
416 |
417 | for agent_idx in range(self.grid_config.num_agents):
418 | infos[agent_idx]['is_active'] = self.grid.is_active[agent_idx]
419 |
420 | obs = self._obs()
421 |
422 | terminated = [False] * self.grid_config.num_agents
423 | truncated = [False] * self.grid_config.num_agents
424 | return obs, rewards, terminated, truncated, infos
425 |
426 |
427 | class PogemaCoopFinish(Pogema):
428 | def __init__(self, grid_config=GridConfig(num_agents=2)):
429 | super().__init__(grid_config)
430 | self.num_agents = self.grid_config.num_agents
431 | self.is_multiagent = True
432 |
433 | def _initialize_grid(self):
434 | self.grid: Grid = Grid(grid_config=self.grid_config)
435 |
436 | def step(self, action: list):
437 | assert len(action) == self.grid_config.num_agents
438 |
439 | infos = [dict() for _ in range(self.grid_config.num_agents)]
440 |
441 | self.move_agents(action)
442 | self.update_was_on_goal()
443 |
444 | is_task_solved = all(self.was_on_goal)
445 | for agent_idx in range(self.grid_config.num_agents):
446 | infos[agent_idx]['is_active'] = self.grid.is_active[agent_idx]
447 |
448 | obs = self._obs()
449 |
450 | terminated = [is_task_solved] * self.grid_config.num_agents
451 | truncated = [False] * self.grid_config.num_agents
452 | rewards = [1.0 if is_task_solved else 0.0 for _ in range(self.grid_config.num_agents)]
453 | return obs, rewards, terminated, truncated, infos
454 |
455 |
456 | def _make_pogema(grid_config):
457 | if grid_config.on_target == 'restart':
458 | env = PogemaLifeLong(grid_config=grid_config)
459 | elif grid_config.on_target == 'nothing':
460 | env = PogemaCoopFinish(grid_config=grid_config)
461 | elif grid_config.on_target == 'finish':
462 | env = Pogema(grid_config=grid_config)
463 | else:
464 | raise KeyError(f'Unknown on_target option: {grid_config.on_target}')
465 |
466 | env = MultiTimeLimit(env, grid_config.max_episode_steps)
467 | if env.grid_config.persistent:
468 | env = PersistentWrapper(env)
469 | else:
470 | # adding metrics wrappers
471 | if grid_config.on_target == 'restart':
472 | env = LifeLongAverageThroughputMetric(env)
473 | elif grid_config.on_target == 'nothing':
474 | env = NonDisappearISRMetric(env)
475 | env = NonDisappearCSRMetric(env)
476 | env = NonDisappearEpLengthMetric(env)
477 | env = SumOfCostsAndMakespanMetric(env)
478 | elif grid_config.on_target == 'finish':
479 | env = ISRMetric(env)
480 | env = CSRMetric(env)
481 | env = EpLengthMetric(env)
482 | else:
483 | raise KeyError(f'Unknown on_target option: {grid_config.on_target}')
484 |
485 | return env
486 |
--------------------------------------------------------------------------------