├── test
├── __init__.py
├── test_vector
│ ├── __init__.py
│ ├── test_pettingzoo_to_vec_gymnasium_wrappers.py
│ ├── test_render.py
│ ├── test_env_is_wrapped.py
│ ├── test_aec_vector_identity_env.py
│ ├── test_aec_vector_values.py
│ ├── test_vector_dict.py
│ ├── test_pettingzoo_to_vec.py
│ └── test_gym_vector.py
├── test_utils
│ ├── test_basic_transforms
│ │ ├── __init__.py
│ │ ├── test_flatten.py
│ │ ├── test_dtype.py
│ │ ├── test_color_reduction.py
│ │ ├── test_reshape.py
│ │ ├── test_normalized_obs.py
│ │ └── test_resize.py
│ ├── test_action_transforms
│ │ └── test_homogenize.py
│ ├── test_frame_stack.py
│ └── test_agent_indicator.py
├── test_autodep.py
├── pettingzoo_env_test.py
├── dummy_gym_env.py
├── vec_env_test.py
├── dummy_aec_env.py
├── parallel_env_test.py
├── gym_unwrapped_test.py
├── gym_mock_test.py
├── aec_unwrapped_test.py
├── generated_agents_test.py
└── pettingzoo_api_test.py
├── supersuit
├── utils
│ ├── __init__.py
│ ├── action_transforms
│ │ ├── __init__.py
│ │ └── homogenize_ops.py
│ ├── convert_box.py
│ ├── make_defaultdict.py
│ ├── basic_transforms
│ │ ├── dtype.py
│ │ ├── __init__.py
│ │ ├── flatten.py
│ │ ├── reshape.py
│ │ ├── resize.py
│ │ ├── color_reduction.py
│ │ └── normalize_obs.py
│ ├── frame_skip.py
│ ├── accumulator.py
│ ├── obs_delay.py
│ ├── base_aec_wrapper.py
│ ├── wrapper_chooser.py
│ ├── frame_stack.py
│ └── agent_indicator.py
├── vector
│ ├── utils
│ │ ├── __init__.py
│ │ ├── space_wrapper.py
│ │ └── shared_array.py
│ ├── __init__.py
│ ├── sb_vector_wrapper.py
│ ├── constructors.py
│ ├── single_vec_env.py
│ ├── sb3_vector_wrapper.py
│ ├── vector_constructors.py
│ ├── concat_vec_env.py
│ ├── markov_vector_wrapper.py
│ └── multiproc_vec.py
├── generic_wrappers
│ ├── utils
│ │ ├── __init__.py
│ │ ├── base_modifier.py
│ │ └── shared_wrapper_util.py
│ ├── delay_observations.py
│ ├── __init__.py
│ ├── max_observation.py
│ ├── sticky_actions.py
│ ├── basic_wrappers.py
│ ├── nan_wrappers.py
│ ├── frame_stack.py
│ └── frame_skip.py
├── aec_vector
│ ├── __init__.py
│ ├── create.py
│ ├── base_aec_vec_env.py
│ └── vector_env.py
├── lambda_wrappers
│ ├── __init__.py
│ ├── reward_lambda.py
│ ├── action_lambda.py
│ └── observation_lambda.py
├── multiagent_wrappers
│ ├── __init__.py
│ ├── agent_indication.py
│ ├── padding_wrappers.py
│ └── black_death.py
└── __init__.py
├── .github
├── FUNDING.yml
└── workflows
│ ├── pre-commit.yml
│ ├── linux-test.yml
│ └── build-publish.yml
├── .gitignore
├── supersuit-text.png
├── setup.py
├── .pre-commit-config.yaml
├── README.md
├── pyproject.toml
└── LICENSE
/test/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/supersuit/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/test_vector/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/supersuit/vector/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/supersuit/generic_wrappers/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/supersuit/utils/action_transforms/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: Farama-Foundation
2 |
--------------------------------------------------------------------------------
/test/test_utils/test_basic_transforms/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | pettingzoo
3 | dist/*
4 | build/*
5 | SuperSuit.egg-info/*
6 | venv/
7 |
--------------------------------------------------------------------------------
/supersuit-text.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Farama-Foundation/SuperSuit/HEAD/supersuit-text.png
--------------------------------------------------------------------------------
/supersuit/aec_vector/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_aec_vec_env import VectorAECEnv # NOQA
2 | from .create import vectorize_aec_env_v0 # NOQA
3 |
--------------------------------------------------------------------------------
/supersuit/lambda_wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | from .action_lambda import action_lambda_v1 # NOQA
2 | from .observation_lambda import observation_lambda_v0 # NOQA
3 | from .reward_lambda import reward_lambda_v0 # NOQA
4 |
--------------------------------------------------------------------------------
/supersuit/multiagent_wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | from .agent_indication import agent_indicator_v0 # NOQA
2 | from .black_death import black_death_v3 # NOQA
3 | from .padding_wrappers import pad_action_space_v0, pad_observations_v0 # NOQA
4 |
--------------------------------------------------------------------------------
/supersuit/utils/convert_box.py:
--------------------------------------------------------------------------------
1 | from gymnasium.spaces import Box
2 |
3 |
4 | def convert_box(convert_obs_fn, old_box):
5 | new_low = convert_obs_fn(old_box.low)
6 | new_high = convert_obs_fn(old_box.high)
7 | return Box(low=new_low, high=new_high, dtype=new_low.dtype)
8 |
--------------------------------------------------------------------------------
/supersuit/vector/__init__.py:
--------------------------------------------------------------------------------
1 | from .concat_vec_env import ConcatVecEnv # NOQA
2 | from .constructors import MakeCPUAsyncConstructor # NOQA
3 | from .markov_vector_wrapper import MarkovVectorEnv # NOQA
4 | from .multiproc_vec import ProcConcatVec # NOQA
5 | from .single_vec_env import SingleVecEnv # NOQA
6 |
--------------------------------------------------------------------------------
/test/test_autodep.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import supersuit
4 |
5 |
6 | def test_bad_import():
7 | with pytest.raises(supersuit.DeprecatedWrapper):
8 | from supersuit import action_lambda_v0 # noqa: F401
9 | with pytest.raises(supersuit.DeprecatedWrapper):
10 | supersuit.action_lambda_v0
11 |
--------------------------------------------------------------------------------
/supersuit/utils/make_defaultdict.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from collections import defaultdict
3 |
4 |
5 | def make_defaultdict(d):
6 | try:
7 | dd = defaultdict(type(next(iter(d.values()))))
8 | for k, v in d.items():
9 | dd[k] = v
10 | return dd
11 | except StopIteration:
12 | warnings.warn("No agents left in the environment!")
13 | return {}
14 |
--------------------------------------------------------------------------------
/test/pettingzoo_env_test.py:
--------------------------------------------------------------------------------
1 | from pettingzoo.mpe import simple_spread_v3
2 | from pettingzoo.test import parallel_api_test
3 |
4 | from supersuit.multiagent_wrappers import pad_action_space_v0
5 |
6 |
7 | def test_pad_actuon_space():
8 | env = simple_spread_v3.parallel_env(max_cycles=25, continuous_actions=True)
9 | env = pad_action_space_v0(env)
10 |
11 | parallel_api_test(env, num_cycles=100)
12 |
--------------------------------------------------------------------------------
/supersuit/utils/basic_transforms/dtype.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from . import convert_box
4 |
5 |
6 | def check_param(obs_space, new_dtype):
7 | np.dtype(new_dtype) # type argument must be convertible to a numpy dtype
8 |
9 |
10 | def change_obs_space(obs_space, param):
11 | return convert_box(lambda obs: change_observation(obs, obs_space, param), obs_space)
12 |
13 |
14 | def change_observation(obs, obs_space, new_dtype):
15 | obs = obs.astype(new_dtype)
16 | return obs
17 |
--------------------------------------------------------------------------------
/supersuit/utils/basic_transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from gymnasium.spaces import Box
2 |
3 |
4 | def convert_box(convert_obs_fn, old_box):
5 | new_low = convert_obs_fn(old_box.low)
6 | new_high = convert_obs_fn(old_box.high)
7 | return Box(low=new_low, high=new_high, dtype=new_low.dtype)
8 |
9 |
10 | from . import color_reduction # NOQA
11 | from . import dtype # NOQA
12 | from . import flatten # NOQA
13 | from . import normalize_obs # NOQA
14 | from . import reshape # NOQA
15 | from . import resize # NOQA
16 |
--------------------------------------------------------------------------------
/test/dummy_gym_env.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 |
3 |
4 | class DummyEnv(gymnasium.Env):
5 | def __init__(self, observation, observation_space, action_space):
6 | super().__init__()
7 | self._observation = observation
8 | self.observation_space = observation_space
9 | self.action_space = action_space
10 |
11 | def step(self, action):
12 | return self._observation, 1, False, False, {}
13 |
14 | def reset(self, seed=None, options=None):
15 | return self._observation, {}
16 |
--------------------------------------------------------------------------------
/supersuit/vector/utils/space_wrapper.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 | import numpy as np
3 |
4 |
5 | class SpaceWrapper:
6 | def __init__(self, space):
7 | if isinstance(space, gymnasium.spaces.Discrete):
8 | self.shape = ()
9 | self.dtype = np.dtype(np.int64)
10 | elif isinstance(space, gymnasium.spaces.Box):
11 | self.shape = space.shape
12 | self.dtype = np.dtype(space.dtype)
13 | else:
14 | assert False, "ProcVectorEnv only support Box and Discrete types"
15 |
--------------------------------------------------------------------------------
/supersuit/utils/basic_transforms/flatten.py:
--------------------------------------------------------------------------------
1 | from . import convert_box
2 |
3 |
4 | def check_param(obs_space, should_flatten):
5 | assert isinstance(
6 | should_flatten, bool
7 | ), f"should_flatten must be bool. It is {should_flatten}"
8 |
9 |
10 | def change_obs_space(obs_space, param):
11 | return convert_box(lambda obs: change_observation(obs, obs_space, param), obs_space)
12 |
13 |
14 | def change_observation(obs, obs_space, should_flatten):
15 | if should_flatten:
16 | obs = obs.flatten()
17 | return obs
18 |
--------------------------------------------------------------------------------
/supersuit/utils/frame_skip.py:
--------------------------------------------------------------------------------
1 | def check_transform_frameskip(frame_skip):
2 | if (
3 | isinstance(frame_skip, tuple)
4 | and len(frame_skip) == 2
5 | and isinstance(frame_skip[0], int)
6 | and isinstance(frame_skip[1], int)
7 | and 1 <= frame_skip[0] <= frame_skip[1]
8 | ):
9 | return frame_skip
10 | elif isinstance(frame_skip, int):
11 | return (frame_skip, frame_skip)
12 | else:
13 | assert (
14 | False
15 | ), "frame_skip must be an int or a tuple of two ints, where the first values is at least one and not greater than the second"
16 |
--------------------------------------------------------------------------------
/supersuit/generic_wrappers/delay_observations.py:
--------------------------------------------------------------------------------
1 | from supersuit.utils.obs_delay import Delayer
2 |
3 | from .utils.base_modifier import BaseModifier
4 | from .utils.shared_wrapper_util import shared_wrapper
5 |
6 |
7 | def delay_observations_v0(env, delay):
8 | class DelayObsModifier(BaseModifier):
9 | def reset(self, seed=None, options=None):
10 | self.delayer = Delayer(self.observation_space, delay)
11 |
12 | def modify_obs(self, obs):
13 | obs = self.delayer.add(obs)
14 | return BaseModifier.modify_obs(self, obs)
15 |
16 | return shared_wrapper(env, DelayObsModifier)
17 |
--------------------------------------------------------------------------------
/test/vec_env_test.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 | import numpy as np
3 |
4 | from supersuit import gym_vec_env_v0
5 |
6 |
7 | def test_vec_env_args():
8 | env = gymnasium.make("Acrobot-v1")
9 | num_envs = 8
10 | vec_env = gym_vec_env_v0(env, num_envs)
11 | vec_env.reset()
12 | obs, rew, terminations, truncations, infos = vec_env.step(
13 | [0] + [1] * (vec_env.num_envs - 1)
14 | )
15 | assert not np.any(np.equal(obs[0], obs[1]))
16 |
17 |
18 | def test_all_vec_env_fns():
19 | num_envs = 8
20 | env = gymnasium.make("Acrobot-v1")
21 | gym_vec_env_v0(env, num_envs, False)
22 | gym_vec_env_v0(env, num_envs, True)
23 |
--------------------------------------------------------------------------------
/supersuit/generic_wrappers/utils/base_modifier.py:
--------------------------------------------------------------------------------
1 | class BaseModifier:
2 | def __init__(self):
3 | pass
4 |
5 | def reset(self, seed=None, options=None):
6 | pass
7 |
8 | def modify_obs(self, obs):
9 | self.cur_obs = obs
10 | return obs
11 |
12 | def get_last_obs(self):
13 | return self.cur_obs
14 |
15 | def modify_obs_space(self, obs_space):
16 | self.observation_space = obs_space
17 | return obs_space
18 |
19 | def modify_action(self, act):
20 | return act
21 |
22 | def modify_action_space(self, act_space):
23 | self.action_space = act_space
24 | return act_space
25 |
--------------------------------------------------------------------------------
/supersuit/generic_wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic_wrappers import clip_reward_v0 # NOQA
2 | from .basic_wrappers import (
3 | clip_actions_v0,
4 | color_reduction_v0,
5 | dtype_v0,
6 | flatten_v0,
7 | normalize_obs_v0,
8 | reshape_v0,
9 | resize_v1,
10 | scale_actions_v0,
11 | )
12 | from .delay_observations import delay_observations_v0 # NOQA
13 | from .frame_skip import frame_skip_v0 # NOQA
14 | from .frame_stack import frame_stack_v1, frame_stack_v2 # NOQA
15 | from .max_observation import max_observation_v0 # NOQA
16 | from .nan_wrappers import nan_noop_v0, nan_random_v0, nan_zeros_v0 # NOQA
17 | from .sticky_actions import sticky_actions_v0 # NOQA
18 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | # https://pre-commit.com
2 | # This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file.
3 | name: pre-commit
4 | on:
5 | pull_request:
6 | push:
7 | branches: [master]
8 |
9 | permissions:
10 | contents: read # to fetch code (actions/checkout)
11 |
12 | jobs:
13 | pre-commit:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v3
17 | - uses: actions/setup-python@v4
18 | - run: python -m pip install pre-commit
19 | - run: python -m pre_commit --version
20 | - run: python -m pre_commit install
21 | - run: python -m pre_commit run --all-files --show-diff-on-failure
22 |
--------------------------------------------------------------------------------
/supersuit/generic_wrappers/max_observation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from supersuit.utils.accumulator import Accumulator
4 |
5 | from .utils.base_modifier import BaseModifier
6 | from .utils.shared_wrapper_util import shared_wrapper
7 |
8 |
9 | def max_observation_v0(env, memory):
10 | int(memory) # delay must be an int
11 |
12 | class MaxObsModifier(BaseModifier):
13 | def reset(self, seed=None, options=None):
14 | self.accumulator = Accumulator(self.observation_space, memory, np.maximum)
15 |
16 | def modify_obs(self, obs):
17 | self.accumulator.add(obs)
18 | return super().modify_obs(self.accumulator.get())
19 |
20 | return shared_wrapper(env, MaxObsModifier)
21 |
--------------------------------------------------------------------------------
/test/test_utils/test_basic_transforms/test_flatten.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from gymnasium.spaces import Box
4 |
5 | from supersuit.utils.basic_transforms.flatten import change_observation, check_param
6 |
7 |
8 | test_obs_space = Box(
9 | low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 3), dtype=np.float32
10 | )
11 | test_obs = np.zeros([4, 4, 3], dtype=np.float64) + np.arange(3)
12 |
13 |
14 | def test_param_check():
15 | check_param(test_obs_space, True)
16 | with pytest.raises(AssertionError):
17 | check_param(test_obs_space, 6)
18 |
19 |
20 | def test_change_observation():
21 | new_obs = change_observation(test_obs, test_obs_space, True)
22 | assert new_obs.shape == (4 * 4 * 3,)
23 |
--------------------------------------------------------------------------------
/supersuit/utils/accumulator.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | from functools import reduce
3 |
4 | import numpy as np
5 |
6 |
7 | class Accumulator:
8 | def __init__(self, obs_space, memory, reduction):
9 | self.memory = memory
10 | self._obs_buffer = deque()
11 | self.reduction = reduction
12 | self.maxed_val = None
13 |
14 | def add(self, in_obs):
15 | self._obs_buffer.append(np.copy(in_obs))
16 | if len(self._obs_buffer) > self.memory:
17 | self._obs_buffer.popleft()
18 | self.maxed_val = None
19 |
20 | def get(self):
21 | if self.maxed_val is None:
22 | self.maxed_val = reduce(self.reduction, (self._obs_buffer))
23 | return self.maxed_val
24 |
--------------------------------------------------------------------------------
/supersuit/vector/utils/shared_array.py:
--------------------------------------------------------------------------------
1 | import multiprocessing as mp
2 |
3 | import numpy as np
4 |
5 |
6 | class SharedArray:
7 | def __init__(self, shape, dtype):
8 | self.shared_arr = mp.Array(
9 | np.ctypeslib.as_ctypes_type(dtype), int(np.prod(shape)), lock=False
10 | )
11 | self.dtype = dtype
12 | self.shape = shape
13 | self._set_np_arr()
14 |
15 | def _set_np_arr(self):
16 | self.np_arr = np.frombuffer(self.shared_arr, dtype=self.dtype).reshape(
17 | self.shape
18 | )
19 |
20 | def __getstate__(self):
21 | return (self.shared_arr, self.dtype, self.shape)
22 |
23 | def __setstate__(self, state):
24 | self.shared_arr, self.dtype, self.shape = state
25 | self._set_np_arr()
26 |
--------------------------------------------------------------------------------
/test/test_utils/test_basic_transforms/test_dtype.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from gymnasium.spaces import Box
4 |
5 | from supersuit.utils.basic_transforms.dtype import change_observation, check_param
6 |
7 |
8 | test_obs_space = Box(
9 | low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 3), dtype=np.float32
10 | )
11 | test_obs = np.zeros([4, 4, 3], dtype=np.float64) + np.arange(3)
12 |
13 |
14 | def test_param_check():
15 | check_param(test_obs_space, np.uint8)
16 | check_param(test_obs_space, np.dtype("uint8"))
17 | with pytest.raises(TypeError):
18 | check_param(test_obs_space, 6)
19 |
20 |
21 | def test_change_observation():
22 | new_obs = change_observation(test_obs, test_obs_space, np.float32)
23 | assert new_obs.dtype == np.float32
24 |
--------------------------------------------------------------------------------
/supersuit/utils/basic_transforms/reshape.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from . import convert_box
4 |
5 |
6 | def check_param(obs_space, shape):
7 | assert isinstance(shape, tuple), f"shape must be tuple. It is {shape}"
8 | assert all(
9 | isinstance(el, int) for el in shape
10 | ), f"shape must be tuple of ints, is: {shape}"
11 | assert np.prod(shape) == np.prod(
12 | obs_space.shape
13 | ), "new shape {} must have as many elements as original shape {}".format(
14 | shape, obs_space.shape
15 | )
16 |
17 |
18 | def change_obs_space(obs_space, param):
19 | return convert_box(lambda obs: change_observation(obs, obs_space, param), obs_space)
20 |
21 |
22 | def change_observation(obs, obs_space, shape):
23 | obs = obs.reshape(shape)
24 | return obs
25 |
--------------------------------------------------------------------------------
/supersuit/aec_vector/create.py:
--------------------------------------------------------------------------------
1 | import cloudpickle
2 | from pettingzoo import AECEnv
3 |
4 | from .async_vector_env import AsyncAECVectorEnv
5 | from .vector_env import SyncAECVectorEnv
6 |
7 |
8 | def vectorize_aec_env_v0(aec_env, num_envs, num_cpus=0):
9 | assert isinstance(
10 | aec_env, AECEnv
11 | ), "pettingzoo_env_to_vec_env takes in a pettingzoo AECEnv."
12 | assert hasattr(
13 | aec_env, "possible_agents"
14 | ), "environment passed to vectorize_aec_env must have possible_agents attribute."
15 |
16 | def env_fn():
17 | return cloudpickle.loads(cloudpickle.dumps(aec_env))
18 |
19 | env_list = [env_fn] * num_envs
20 |
21 | if num_cpus == 0 or num_cpus == 1:
22 | return SyncAECVectorEnv(env_list)
23 | else:
24 | return AsyncAECVectorEnv(env_list, num_cpus)
25 |
--------------------------------------------------------------------------------
/test/test_utils/test_basic_transforms/test_color_reduction.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from gymnasium.spaces import Box
4 |
5 | from supersuit.utils.basic_transforms.color_reduction import (
6 | change_observation,
7 | check_param,
8 | )
9 |
10 |
11 | test_obs_space = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 3))
12 | bad_test_obs_space = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 4))
13 | test_obs = np.zeros([4, 4, 3]) + np.arange(3)
14 |
15 |
16 | def test_param_check():
17 | with pytest.raises(AssertionError):
18 | check_param(test_obs_space, "bob")
19 | with pytest.raises(AssertionError):
20 | check_param(bad_test_obs_space, "R")
21 | check_param(test_obs_space, "G")
22 |
23 |
24 | def test_change_observation():
25 | new_obs = change_observation(test_obs, test_obs_space, "B")
26 | assert np.all(np.equal(new_obs, 2 * np.ones([4, 4])))
27 |
--------------------------------------------------------------------------------
/supersuit/generic_wrappers/sticky_actions.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 |
3 | from .utils.base_modifier import BaseModifier
4 | from .utils.shared_wrapper_util import shared_wrapper
5 |
6 |
7 | def sticky_actions_v0(env, repeat_action_probability):
8 | assert 0 <= repeat_action_probability < 1
9 |
10 | class StickyActionsModifier(BaseModifier):
11 | def __init__(self):
12 | super().__init__()
13 |
14 | def reset(self, seed=None, options=None):
15 | self.np_random, _ = gymnasium.utils.seeding.np_random(seed)
16 | self.old_action = None
17 |
18 | def modify_action(self, action):
19 | if (
20 | self.old_action is not None
21 | and self.np_random.uniform() < repeat_action_probability
22 | ):
23 | action = self.old_action
24 | self.old_action = action
25 | return action
26 |
27 | return shared_wrapper(env, StickyActionsModifier)
28 |
--------------------------------------------------------------------------------
/test/test_utils/test_basic_transforms/test_reshape.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from gymnasium.spaces import Box
4 |
5 | from supersuit.utils.basic_transforms.reshape import change_observation, check_param
6 |
7 |
8 | test_obs_space = Box(
9 | low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 3), dtype=np.float32
10 | )
11 | test_obs = np.zeros([4, 4, 3], dtype=np.float64) + np.arange(3)
12 |
13 |
14 | def test_param_check():
15 | check_param(test_obs_space, (8, 6))
16 | with pytest.raises(AssertionError):
17 | check_param(test_obs_space, (8, 7))
18 | with pytest.raises(AssertionError):
19 | check_param(test_obs_space, "bob")
20 | with pytest.raises(AssertionError):
21 | check_param(test_obs_space, ("bob", 5))
22 |
23 |
24 | def test_change_observation():
25 | new_obs = change_observation(test_obs, test_obs_space, (8, 6))
26 | assert new_obs.shape == (8, 6)
27 | new_obs = change_observation(test_obs, test_obs_space, (4 * 4 * 3,))
28 | assert new_obs.shape == (4 * 4 * 3,)
29 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | from setuptools import setup
4 |
5 |
6 | CWD = pathlib.Path(__file__).absolute().parent
7 |
8 |
9 | def get_version():
10 | """Gets the supersuit version."""
11 | path = CWD / "supersuit" / "__init__.py"
12 | content = path.read_text()
13 |
14 | for line in content.splitlines():
15 | if line.startswith("__version__"):
16 | return line.strip().split()[-1].strip().strip('"')
17 | raise RuntimeError("bad version data in __init__.py")
18 |
19 |
20 | def get_description():
21 | """Gets the description from the readme."""
22 | with open("README.md") as fh:
23 | long_description = ""
24 | header_count = 0
25 | for line in fh:
26 | if line.startswith("##"):
27 | header_count += 1
28 | if header_count < 2:
29 | long_description += line
30 | else:
31 | break
32 | return long_description
33 |
34 |
35 | setup(name="supersuit", version=get_version(), long_description=get_description())
36 |
--------------------------------------------------------------------------------
/supersuit/utils/obs_delay.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 |
3 | import numpy as np
4 |
5 |
6 | class Delayer:
7 | def __init__(self, obs_space, delay):
8 | self.delay = delay
9 | self.obs_queue = deque()
10 | self.obs_space = obs_space
11 |
12 | def add(self, in_obs):
13 | self.obs_queue.append(in_obs)
14 | if len(self.obs_queue) > self.delay:
15 | return self.obs_queue.popleft()
16 | else:
17 | if isinstance(in_obs, np.ndarray):
18 | return np.zeros_like(in_obs)
19 | elif (
20 | isinstance(in_obs, dict)
21 | and "observation" in in_obs.keys()
22 | and "action_mask" in in_obs.keys()
23 | ):
24 | return {
25 | "observation": np.zeros_like(in_obs["observation"]),
26 | "action_mask": np.ones_like(in_obs["action_mask"]),
27 | }
28 | else:
29 | raise TypeError(
30 | "Observation must be of type np.ndarray or dictionary with keys 'observation' and 'action_mask'"
31 | )
32 |
--------------------------------------------------------------------------------
/supersuit/utils/basic_transforms/resize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from . import convert_box
4 |
5 |
6 | def check_param(obs_space, resize):
7 | xsize, ysize, linear_interp = resize
8 | assert all(
9 | isinstance(ds, int) and ds > 0 for ds in [xsize, ysize]
10 | ), "resize x and y sizes must be integers greater than zero."
11 | assert isinstance(
12 | linear_interp, bool
13 | ), "resize linear_interp parameter must be bool."
14 | assert len(obs_space.shape) == 3 or len(obs_space.shape) == 2
15 |
16 |
17 | def change_obs_space(obs_space, param):
18 | return convert_box(lambda obs: change_observation(obs, obs_space, param), obs_space)
19 |
20 |
21 | def change_observation(obs, obs_space, resize):
22 | import tinyscaler
23 |
24 | xsize, ysize, linear_interp = resize
25 | if len(obs.shape) == 2:
26 | obs = obs.reshape(obs.shape + (1,))
27 | interp_method = "bilinear" if linear_interp else "nearest"
28 | obs = tinyscaler.scale(
29 | src=np.ascontiguousarray(obs), size=(xsize, ysize), mode=interp_method
30 | )
31 | if len(obs_space.shape) == 2:
32 | obs = obs.reshape(obs.shape[:2])
33 | return obs
34 |
--------------------------------------------------------------------------------
/supersuit/multiagent_wrappers/agent_indication.py:
--------------------------------------------------------------------------------
1 | from pettingzoo.utils.env import AECEnv, ParallelEnv
2 |
3 | from supersuit import observation_lambda_v0
4 | from supersuit.utils import agent_indicator as agent_ider
5 |
6 |
7 | def agent_indicator_v0(env, type_only=False):
8 | assert isinstance(env, AECEnv) or isinstance(
9 | env, ParallelEnv
10 | ), "agent_indicator_v0 only accepts an AECEnv or ParallelEnv"
11 | assert hasattr(
12 | env, "possible_agents"
13 | ), "environment passed to agent indicator wrapper must have the possible_agents attribute."
14 |
15 | indicator_map = agent_ider.get_indicator_map(env.possible_agents, type_only)
16 | num_indicators = len(set(indicator_map.values()))
17 |
18 | obs_spaces = [env.observation_space(agent) for agent in env.possible_agents]
19 | agent_ider.check_params(obs_spaces)
20 |
21 | return observation_lambda_v0(
22 | env,
23 | lambda obs, obs_space, agent: agent_ider.change_observation(
24 | obs,
25 | obs_space,
26 | (indicator_map[agent], num_indicators),
27 | ),
28 | lambda obs_space: agent_ider.change_obs_space(obs_space, num_indicators),
29 | )
30 |
--------------------------------------------------------------------------------
/supersuit/vector/sb_vector_wrapper.py:
--------------------------------------------------------------------------------
1 | from stable_baselines.common.vec_env.base_vec_env import VecEnv
2 |
3 |
4 | class SBVecEnvWrapper(VecEnv):
5 | def __init__(self, venv):
6 | self.venv = venv
7 | self.num_envs = venv.num_envs
8 | self.observation_space = venv.observation_space
9 | self.action_space = venv.action_space
10 |
11 | def reset(self, seed=None, options=None):
12 | if seed is not None:
13 | self.seed(seed=seed)
14 |
15 | return self.venv.reset()
16 |
17 | def step_async(self, actions):
18 | self.venv.step_async(actions)
19 |
20 | def step_wait(self):
21 | return self.venv.step_wait()
22 |
23 | def step(self, actions):
24 | return self.venv.step(actions)
25 |
26 | def close(self):
27 | del self.venv
28 |
29 | def seed(self, seed=None):
30 | self.venv.seed(seed)
31 |
32 | def get_attr(self, attr_name, indices=None):
33 | raise NotImplementedError()
34 |
35 | def set_attr(self, attr_name, value, indices=None):
36 | raise NotImplementedError()
37 |
38 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
39 | raise NotImplementedError()
40 |
--------------------------------------------------------------------------------
/supersuit/utils/basic_transforms/color_reduction.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from . import convert_box
4 |
5 |
6 | COLOR_RED_LIST = ["full", "R", "G", "B"]
7 | GRAYSCALE_WEIGHTS = np.array([0.299, 0.587, 0.114], dtype=np.float32)
8 |
9 |
10 | def check_param(space, color_reduction):
11 | assert isinstance(
12 | color_reduction, str
13 | ), f"color_reduction must be str. It is {color_reduction}"
14 | assert color_reduction in COLOR_RED_LIST, "color_reduction must be in {}".format(
15 | COLOR_RED_LIST
16 | )
17 | assert (
18 | len(space.low.shape) == 3 and space.low.shape[2] == 3
19 | ), "To apply color_reduction, shape must be a 3d image with last dimension of size 3. Shape is {}".format(
20 | space.low.shape
21 | )
22 |
23 |
24 | def change_obs_space(obs_space, param):
25 | return convert_box(lambda obs: change_observation(obs, obs_space, param), obs_space)
26 |
27 |
28 | def change_observation(obs, obs_space, color_reduction):
29 | if color_reduction == "R":
30 | obs = obs[:, :, 0]
31 | if color_reduction == "G":
32 | obs = obs[:, :, 1]
33 | if color_reduction == "B":
34 | obs = obs[:, :, 2]
35 | if color_reduction == "full":
36 | obs = (obs.astype(np.float32) @ GRAYSCALE_WEIGHTS).astype(np.uint8)
37 | return obs
38 |
--------------------------------------------------------------------------------
/test/test_utils/test_action_transforms/test_homogenize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gymnasium.spaces import Box, Discrete
3 |
4 | from supersuit.utils.action_transforms.homogenize_ops import (
5 | check_homogenize_spaces,
6 | dehomogenize_actions,
7 | homogenize_observations,
8 | homogenize_spaces,
9 | )
10 |
11 |
12 | box_spaces = [
13 | Box(low=np.float32(0), high=np.float32(1), shape=(5, 4)),
14 | Box(low=np.float32(0), high=np.float32(1), shape=(10, 2)),
15 | ]
16 | discrete_spaces = [Discrete(5), Discrete(7)]
17 |
18 |
19 | def test_param_check():
20 | check_homogenize_spaces(box_spaces)
21 | check_homogenize_spaces(discrete_spaces)
22 |
23 |
24 | def test_homogenize_spaces():
25 | hom_space_box = homogenize_spaces(box_spaces)
26 | hom_space_discrete = homogenize_spaces(discrete_spaces)
27 | assert hom_space_box.shape == (10, 4)
28 | assert hom_space_discrete.n == 7
29 |
30 |
31 | def test_dehomogenize_actions():
32 | action = np.ones([10, 4])
33 | assert dehomogenize_actions(box_spaces[0], action).shape == (5, 4)
34 | assert dehomogenize_actions(discrete_spaces[0], 5) == 0
35 | assert dehomogenize_actions(discrete_spaces[0], 4) == 4
36 |
37 |
38 | def test_homogenize_observations():
39 | obs = np.zeros([5, 4])
40 | hom_space_box = homogenize_spaces(box_spaces)
41 | assert homogenize_observations(hom_space_box, obs).shape == (10, 4)
42 |
--------------------------------------------------------------------------------
/test/test_vector/test_pettingzoo_to_vec_gymnasium_wrappers.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 | import numpy as np
3 | import pytest
4 | from gymnasium.wrappers import NormalizeObservation
5 | from pettingzoo.butterfly import pistonball_v6
6 |
7 | import supersuit as ss
8 |
9 |
10 | @pytest.mark.parametrize("env_fn", [pistonball_v6])
11 | def test_vec_env_normalize_obs(env_fn):
12 | env = env_fn.parallel_env()
13 | env = ss.pettingzoo_env_to_vec_env_v1(env)
14 | env = ss.concat_vec_envs_v1(env, 10, base_class="gymnasium")
15 | obs, info = env.reset()
16 |
17 | # Create a "dummy" class that adds gym.Env as a base class of the env
18 | # to satisfy the assertion.
19 | class PatchedEnv(env.__class__, gym.Env):
20 | pass
21 |
22 | env.__class__ = PatchedEnv
23 |
24 | env = NormalizeObservation(env)
25 | normalized_obs, normalized_info = env.reset()
26 |
27 | obs_range = np.amax(obs) - np.amin(obs)
28 | normalized_obs_range = np.amax(normalized_obs) - np.amin(normalized_obs)
29 | assert obs_range > 1, "Regular observation space should be greater than 1."
30 | assert (
31 | normalized_obs_range < 1.0e-4
32 | ), "Normalized observation space should be smaller than 1.0e-4."
33 | assert (
34 | obs_range > normalized_obs_range
35 | ), "Normalized observation space has more range than regular observation space."
36 |
--------------------------------------------------------------------------------
/test/test_vector/test_render.py:
--------------------------------------------------------------------------------
1 | from pettingzoo.butterfly import pistonball_v6
2 | from stable_baselines3.common.vec_env import VecVideoRecorder
3 |
4 | import supersuit as ss
5 |
6 |
7 | def schedule(episode_idx):
8 | print(episode_idx)
9 | return episode_idx <= 1
10 |
11 |
12 | def make_sb3_record_env():
13 | env = pistonball_v6.parallel_env(render_mode="rgb_array")
14 | print(env.render_mode)
15 | env = ss.pettingzoo_env_to_vec_env_v1(env)
16 | envs = ss.concat_vec_envs_v1(env, 1, num_cpus=0, base_class="stable_baselines3")
17 | envs = VecVideoRecorder(envs, "/tmp", schedule)
18 | return envs
19 |
20 |
21 | def test_record_video_sb3():
22 | envs = make_sb3_record_env()
23 | envs.reset()
24 | for _ in range(100):
25 | envs.step([envs.action_space.sample() for _ in range(envs.num_envs)])
26 | envs.close()
27 |
28 |
29 | def make_env():
30 | env = pistonball_v6.parallel_env(render_mode="rgb_array")
31 | env = ss.pettingzoo_env_to_vec_env_v1(env)
32 | return env
33 |
34 |
35 | def test_vector_render_multiproc():
36 | env = make_env()
37 | num_envs = 1
38 | venv = ss.concat_vec_envs_v1(
39 | env, num_envs, num_cpus=num_envs, base_class="stable_baselines3"
40 | )
41 | venv.reset()
42 | arr = venv.render()
43 | venv.reset()
44 | assert len(arr.shape) == 3 and arr.shape[2] == 3
45 | venv.reset()
46 | try:
47 | venv.close()
48 | except RuntimeError:
49 | pass
50 |
--------------------------------------------------------------------------------
/test/test_vector/test_env_is_wrapped.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 | import pytest
3 | from pettingzoo.mpe import simple_spread_v3
4 |
5 | from supersuit import concat_vec_envs_v1, pettingzoo_env_to_vec_env_v1
6 | from supersuit.generic_wrappers.frame_skip import frame_skip_gym
7 |
8 |
9 | def test_env_is_wrapped_true():
10 | env = gymnasium.make("MountainCarContinuous-v0")
11 | env = frame_skip_gym(env, 4)
12 | num_envs = 3
13 | venv1 = concat_vec_envs_v1(env, num_envs)
14 | assert venv1.env_is_wrapped(frame_skip_gym) == [True] * 3
15 |
16 |
17 | def test_env_is_wrapped_false():
18 | env = gymnasium.make("MountainCarContinuous-v0")
19 | num_envs = 3
20 | venv1 = concat_vec_envs_v1(env, num_envs)
21 | assert venv1.env_is_wrapped(frame_skip_gym) == [False] * 3
22 |
23 |
24 | @pytest.mark.skip(
25 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188"
26 | )
27 | def test_env_is_wrapped_cpu():
28 | env = gymnasium.make("MountainCarContinuous-v0")
29 | env = frame_skip_gym(env, 4)
30 | num_envs = 3
31 | venv1 = concat_vec_envs_v1(env, num_envs, num_cpus=2)
32 | assert venv1.env_is_wrapped(frame_skip_gym) == [True] * 3
33 |
34 |
35 | def test_env_is_wrapped_pettingzoo():
36 | env = simple_spread_v3.parallel_env()
37 | venv1 = pettingzoo_env_to_vec_env_v1(env)
38 | num_envs = 3
39 | venv1 = concat_vec_envs_v1(venv1, num_envs)
40 | assert venv1.env_is_wrapped(frame_skip_gym) == [False] * 9
41 |
--------------------------------------------------------------------------------
/supersuit/vector/constructors.py:
--------------------------------------------------------------------------------
1 | from .concat_vec_env import ConcatVecEnv
2 | from .multiproc_vec import ProcConcatVec
3 |
4 |
5 | class call_wrap:
6 | def __init__(self, fn, data):
7 | self.fn = fn
8 | self.data = data
9 |
10 | def __call__(self, *args):
11 | return self.fn(self.data)
12 |
13 |
14 | def MakeCPUAsyncConstructor(max_num_cpus):
15 | if max_num_cpus == 0 or max_num_cpus == 1:
16 | return ConcatVecEnv
17 | else:
18 |
19 | def constructor(env_fn_list, obs_space, act_space):
20 | example_env = env_fn_list[0]()
21 | envs_per_env = getattr(example_env, "num_envs", 1)
22 |
23 | num_fns = len(env_fn_list)
24 | envs_per_cpu = (num_fns + max_num_cpus - 1) // max_num_cpus
25 |
26 | env_cpu_div = []
27 | num_envs_alloced = 0
28 | while num_envs_alloced < num_fns:
29 | start_idx = num_envs_alloced
30 | end_idx = min(num_fns, start_idx + envs_per_cpu)
31 | env_cpu_div.append(env_fn_list[start_idx:end_idx])
32 | num_envs_alloced = end_idx
33 |
34 | cat_env_fns = [call_wrap(ConcatVecEnv, env_fns) for env_fns in env_cpu_div]
35 | return ProcConcatVec(
36 | cat_env_fns,
37 | obs_space,
38 | act_space,
39 | num_fns * envs_per_env,
40 | example_env.metadata,
41 | )
42 |
43 | return constructor
44 |
--------------------------------------------------------------------------------
/test/test_utils/test_basic_transforms/test_normalized_obs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from gymnasium.spaces import Box
4 |
5 | from supersuit.utils.basic_transforms.normalize_obs import (
6 | change_obs_space,
7 | change_observation,
8 | check_param,
9 | )
10 |
11 |
12 | high_val = np.array([1, 2, 4])
13 | test_val = np.array([1, 1, 1])
14 | test_obs_space = Box(
15 | low=np.zeros(3, dtype=np.float32), high=high_val.astype(np.float32)
16 | )
17 | bad_test_obs_space = Box(
18 | low=np.zeros(3, dtype=np.int32), high=high_val.astype(np.int32), dtype=np.int32
19 | )
20 | bad_test_obs_space2 = Box(
21 | low=np.zeros(3, dtype=np.float32), high=np.inf * high_val.astype(np.float32)
22 | )
23 |
24 |
25 | def test_param_check():
26 | check_param(test_obs_space, (2, 3))
27 | with pytest.raises(AssertionError):
28 | check_param(test_obs_space, (2, 2))
29 | with pytest.raises(AssertionError):
30 | check_param(test_obs_space, ("bob", 2))
31 | with pytest.raises(AssertionError):
32 | check_param(bad_test_obs_space, (2, 3))
33 | with pytest.raises(AssertionError):
34 | check_param(bad_test_obs_space2, (2, 3))
35 |
36 |
37 | def test_change_obs_space():
38 | assert np.all(
39 | np.equal(change_obs_space(test_obs_space, (1, 2)).high, np.array([2, 2, 2]))
40 | )
41 |
42 |
43 | def test_change_observation():
44 | assert np.all(
45 | np.equal(
46 | change_observation(test_val, test_obs_space, (1, 2)),
47 | np.array([2, 1.5, 1.25]),
48 | )
49 | )
50 |
--------------------------------------------------------------------------------
/supersuit/utils/base_aec_wrapper.py:
--------------------------------------------------------------------------------
1 | from pettingzoo.utils.wrappers import OrderEnforcingWrapper as PZBaseWrapper
2 |
3 |
4 | class BaseWrapper(PZBaseWrapper):
5 | def __init__(self, env):
6 | """
7 | Creates a wrapper around `env`. Extend this class to create changes to the space.
8 | """
9 | super().__init__(env)
10 |
11 | self._check_wrapper_params()
12 |
13 | self._modify_spaces()
14 |
15 | def _check_wrapper_params(self):
16 | pass
17 |
18 | def _modify_spaces(self):
19 | pass
20 |
21 | def _modify_action(self, agent, action):
22 | raise NotImplementedError()
23 |
24 | def _modify_observation(self, agent, observation):
25 | raise NotImplementedError()
26 |
27 | def _update_step(self, agent):
28 | pass
29 |
30 | def reset(self, seed=None, options=None):
31 | super().reset(seed=seed, options=options)
32 | self._update_step(self.agent_selection)
33 |
34 | def observe(self, agent):
35 | obs = super().observe(
36 | agent
37 | ) # problem is in this line, the obs is sometimes a different size from the obs space
38 | observation = self._modify_observation(agent, obs)
39 | return observation
40 |
41 | def step(self, action):
42 | agent = self.env.agent_selection
43 | if not (self.terminations[agent] or self.truncations[agent]):
44 | action = self._modify_action(agent, action)
45 |
46 | super().step(action)
47 |
48 | self._update_step(self.agent_selection)
49 |
--------------------------------------------------------------------------------
/.github/workflows/linux-test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 |
12 | permissions:
13 | contents: read
14 |
15 | jobs:
16 | build:
17 |
18 | runs-on: ubuntu-22.04
19 | strategy:
20 | matrix:
21 | python-version: ['3.9', '3.10', '3.11', '3.12']
22 |
23 | steps:
24 | - uses: actions/checkout@v2
25 | - name: Set up Python ${{ matrix.python-version }}
26 | uses: actions/setup-python@v2
27 | with:
28 | python-version: ${{ matrix.python-version }}
29 | - name: Install ubuntu dependencies
30 | run: |
31 | sudo apt-get install python3-opengl xvfb
32 | - name: Install python dependencies
33 | run: |
34 | pip install -e .[testing]
35 | # Install pettingzoo directly from master until it mints a new version to resolve
36 | # pygame version restriction.
37 | pip install "pettingzoo[all,atari] @ git+https://github.com/Farama-Foundation/PettingZoo.git@master"
38 | AutoROM -v
39 | - name: Test with pytest
40 | run: |
41 | xvfb-run -s "-screen 0 1400x900x24" pytest ./test
42 | - name: Test installation
43 | run: |
44 | python -m pip install --upgrade build
45 | python -m build --sdist
46 | pip install dist/*.tar.gz
47 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v4.4.0
6 | hooks:
7 | - id: check-symlinks
8 | - id: destroyed-symlinks
9 | - id: trailing-whitespace
10 | - id: end-of-file-fixer
11 | - id: check-yaml
12 | - id: check-toml
13 | - id: check-ast
14 | - id: check-added-large-files
15 | - id: check-merge-conflict
16 | - id: check-executables-have-shebangs
17 | - id: check-shebang-scripts-are-executable
18 | - id: detect-private-key
19 | - id: debug-statements
20 | - repo: https://github.com/codespell-project/codespell
21 | rev: v2.2.2
22 | hooks:
23 | - id: codespell
24 | args:
25 | - --ignore-words-list=magent
26 | - repo: https://github.com/PyCQA/flake8
27 | rev: 6.0.0
28 | hooks:
29 | - id: flake8
30 | args:
31 | - '--per-file-ignores=*/__init__.py:F401'
32 | - --ignore=E203,W503,E741
33 | - --max-complexity=30
34 | - --max-line-length=456
35 | - --show-source
36 | - --statistics
37 | - repo: https://github.com/asottile/pyupgrade
38 | rev: v3.3.1
39 | hooks:
40 | - id: pyupgrade
41 | args: ["--py37-plus"]
42 | - repo: https://github.com/PyCQA/isort
43 | rev: 5.12.0
44 | hooks:
45 | - id: isort
46 | args: ["--profile", "black"]
47 | - repo: https://github.com/python/black
48 | rev: 23.1.0
49 | hooks:
50 | - id: black
51 |
--------------------------------------------------------------------------------
/supersuit/utils/basic_transforms/normalize_obs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gymnasium.spaces import Box
3 |
4 |
5 | def check_param(obs_space, min_max_pair):
6 | assert np.dtype(obs_space.dtype) == np.dtype("float32") or np.dtype(
7 | obs_space.dtype
8 | ) == np.dtype("float64")
9 | assert (
10 | isinstance(min_max_pair, tuple) and len(min_max_pair) == 2
11 | ), f"range_scale must be tuple of size 2. It is {min_max_pair}"
12 | try:
13 | min_res = float(min_max_pair[0])
14 | max_res = float(min_max_pair[1])
15 | except ValueError:
16 | assert False, "normalize_obs inputs must be numbers. They are {}".format(
17 | min_max_pair
18 | )
19 | assert (
20 | max_res > min_res
21 | ), "maximum must be greater than minimum value in normalize_obs"
22 | assert np.all(np.isfinite(obs_space.low)) and np.all(
23 | np.isfinite(obs_space.high)
24 | ), "Box observation_space of environment has infinite bounds! Only environments with finite bounds can be passed to normalize_obs_v0"
25 |
26 |
27 | def change_obs_space(obs_space, min_max_pair):
28 | min = np.float64(min_max_pair[0]).astype(obs_space.dtype)
29 | max = np.float64(min_max_pair[1]).astype(obs_space.dtype)
30 | return Box(low=min, high=max, shape=obs_space.shape, dtype=obs_space.dtype)
31 |
32 |
33 | def change_observation(obs, obs_space, min_max_pair):
34 | min_res, max_res = (float(x) for x in min_max_pair)
35 | old_size = obs_space.high - obs_space.low
36 | new_size = max_res - min_res
37 | result = (obs - obs_space.low) / old_size * new_size + min_res
38 | return result
39 |
--------------------------------------------------------------------------------
/supersuit/multiagent_wrappers/padding_wrappers.py:
--------------------------------------------------------------------------------
1 | from pettingzoo.utils.env import AECEnv, ParallelEnv
2 |
3 | from supersuit import action_lambda_v1, observation_lambda_v0
4 | from supersuit.utils.action_transforms import homogenize_ops
5 |
6 |
7 | def pad_action_space_v0(env):
8 | assert isinstance(env, AECEnv) or isinstance(
9 | env, ParallelEnv
10 | ), "pad_action_space_v0 only accepts an AECEnv or ParallelEnv"
11 | assert hasattr(
12 | env, "possible_agents"
13 | ), "environment passed to pad_observations must have a possible_agents list."
14 | spaces = [env.action_space(agent) for agent in env.possible_agents]
15 | homogenize_ops.check_homogenize_spaces(spaces)
16 | padded_space = homogenize_ops.homogenize_spaces(spaces)
17 | return action_lambda_v1(
18 | env,
19 | lambda action, act_space: homogenize_ops.dehomogenize_actions(
20 | act_space, action
21 | ),
22 | lambda act_space: padded_space,
23 | )
24 |
25 |
26 | def pad_observations_v0(env):
27 | assert isinstance(env, AECEnv) or isinstance(
28 | env, ParallelEnv
29 | ), "pad_observations_v0 only accepts an AECEnv or ParallelEnv"
30 | assert hasattr(
31 | env, "possible_agents"
32 | ), "environment passed to pad_observations must have a possible_agents list."
33 | spaces = [env.observation_space(agent) for agent in env.possible_agents]
34 | homogenize_ops.check_homogenize_spaces(spaces)
35 | padded_space = homogenize_ops.homogenize_spaces(spaces)
36 | return observation_lambda_v0(
37 | env,
38 | lambda obs, obs_space: homogenize_ops.homogenize_observations(
39 | padded_space, obs
40 | ),
41 | lambda obs_space: padded_space,
42 | )
43 |
--------------------------------------------------------------------------------
/supersuit/vector/single_vec_env.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 | import numpy as np
3 |
4 |
5 | class SingleVecEnv:
6 | def __init__(self, gym_env_fns, *args):
7 | assert len(gym_env_fns) == 1
8 | self.gym_env = gym_env_fns[0]()
9 | self.render_mode = self.gym_env.render_mode
10 | self.observation_space = self.gym_env.observation_space
11 | self.action_space = self.gym_env.action_space
12 | self.num_envs = 1
13 | self.metadata = self.gym_env.metadata
14 |
15 | def reset(self, seed=None, options=None):
16 | # TODO: should this include info
17 | return np.expand_dims(self.gym_env.reset(seed=seed, options=options), 0)
18 |
19 | def step_async(self, actions):
20 | self._saved_actions = actions
21 |
22 | def step_wait(self):
23 | return self.step(self._saved_actions)
24 |
25 | def render(self):
26 | return self.gym_env.render()
27 |
28 | def close(self):
29 | self.gym_env.close()
30 |
31 | def step(self, actions):
32 | observations, reward, term, trunc, info = self.gym_env.step(actions[0])
33 | if term or trunc:
34 | observations = self.gym_env.reset()
35 | observations = np.expand_dims(observations, 0)
36 | rewards = np.array([reward], dtype=np.float32)
37 | terms = np.array([term], dtype=np.uint8)
38 | truncs = np.array([trunc], dtype=np.uint8)
39 | infos = [info]
40 | return observations, rewards, terms, truncs, infos
41 |
42 | def env_is_wrapped(self, wrapper_class):
43 | env_tmp = self.gym_env
44 | while isinstance(env_tmp, gymnasium.Wrapper):
45 | if isinstance(env_tmp, wrapper_class):
46 | return [True]
47 | env_tmp = env_tmp.env
48 | return [False]
49 |
--------------------------------------------------------------------------------
/supersuit/utils/wrapper_chooser.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 | from pettingzoo.utils.conversions import aec_to_parallel, parallel_to_aec
3 | from pettingzoo.utils.env import AECEnv, ParallelEnv
4 |
5 |
6 | class WrapperChooser:
7 | def __init__(self, aec_wrapper=None, gym_wrapper=None, parallel_wrapper=None):
8 | assert (
9 | aec_wrapper is not None or parallel_wrapper is not None
10 | ), "either the aec wrapper or the parallel wrapper must be defined for all supersuit environments"
11 | self.aec_wrapper = aec_wrapper
12 | self.gym_wrapper = gym_wrapper
13 | self.parallel_wrapper = parallel_wrapper
14 |
15 | def __call__(self, env, *args, **kwargs):
16 | if isinstance(env, gymnasium.Env):
17 | if self.gym_wrapper is None:
18 | raise ValueError(
19 | f"{self.wrapper_name} does not apply to gymnasium environments, pettingzoo environments only"
20 | )
21 | return self.gym_wrapper(env, *args, **kwargs)
22 | elif isinstance(env, AECEnv):
23 | if self.aec_wrapper is not None:
24 | return self.aec_wrapper(env, *args, **kwargs)
25 | else:
26 | return parallel_to_aec(
27 | self.parallel_wrapper(aec_to_parallel(env), *args, **kwargs)
28 | )
29 | elif isinstance(env, ParallelEnv):
30 | if self.parallel_wrapper is not None:
31 | return self.parallel_wrapper(env, *args, **kwargs)
32 | else:
33 | return aec_to_parallel(
34 | self.aec_wrapper(parallel_to_aec(env), *args, **kwargs)
35 | )
36 | else:
37 | raise ValueError(
38 | "environment passed to supersuit wrapper must either be a gymnasium environment or a pettingzoo environment"
39 | )
40 |
--------------------------------------------------------------------------------
/test/test_utils/test_basic_transforms/test_resize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from gymnasium.spaces import Box
4 |
5 | from supersuit.utils.basic_transforms.resize import change_observation, check_param
6 |
7 |
8 | test_shape = (6, 4, 3)
9 | high_val = np.ones(test_shape) + np.arange(4).reshape(1, 4, 1)
10 | test_obs_space = Box(low=high_val - 1, high=high_val, dtype=np.uint8)
11 | test_val = high_val - 0.5
12 |
13 |
14 | def test_param_check():
15 | check_param(test_obs_space, (2, 2, False))
16 | with pytest.raises(AssertionError):
17 | check_param(test_obs_space, (-2, 2, 2))
18 | with pytest.raises(AssertionError):
19 | check_param(test_obs_space, (2.0, 2, 2))
20 |
21 |
22 | def test_change_observation():
23 | cur_val = test_val
24 | cur_val = cur_val.astype(np.uint8)
25 | new_obs = change_observation(cur_val, test_obs_space, (3, 2, False))
26 | new_obs = change_observation(cur_val, test_obs_space, (3, 2, True))
27 | test_obs = np.array(
28 | [
29 | [
30 | [0, 0, 0],
31 | [1, 1, 1],
32 | [2, 2, 2],
33 | ],
34 | [
35 | [0, 0, 0],
36 | [1, 1, 1],
37 | [2, 2, 2],
38 | ],
39 | ]
40 | ).astype(np.uint8)
41 |
42 | assert new_obs.dtype == np.uint8
43 | assert np.all(np.equal(new_obs, test_obs))
44 |
45 | test_shape = (6, 4)
46 | high_val = np.ones(test_shape).astype(np.float64)
47 | obs_spae = Box(low=high_val - 1, high=high_val)
48 | new_obs = change_observation(high_val - 0.5, obs_spae, (3, 2, False))
49 | assert new_obs.shape == (2, 3)
50 | assert new_obs.dtype == np.float64
51 |
52 | test_shape = (6, 5, 4)
53 | high_val = np.ones(test_shape).astype(np.uint8)
54 | obs_spae = Box(low=high_val - 1, high=high_val, dtype=np.uint8)
55 | new_obs = change_observation(high_val, obs_spae, (5, 2, False))
56 | assert new_obs.shape == (2, 5, 4)
57 | assert new_obs.dtype == np.uint8
58 |
--------------------------------------------------------------------------------
/.github/workflows/build-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will build and (if release) publish Python distributions to PyPI
2 | # For more information see:
3 | # - https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
4 | # - https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
5 | #
6 |
7 | name: build-publish
8 |
9 | on:
10 | workflow_dispatch:
11 | push:
12 | branches: [master]
13 | pull_request:
14 | paths:
15 | - .github/workflows/build-publish.yml
16 | release:
17 | types: [published]
18 |
19 | jobs:
20 | build:
21 | name: Build sdist and wheel
22 | runs-on: ubuntu-latest
23 | permissions:
24 | contents: read
25 | steps:
26 | - uses: actions/checkout@v4
27 | - name: Set up Python
28 | uses: actions/setup-python@v4
29 | with:
30 | python-version: 3.x
31 | - name: Install dependencies
32 | run: python -m pip install --upgrade setuptools wheel build
33 | - name: Build sdist and wheel
34 | run: python -m build --sdist --wheel
35 | - name: Store sdist and wheel
36 | uses: actions/upload-artifact@v4
37 | with:
38 | name: artifact
39 | path: dist
40 | publish:
41 | name: Publish to PyPI
42 | needs: [build]
43 | runs-on: ubuntu-latest
44 | environment: pypi
45 | if: github.event_name == 'release' && github.event.action == 'published'
46 | steps:
47 | - name: Download dist
48 | uses: actions/download-artifact@v4
49 | with:
50 | name: artifact
51 | path: dist
52 | - name: Publish
53 | uses: pypa/gh-action-pypi-publish@release/v1
54 | with:
55 | password: ${{ secrets.PYPI_API_TOKEN }}
56 |
--------------------------------------------------------------------------------
/supersuit/aec_vector/base_aec_vec_env.py:
--------------------------------------------------------------------------------
1 | class VectorAECEnv:
2 | def reset(self, seed=None, options=None):
3 | """
4 | resets all environments
5 | """
6 |
7 | def observe(self, agent):
8 | """
9 | returns observation for agent from all environments (if agent is alive, else all zeros)
10 | """
11 |
12 | def last(self, observe=True):
13 | """
14 | returns list of observations, rewards, dones, env_dones, passes, infos
15 |
16 | each of the following is a list over environments that holds the value for the current agent (env.agent_selection)
17 |
18 | dones: are True when the current agent is done
19 | env_dones: is True when all agents are done, and the environment will reset
20 | passes: is true when the agent is not stepping this turn (because it is dead or not currently stepping for some other reason)
21 | infos: list of infos for the agent
22 | """
23 |
24 | def step(self, actions, observe=True):
25 | """
26 | steps the current agent with the following actions.
27 | Unlike a regular AECEnv, the actions cannot be None
28 | """
29 |
30 | def agent_iter(self, max_iter):
31 | """
32 | Unlike aec agent_iter, this does not stop on environment done. Instead,
33 | vector environment resets specific envs when done.
34 |
35 | Instead, just continues until max_iter is reached.
36 | """
37 | return AECIterable(self, max_iter)
38 |
39 |
40 | class AECIterable:
41 | def __init__(self, env, max_iter):
42 | self.env = env
43 | self.max_iter = max_iter
44 |
45 | def __iter__(self):
46 | return AECIterator(self.env, self.max_iter)
47 |
48 |
49 | class AECIterator:
50 | def __init__(self, env, max_iter):
51 | self.env = env
52 | self.iters_til_term = max_iter
53 | self.env._is_iterating = True
54 |
55 | def __next__(self):
56 | if self.iters_til_term <= 0:
57 | raise StopIteration
58 | self.iters_til_term -= 1
59 | return self.env.agent_selection
60 |
--------------------------------------------------------------------------------
/test/dummy_aec_env.py:
--------------------------------------------------------------------------------
1 | from pettingzoo import AECEnv
2 | from pettingzoo.utils.agent_selector import agent_selector
3 |
4 |
5 | class DummyEnv(AECEnv):
6 | metadata = {"render_modes": ["human"], "is_parallelizable": True}
7 |
8 | def __init__(self, observations, observation_spaces, action_spaces):
9 | super().__init__()
10 | self._observations = observations
11 | self._observation_spaces = observation_spaces
12 | self.render_mode = None
13 | self.agents = sorted([x for x in observation_spaces.keys()])
14 | self.possible_agents = self.agents[:]
15 | self._agent_selector = agent_selector(self.agents)
16 | self.agent_selection = self._agent_selector.reset()
17 | self._action_spaces = action_spaces
18 |
19 | self.steps = 0
20 |
21 | def observation_space(self, agent):
22 | return self._observation_spaces[agent]
23 |
24 | def action_space(self, agent):
25 | return self._action_spaces[agent]
26 |
27 | def observe(self, agent):
28 | return self._observations[agent]
29 |
30 | def step(self, action, observe=True):
31 | if (
32 | self.terminations[self.agent_selection]
33 | or self.truncations[self.agent_selection]
34 | ):
35 | return self._was_dead_step(action)
36 | self._cumulative_rewards[self.agent_selection] = 0
37 | self.agent_selection = self._agent_selector.next()
38 | self.steps += 1
39 | if self.steps >= 5 * len(self.agents):
40 | self.truncations = {a: True for a in self.agents}
41 |
42 | self._accumulate_rewards()
43 | self._deads_step_first()
44 |
45 | def reset(self, seed=None, options=None):
46 | self.agents = self.possible_agents[:]
47 | self._agent_selector = agent_selector(self.agents)
48 | self.agent_selection = self._agent_selector.reset()
49 | self.rewards = {a: 1 for a in self.agents}
50 | self._cumulative_rewards = {a: 0 for a in self.agents}
51 | self.terminations = {a: False for a in self.agents}
52 | self.truncations = {a: False for a in self.agents}
53 | self.infos = {a: {} for a in self.agents}
54 | self.steps = 0
55 |
56 | def close(self):
57 | pass
58 |
--------------------------------------------------------------------------------
/supersuit/multiagent_wrappers/black_death.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 | import numpy as np
3 | from pettingzoo.utils.wrappers import BaseParallelWrapper
4 |
5 | from supersuit.utils.wrapper_chooser import WrapperChooser
6 |
7 |
8 | class black_death_par(BaseParallelWrapper):
9 | def __init__(self, env):
10 | super().__init__(env)
11 |
12 | def _check_valid_for_black_death(self):
13 | for agent in self.agents:
14 | space = self.observation_space(agent)
15 | assert isinstance(
16 | space, gymnasium.spaces.Box
17 | ), f"observation sapces for black death must be Box spaces, is {space}"
18 |
19 | def reset(self, seed=None, options=None):
20 | obss, infos = self.env.reset(seed=seed, options=options)
21 |
22 | self.agents = self.env.agents[:]
23 | self._check_valid_for_black_death()
24 | black_obs = {
25 | agent: np.zeros_like(self.observation_space(agent).low)
26 | for agent in self.agents
27 | if agent not in obss
28 | }
29 | return {**obss, **black_obs}, infos
30 |
31 | def step(self, actions):
32 | active_actions = {agent: actions[agent] for agent in self.env.agents}
33 | obss, rews, terms, truncs, infos = self.env.step(active_actions)
34 | black_obs = {
35 | agent: np.zeros_like(self.observation_space(agent).low)
36 | for agent in self.agents
37 | if agent not in obss
38 | }
39 | black_rews = {agent: 0.0 for agent in self.agents if agent not in obss}
40 | black_infos = {agent: {} for agent in self.agents if agent not in obss}
41 | terminations = np.fromiter(terms.values(), dtype=bool)
42 | truncations = np.fromiter(truncs.values(), dtype=bool)
43 | env_is_done = (terminations & truncations).all()
44 | total_obs = {**black_obs, **obss}
45 | total_rews = {**black_rews, **rews}
46 | total_infos = {**black_infos, **infos}
47 | total_dones = {agent: env_is_done for agent in self.agents}
48 | if env_is_done:
49 | self.agents.clear()
50 | return total_obs, total_rews, total_dones, total_dones, total_infos
51 |
52 |
53 | black_death_v3 = WrapperChooser(parallel_wrapper=black_death_par)
54 |
--------------------------------------------------------------------------------
/test/parallel_env_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gymnasium.spaces import Box, Discrete
3 | from pettingzoo.utils import ParallelEnv
4 |
5 | import supersuit
6 |
7 |
8 | class DummyParEnv(ParallelEnv):
9 | metadata = {"render_modes": ["human"]}
10 |
11 | def __init__(self, observations, observation_spaces, action_spaces):
12 | super().__init__()
13 | self._observations = observations
14 | self._observation_spaces = observation_spaces
15 |
16 | self.agents = [x for x in observation_spaces.keys()]
17 | self.possible_agents = self.agents
18 | self.agent_selection = self.agents[0]
19 | self._action_spaces = action_spaces
20 |
21 | self.rewards = {a: 1 for a in self.agents}
22 | self.terminations = {a: False for a in self.agents}
23 | self.truncations = {a: False for a in self.agents}
24 | self.infos = {a: {} for a in self.agents}
25 |
26 | def observation_space(self, agent):
27 | return self._observation_spaces[agent]
28 |
29 | def action_space(self, agent):
30 | return self._action_spaces[agent]
31 |
32 | def step(self, actions):
33 | for agent, action in actions.items():
34 | assert action in self.action_space(agent)
35 | return (
36 | self._observations,
37 | self.rewards,
38 | self.terminations,
39 | self.truncations,
40 | self.infos,
41 | )
42 |
43 | def reset(self, seed=None, options=None):
44 | return self._observations, self.infos
45 |
46 | def close(self):
47 | pass
48 |
49 |
50 | base_obs = {
51 | f"a{idx}": np.zeros([8, 8, 3], dtype=np.float32) + np.arange(3) + idx
52 | for idx in range(2)
53 | }
54 | base_obs_space = {
55 | f"a{idx}": Box(low=np.float32(0.0), high=np.float32(10.0), shape=[8, 8, 3])
56 | for idx in range(2)
57 | }
58 | base_act_spaces = {f"a{idx}": Discrete(5) for idx in range(2)}
59 |
60 |
61 | def test_basic():
62 | env = DummyParEnv(base_obs, base_obs_space, base_act_spaces)
63 | env = supersuit.delay_observations_v0(env, 4)
64 | env = supersuit.dtype_v0(env, np.uint8)
65 | env.reset()
66 | for i in range(10):
67 | action = {agent: env.action_space(agent).sample() for agent in env.agents}
68 | env.step(action)
69 |
--------------------------------------------------------------------------------
/supersuit/vector/sb3_vector_wrapper.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Any, List, Optional
3 |
4 | import numpy as np
5 | from stable_baselines3.common.vec_env import VecEnvWrapper
6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnvIndices
7 |
8 |
9 | class SB3VecEnvWrapper(VecEnvWrapper):
10 | def __init__(self, venv):
11 | self.venv = venv
12 | self.num_envs = venv.num_envs
13 | self.observation_space = venv.observation_space
14 | self.action_space = venv.action_space
15 | self.render_mode = venv.render_mode
16 | self.reset_infos = []
17 |
18 | def reset(self, seed=None, options=None):
19 | if seed is not None:
20 | self.seed(seed=seed)
21 | # Note: SB3's vector envs return only observations on reset, and store infos in `self.reset_infos`
22 | observations, self.reset_infos = self.venv.reset()
23 | return observations
24 |
25 | def step_wait(self):
26 | observations, rewards, terminations, truncations, infos = self.venv.step_wait()
27 | # Note: SB3 expects dones to be an np.array
28 | dones = np.array(
29 | [terminations[i] or truncations[i] for i in range(len(terminations))]
30 | )
31 | return observations, rewards, dones, infos
32 |
33 | def env_is_wrapped(self, wrapper_class, indices=None):
34 | # ignores indices
35 | return self.venv.env_is_wrapped(wrapper_class)
36 |
37 | def getattr_recursive(self, name):
38 | raise AttributeError(name)
39 |
40 | def getattr_depth_check(self, name, already_found):
41 | return None
42 |
43 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
44 | attr = self.venv.get_attr(attr_name)
45 | # Note: SB3 expects render_mode to be returned as an array, with values for each env
46 | if attr_name == "render_mode":
47 | return [attr for _ in range(self.num_envs)]
48 | else:
49 | return attr
50 |
51 | def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
52 | warnings.warn(
53 | "PettingZoo environments do not take the `render(mode)` argument, to change rendering mode, re-initialize the environment using the `render_mode` argument."
54 | )
55 | return self.venv.render()
56 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | **Aug 11, 2025: This project is semi-depricated, and is unmaintained except for being kept operational with new versions of PettingZoo until the relevant functionality can be merged into pettingzoo.wrappers.**
6 |
7 | SuperSuit introduces a collection of small functions which can wrap reinforcement learning environments to do preprocessing ('microwrappers').
8 | We support Gymnasium for single agent environments and PettingZoo for multi-agent environments (both AECEnv and ParallelEnv environments).
9 |
10 |
11 | Using it with Gymnasium to convert space invaders to have a grey scale observation space and stack the last 4 frames looks like:
12 |
13 | ```
14 | import gymnasium
15 | from supersuit import color_reduction_v0, frame_stack_v1
16 |
17 | env = gymnasium.make('SpaceInvaders-v0')
18 |
19 | env = frame_stack_v1(color_reduction_v0(env, 'full'), 4)
20 | ```
21 |
22 | Similarly, using SuperSuit with PettingZoo environments looks like
23 |
24 | ```
25 | from pettingzoo.butterfly import pistonball_v0
26 | env = pistonball_v0.env()
27 |
28 | env = frame_stack_v1(color_reduction_v0(env, 'full'), 4)
29 | ```
30 |
31 |
32 | **Please note**: Once the planned wrapper rewrite of Gymnasium is complete and the vector API is stabilized, this project will be deprecated and rewritten as part of a new wrappers package in PettingZoo and the vectorized API will be redone, taking inspiration from the functionality currently in Gymnasium.
33 |
34 | ## Installing SuperSuit
35 | To install SuperSuit from pypi:
36 |
37 | ```
38 | python3 -m venv env
39 | source env/bin/activate
40 | pip install --upgrade pip
41 | pip install supersuit
42 | ```
43 |
44 | Alternatively, to install SuperSuit from source, clone this repo, `cd` to it, and then:
45 |
46 | ```
47 | python3 -m venv env
48 | source env/bin/activate
49 | pip install --upgrade pip
50 | pip install -e .
51 | ```
52 |
53 | ## Citation
54 |
55 | If you use this in your research, please cite:
56 |
57 | ```
58 | @article{SuperSuit,
59 | Title = {SuperSuit: Simple Microwrappers for Reinforcement Learning Environments},
60 | Author = {Terry, J. K and Black, Benjamin and Hari, Ananth},
61 | journal={arXiv preprint arXiv:2008.08932},
62 | year={2020}
63 | }
64 | ```
65 |
--------------------------------------------------------------------------------
/supersuit/lambda_wrappers/reward_lambda.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 | from pettingzoo.utils import BaseWrapper as PettingzooWrap
3 |
4 | from supersuit.utils.make_defaultdict import make_defaultdict
5 | from supersuit.utils.wrapper_chooser import WrapperChooser
6 |
7 |
8 | class aec_reward_lambda(PettingzooWrap):
9 | def __init__(self, env, change_reward_fn):
10 | assert callable(
11 | change_reward_fn
12 | ), f"change_reward_fn needs to be a function. It is {change_reward_fn}"
13 | self._change_reward_fn = change_reward_fn
14 |
15 | super().__init__(env)
16 |
17 | def _check_wrapper_params(self):
18 | pass
19 |
20 | def _modify_spaces(self):
21 | pass
22 |
23 | def reset(self, seed=None, options=None):
24 | super().reset(seed=seed, options=options)
25 | self.rewards = {
26 | agent: self._change_reward_fn(reward)
27 | for agent, reward in self.env.rewards.items() # you don't want to unwrap here, because another reward wrapper might have been applied
28 | }
29 | self.__cumulative_rewards = make_defaultdict({a: 0 for a in self.agents})
30 | self._accumulate_rewards()
31 |
32 | def step(self, action):
33 | agent = self.env.agent_selection
34 | super().step(action)
35 | self.rewards = {
36 | agent: self._change_reward_fn(reward)
37 | for agent, reward in self.env.rewards.items() # you don't want to unwrap here, because another reward wrapper might have been applied
38 | }
39 | self.__cumulative_rewards[agent] = 0
40 | self._cumulative_rewards = self.__cumulative_rewards
41 | self._accumulate_rewards()
42 |
43 |
44 | class gym_reward_lambda(gymnasium.Wrapper):
45 | def __init__(self, env, change_reward_fn):
46 | assert callable(
47 | change_reward_fn
48 | ), f"change_reward_fn needs to be a function. It is {change_reward_fn}"
49 | self._change_reward_fn = change_reward_fn
50 |
51 | super().__init__(env)
52 |
53 | def step(self, action):
54 | obs, rew, termination, truncation, info = super().step(action)
55 | return obs, self._change_reward_fn(rew), termination, truncation, info
56 |
57 |
58 | reward_lambda_v0 = WrapperChooser(
59 | aec_wrapper=aec_reward_lambda, gym_wrapper=gym_reward_lambda
60 | )
61 |
--------------------------------------------------------------------------------
/supersuit/__init__.py:
--------------------------------------------------------------------------------
1 | from supersuit.generic_wrappers import clip_actions_v0 # NOQA
2 | from supersuit.generic_wrappers import (
3 | clip_reward_v0,
4 | color_reduction_v0,
5 | delay_observations_v0,
6 | dtype_v0,
7 | flatten_v0,
8 | frame_skip_v0,
9 | frame_stack_v1,
10 | max_observation_v0,
11 | normalize_obs_v0,
12 | reshape_v0,
13 | resize_v1,
14 | sticky_actions_v0,
15 | )
16 |
17 | from .aec_vector import vectorize_aec_env_v0
18 | from .generic_wrappers import * # NOQA
19 | from .lambda_wrappers import observation_lambda_v0 # NOQA
20 | from .lambda_wrappers import action_lambda_v1, reward_lambda_v0
21 | from .multiagent_wrappers import black_death_v3 # NOQA
22 | from .multiagent_wrappers import (
23 | agent_indicator_v0,
24 | pad_action_space_v0,
25 | pad_observations_v0,
26 | )
27 | from .vector.vector_constructors import concat_vec_envs_v1 # NOQA
28 | from .vector.vector_constructors import (
29 | gym_vec_env_v0,
30 | pettingzoo_env_to_vec_env_v1,
31 | stable_baselines3_vec_env_v0,
32 | stable_baselines_vec_env_v0,
33 | )
34 |
35 |
36 | class DeprecatedWrapper(ImportError):
37 | pass
38 |
39 |
40 | def __getattr__(wrapper_name):
41 | """
42 | Gives error that looks like this when trying to import old version of wrapper:
43 | File "./supersuit/__init__.py", line 38, in __getattr__
44 | raise DeprecatedWrapper(f"{base}{version_num} is now deprecated, use {base}{act_version_num} instead")
45 | supersuit.DeprecatedWrapper: concat_vec_envs_v0 is now deprecated, use concat_vec_envs_v1 instead
46 | """
47 | start_v = wrapper_name.rfind("_v") + 2
48 | version = wrapper_name[start_v:]
49 | base = wrapper_name[:start_v]
50 | try:
51 | version_num = int(version)
52 | is_valid_version = True
53 | except ValueError:
54 | is_valid_version = False
55 |
56 | globs = globals()
57 | if is_valid_version:
58 | for act_version_num in range(1000):
59 | if f"{base}{act_version_num}" in globs:
60 | if version_num < act_version_num:
61 | raise DeprecatedWrapper(
62 | f"{base}{version_num} is now deprecated, use {base}{act_version_num} instead"
63 | )
64 |
65 | raise ImportError(f"cannot import name '{wrapper_name}' from 'supersuit'")
66 |
67 |
68 | __version__ = "3.10.0"
69 |
--------------------------------------------------------------------------------
/supersuit/generic_wrappers/basic_wrappers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gymnasium.spaces import Box
3 |
4 | from supersuit.lambda_wrappers import (
5 | action_lambda_v1,
6 | observation_lambda_v0,
7 | reward_lambda_v0,
8 | )
9 | from supersuit.utils import basic_transforms
10 |
11 |
12 | def basic_obs_wrapper(env, module, param):
13 | def change_space(space):
14 | module.check_param(space, param)
15 | space = module.change_obs_space(space, param)
16 | return space
17 |
18 | def change_obs(obs, obs_space):
19 | return module.change_observation(obs, obs_space, param)
20 |
21 | return observation_lambda_v0(env, change_obs, change_space)
22 |
23 |
24 | def color_reduction_v0(env, mode="full"):
25 | return basic_obs_wrapper(env, basic_transforms.color_reduction, mode)
26 |
27 |
28 | def resize_v1(env, x_size, y_size, linear_interp=True):
29 | scale_tuple = (x_size, y_size, linear_interp)
30 | return basic_obs_wrapper(env, basic_transforms.resize, scale_tuple)
31 |
32 |
33 | def dtype_v0(env, dtype):
34 | return basic_obs_wrapper(env, basic_transforms.dtype, dtype)
35 |
36 |
37 | def flatten_v0(env):
38 | return basic_obs_wrapper(env, basic_transforms.flatten, True)
39 |
40 |
41 | def reshape_v0(env, shape):
42 | return basic_obs_wrapper(env, basic_transforms.reshape, shape)
43 |
44 |
45 | def normalize_obs_v0(env, env_min=0.0, env_max=1.0):
46 | return basic_obs_wrapper(env, basic_transforms.normalize_obs, (env_min, env_max))
47 |
48 |
49 | def clip_actions_v0(env):
50 | return action_lambda_v1(
51 | env,
52 | lambda action, act_space: np.clip(action, act_space.low, act_space.high)
53 | if action is not None
54 | else None,
55 | lambda act_space: act_space,
56 | )
57 |
58 |
59 | def scale_actions_v0(env, scale):
60 | def change_act_space(act_space):
61 | assert isinstance(
62 | act_space, Box
63 | ), "scale_actions_v0 only works with a Box action space"
64 | return Box(low=act_space.low * scale, high=act_space.high * scale)
65 |
66 | return action_lambda_v1(
67 | env,
68 | lambda action, act_space: np.asarray(action) * scale
69 | if action is not None
70 | else None,
71 | lambda act_space: change_act_space(act_space),
72 | )
73 |
74 |
75 | def clip_reward_v0(env, lower_bound=-1, upper_bound=1):
76 | return reward_lambda_v0(env, lambda rew: max(min(rew, upper_bound), lower_bound))
77 |
--------------------------------------------------------------------------------
/test/test_vector/test_aec_vector_identity_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pettingzoo.butterfly import knights_archers_zombies_v10
3 |
4 | from supersuit import vectorize_aec_env_v0
5 |
6 |
7 | def recursive_equal(info1, info2):
8 | if info1 == info2:
9 | return True
10 | return False
11 |
12 |
13 | def test_identical():
14 | def env_fn():
15 | return knights_archers_zombies_v10.env() # ,20)
16 |
17 | n_envs = 2
18 | # single threaded
19 | env1 = vectorize_aec_env_v0(knights_archers_zombies_v10.env(), n_envs)
20 | env2 = vectorize_aec_env_v0(knights_archers_zombies_v10.env(), n_envs, num_cpus=1)
21 | env1.reset(seed=42)
22 | env2.reset(seed=42)
23 |
24 | def policy(obs, agent):
25 | return [env1.action_space(agent).sample() for i in range(env1.num_envs)]
26 |
27 | envs_done = 0
28 | for agent in env1.agent_iter(200000):
29 | assert env1.agent_selection == env2.agent_selection
30 | agent = env1.agent_selection
31 | (
32 | obs1,
33 | rew1,
34 | agent_term1,
35 | agent_trunc1,
36 | env_term1,
37 | env_trunc1,
38 | agent_passes1,
39 | infos1,
40 | ) = env1.last()
41 | (
42 | obs2,
43 | rew2,
44 | agent_term2,
45 | agent_trunc2,
46 | env_term2,
47 | env_trunc2,
48 | agent_passes2,
49 | infos2,
50 | ) = env2.last()
51 | assert np.all(np.equal(obs1, obs2))
52 | assert np.all(np.equal(agent_term1, agent_term2))
53 | assert np.all(np.equal(agent_trunc1, agent_trunc2))
54 | assert np.all(np.equal(agent_passes1, agent_passes2))
55 | assert np.all(np.equal(env_term1, env_term2))
56 | assert np.all(np.equal(env_trunc1, env_trunc2))
57 | assert np.all(np.equal(obs1, obs2))
58 | assert all(
59 | np.all(np.equal(r1, r2))
60 | for r1, r2 in zip(env1.rewards.values(), env2.rewards.values())
61 | )
62 | assert recursive_equal(infos1, infos2)
63 | actions = policy(obs1, agent)
64 | env1.step(actions)
65 | env2.step(actions)
66 | # env.envs[0].render()
67 | for j in range(2):
68 | # if agent_passes[j]:
69 | # print("pass")
70 | if rew1[j] != 0:
71 | print(j, agent, rew1, agent_term1[j], agent_trunc1[j])
72 | if env_term1[j] or env_trunc1[j]:
73 | print(j, "done")
74 | envs_done += 1
75 | if envs_done == n_envs + 1:
76 | print("test passed")
77 | return
78 |
--------------------------------------------------------------------------------
/test/gym_unwrapped_test.py:
--------------------------------------------------------------------------------
1 | from test.dummy_gym_env import DummyEnv
2 |
3 | import numpy as np
4 | from gymnasium import spaces
5 |
6 | from supersuit import (
7 | clip_actions_v0,
8 | clip_reward_v0,
9 | color_reduction_v0,
10 | delay_observations_v0,
11 | dtype_v0,
12 | flatten_v0,
13 | frame_skip_v0,
14 | frame_stack_v1,
15 | max_observation_v0,
16 | nan_random_v0,
17 | nan_zeros_v0,
18 | normalize_obs_v0,
19 | scale_actions_v0,
20 | sticky_actions_v0,
21 | )
22 |
23 |
24 | def unwrapped_check(env):
25 | # image observations
26 | if isinstance(env.observation_space, spaces.Box):
27 | if (
28 | (env.observation_space.low.shape == 3)
29 | and (env.observation_space.low == 0).all()
30 | and (len(env.observation_space.shape[2]) == 3)
31 | and (env.observation_space.high == 255).all()
32 | ):
33 | env = max_observation_v0(env, 2)
34 | env = color_reduction_v0(env, mode="full")
35 | env = normalize_obs_v0(env)
36 |
37 | # box action spaces
38 | if isinstance(env.action_space, spaces.Box):
39 | env = clip_actions_v0(env)
40 | env = scale_actions_v0(env, 0.5)
41 |
42 | # stackable observations
43 | if isinstance(env.observation_space, spaces.Box) or isinstance(
44 | env.observation_space, spaces.Discrete
45 | ):
46 | env = frame_stack_v1(env, 2)
47 |
48 | # not discrete and not multibinary observations
49 | if not isinstance(env.observation_space, spaces.Discrete) and not isinstance(
50 | env.observation_space, spaces.MultiBinary
51 | ):
52 | env = dtype_v0(env, np.float16)
53 | env = flatten_v0(env)
54 | env = frame_skip_v0(env, 2)
55 |
56 | # everything else
57 | env = clip_reward_v0(env, lower_bound=-1, upper_bound=1)
58 | env = delay_observations_v0(env, 2)
59 | env = sticky_actions_v0(env, 0.5)
60 | env = nan_random_v0(env)
61 | env = nan_zeros_v0(env)
62 |
63 | assert env.unwrapped.__class__ == DummyEnv, f"Failed to unwrap {env}"
64 |
65 |
66 | def test_unwrapped():
67 | observation_spaces = []
68 | observation_spaces.append(
69 | spaces.Box(low=-1.0, high=1.0, shape=[2], dtype=np.float32)
70 | )
71 | observation_spaces.append(
72 | spaces.Box(low=0, high=255, shape=[64, 64, 3], dtype=np.int16)
73 | )
74 | observation_spaces.append(spaces.Discrete(5))
75 | observation_spaces.append(spaces.MultiBinary([3, 4]))
76 |
77 | action_spaces = []
78 | action_spaces.append(spaces.Box(-3.0, 3.0, [3], np.float32))
79 | action_spaces.append(spaces.Discrete(5))
80 | action_spaces.append(spaces.MultiDiscrete([3, 5]))
81 |
82 | for obs_space in observation_spaces:
83 | for act_space in action_spaces:
84 | env = DummyEnv(obs_space.sample(), obs_space, act_space)
85 | unwrapped_check(env)
86 |
--------------------------------------------------------------------------------
/supersuit/lambda_wrappers/action_lambda.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import gymnasium
4 | from gymnasium.spaces import Space
5 |
6 | from supersuit.utils.base_aec_wrapper import BaseWrapper
7 | from supersuit.utils.wrapper_chooser import WrapperChooser
8 |
9 |
10 | class aec_action_lambda(BaseWrapper):
11 | def __init__(self, env, change_action_fn, change_space_fn):
12 | assert callable(
13 | change_action_fn
14 | ), f"change_action_fn needs to be a function. It is {change_action_fn}"
15 | assert callable(
16 | change_space_fn
17 | ), f"change_space_fn needs to be a function. It is {change_space_fn}"
18 |
19 | self.change_action_fn = change_action_fn
20 | self.change_space_fn = change_space_fn
21 |
22 | super().__init__(env)
23 | if hasattr(self, "possible_agents"):
24 | for agent in self.possible_agents:
25 | # call any validation logic in this function
26 | self.action_space(agent)
27 |
28 | def _modify_observation(self, agent, observation):
29 | return observation
30 |
31 | @functools.lru_cache(maxsize=None)
32 | def action_space(self, agent):
33 | old_act_space = self.env.action_space(agent)
34 | try:
35 | return self.change_space_fn(old_act_space, agent)
36 | except TypeError:
37 | return self.change_space_fn(old_act_space)
38 |
39 | def _modify_action(self, agent, action):
40 | old_act_space = self.env.action_space(agent)
41 | try:
42 | return self.change_action_fn(action, old_act_space, agent)
43 | except TypeError:
44 | return self.change_action_fn(action, old_act_space)
45 |
46 |
47 | class gym_action_lambda(gymnasium.Wrapper):
48 | def __init__(self, env, change_action_fn, change_space_fn):
49 | assert callable(
50 | change_action_fn
51 | ), f"change_action_fn needs to be a function. It is {change_action_fn}"
52 | assert callable(
53 | change_space_fn
54 | ), f"change_space_fn needs to be a function. It is {change_space_fn}"
55 | self.change_action_fn = change_action_fn
56 | self.change_space_fn = change_space_fn
57 |
58 | super().__init__(env)
59 | self._modify_spaces()
60 |
61 | def _modify_spaces(self):
62 | new_space = self.change_space_fn(self.action_space)
63 | assert isinstance(
64 | new_space, Space
65 | ), "output of change_space_fn argument to action_lambda_wrapper must be a gymnasium space"
66 | self.action_space = new_space
67 |
68 | def _modify_action(self, action):
69 | return self.change_action_fn(action, self.env.action_space)
70 |
71 | def step(self, action):
72 | return super().step(self._modify_action(action))
73 |
74 |
75 | action_lambda_v1 = WrapperChooser(
76 | aec_wrapper=aec_action_lambda, gym_wrapper=gym_action_lambda
77 | )
78 |
--------------------------------------------------------------------------------
/test/test_utils/test_frame_stack.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gymnasium.spaces import Box, Discrete
3 |
4 | from supersuit.utils.frame_stack import stack_init, stack_obs, stack_obs_space
5 |
6 |
7 | stack_obs_space_3d = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 3))
8 | stack_obs_space_2d = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(4, 3))
9 | stack_obs_space_1d = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(3,))
10 |
11 | stack_discrete = Discrete(3)
12 |
13 | STACK_SIZE = 11
14 |
15 |
16 | def test_obs_space():
17 | assert stack_obs_space(stack_obs_space_1d, STACK_SIZE).shape == (3 * STACK_SIZE,)
18 | assert stack_obs_space(stack_obs_space_2d, STACK_SIZE).shape == (4, 3, STACK_SIZE)
19 | assert stack_obs_space(stack_obs_space_3d, STACK_SIZE).shape == (
20 | 4,
21 | 4,
22 | 3 * STACK_SIZE,
23 | )
24 | assert stack_obs_space(stack_discrete, STACK_SIZE).n == 3**STACK_SIZE
25 |
26 |
27 | def stack_obs_helper(frame_stack_list, obs_space, stack_size):
28 | stack = stack_init(
29 | obs_space, stack_size
30 | ) # stack_reset_obs(frame_stack_list[0], stack_size)
31 | for obs in frame_stack_list:
32 | stack = stack_obs(stack, obs, obs_space, stack_size)
33 | return stack
34 |
35 |
36 | def test_change_observation():
37 | assert stack_obs_helper(
38 | [stack_obs_space_1d.low], stack_obs_space_1d, STACK_SIZE
39 | ).shape == (3 * STACK_SIZE,)
40 | assert stack_obs_helper(
41 | [stack_obs_space_1d.low, stack_obs_space_1d.high],
42 | stack_obs_space_1d,
43 | STACK_SIZE,
44 | ).shape == (3 * STACK_SIZE,)
45 | assert stack_obs_helper(
46 | [stack_obs_space_2d.low], stack_obs_space_2d, STACK_SIZE
47 | ).shape == (4, 3, STACK_SIZE)
48 | assert stack_obs_helper(
49 | [stack_obs_space_2d.low, stack_obs_space_2d.high],
50 | stack_obs_space_2d,
51 | STACK_SIZE,
52 | ).shape == (4, 3, STACK_SIZE)
53 | assert stack_obs_helper(
54 | [stack_obs_space_3d.low], stack_obs_space_3d, STACK_SIZE
55 | ).shape == (4, 4, 3 * STACK_SIZE)
56 |
57 | assert stack_obs_helper([1, 2], stack_discrete, STACK_SIZE) == 2 + 1 * 3
58 |
59 | stacked = stack_obs_helper(
60 | [stack_obs_space_2d.low, stack_obs_space_2d.high], stack_obs_space_2d, 3
61 | )
62 | raw = np.stack(
63 | [
64 | np.zeros_like(stack_obs_space_2d.high),
65 | stack_obs_space_2d.low,
66 | stack_obs_space_2d.high,
67 | ],
68 | axis=2,
69 | )
70 | assert np.all(np.equal(stacked, raw))
71 |
72 | stacked = stack_obs_helper(
73 | [stack_obs_space_3d.low, stack_obs_space_3d.high], stack_obs_space_3d, 3
74 | )
75 | raw = np.concatenate(
76 | [
77 | np.zeros_like(stack_obs_space_3d.high),
78 | stack_obs_space_3d.low,
79 | stack_obs_space_3d.high,
80 | ],
81 | axis=2,
82 | )
83 | assert np.all(np.equal(stacked, raw))
84 |
--------------------------------------------------------------------------------
/test/test_vector/test_aec_vector_values.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | from pettingzoo.butterfly import knights_archers_zombies_v10
5 | from pettingzoo.classic import rps_v2
6 | from pettingzoo.mpe import simple_world_comm_v3
7 |
8 | from supersuit import vectorize_aec_env_v0
9 |
10 |
11 | def test_all():
12 | NUM_ENVS = 5
13 |
14 | def test_vec_env(vec_env):
15 | vec_env.reset()
16 | (
17 | obs,
18 | rew,
19 | agent_term,
20 | agent_trunc,
21 | env_term,
22 | env_trunc,
23 | agent_passes,
24 | infos,
25 | ) = vec_env.last()
26 | print(np.asarray(obs).shape)
27 | assert len(obs) == NUM_ENVS
28 | act_space = vec_env.action_space(vec_env.agent_selection)
29 | assert np.all(np.equal(obs, vec_env.observe(vec_env.agent_selection)))
30 | assert len(vec_env.observe(vec_env.agent_selection)) == NUM_ENVS
31 | vec_env.step([act_space.sample() for _ in range(NUM_ENVS)])
32 | (
33 | obs,
34 | rew,
35 | agent_term,
36 | agent_trunc,
37 | env_term,
38 | env_trunc,
39 | agent_passes,
40 | infos,
41 | ) = vec_env.last(observe=False)
42 | assert obs is None
43 |
44 | def test_infos(vec_env):
45 | vec_env.reset()
46 | infos = vec_env.infos[vec_env.agent_selection]
47 | assert infos[1]["legal_moves"]
48 |
49 | def test_seed(vec_env):
50 | vec_env.reset(seed=4)
51 |
52 | def test_some_done(vec_env):
53 | vec_env.reset()
54 | act_space = vec_env.action_space(vec_env.agent_selection)
55 | assert not any(done for dones in vec_env.dones.values() for done in dones)
56 | vec_env.step([act_space.sample() for _ in range(NUM_ENVS)])
57 | assert any(rew != 0 for rews in vec_env.rewards.values() for rew in rews)
58 | any_done_first = any(done for dones in vec_env.dones.values() for done in dones)
59 | vec_env.step([act_space.sample() for _ in range(NUM_ENVS)])
60 | any_done_second = any(
61 | done for dones in vec_env.dones.values() for done in dones
62 | )
63 | assert any_done_first and any_done_second
64 |
65 | def select_action(vec_env, passes, i):
66 | my_info = vec_env.infos[vec_env.agent_selection][i]
67 | if False and not passes[i] and "legal_moves" in my_info:
68 | return random.choice(my_info["legal_moves"])
69 | else:
70 | act_space = vec_env.action_space(vec_env.agent_selection)
71 | return act_space.sample()
72 |
73 | for num_cpus in [0, 1]:
74 | test_vec_env(vectorize_aec_env_v0(rps_v2.env(), NUM_ENVS, num_cpus=num_cpus))
75 | test_vec_env(
76 | vectorize_aec_env_v0(
77 | knights_archers_zombies_v10.env(), NUM_ENVS, num_cpus=num_cpus
78 | )
79 | )
80 | test_vec_env(
81 | vectorize_aec_env_v0(
82 | simple_world_comm_v3.env(), NUM_ENVS, num_cpus=num_cpus
83 | )
84 | )
85 |
--------------------------------------------------------------------------------
/supersuit/generic_wrappers/nan_wrappers.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import gymnasium
4 | import numpy as np
5 |
6 | from supersuit.lambda_wrappers import action_lambda_v1
7 |
8 | from .utils.base_modifier import BaseModifier
9 | from .utils.shared_wrapper_util import shared_wrapper
10 |
11 |
12 | def nan_random_v0(env):
13 | class NanRandomModifier(BaseModifier):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def reset(self, seed=None, options=None):
18 | self.np_random, seed = gymnasium.utils.seeding.np_random(seed)
19 |
20 | return super().reset(seed, options=options)
21 |
22 | def modify_action(self, action):
23 | if action is not None and np.isnan(action).any():
24 | obs = self.cur_obs
25 | if isinstance(obs, dict) and "action mask" in obs:
26 | warnings.warn(
27 | "[WARNING]: Step received an NaN action {}. Environment is {}. Taking a random action from 'action mask'.".format(
28 | action, self
29 | )
30 | )
31 | action = self.np_random.choice(np.flatnonzero(obs["action_mask"]))
32 | else:
33 | warnings.warn(
34 | "[WARNING]: Step received an NaN action {}. Environment is {}. Taking a random action.".format(
35 | action, self
36 | )
37 | )
38 | action = self.action_space.sample()
39 |
40 | return action
41 |
42 | return shared_wrapper(env, NanRandomModifier)
43 |
44 |
45 | def nan_noop_v0(env, no_op_action):
46 | def on_action(action, action_space):
47 | if action is None:
48 | warnings.warn(
49 | "[WARNING]: Step received an None action {}. Environment is {}. Taking no operation action.".format(
50 | action, env
51 | )
52 | )
53 | return None
54 | if np.isnan(action).any():
55 | warnings.warn(
56 | "[WARNING]: Step received an NaN action {}. Environment is {}. Taking no operation action.".format(
57 | action, env
58 | )
59 | )
60 | return no_op_action
61 | return action
62 |
63 | return action_lambda_v1(env, on_action, lambda act_space: act_space)
64 |
65 |
66 | def nan_zeros_v0(env):
67 | def on_action(action, action_space):
68 | if action is None:
69 | warnings.warn(
70 | "[WARNING]: Step received an None action {}. Environment is {}. Taking the all zeroes action.".format(
71 | action, env
72 | )
73 | )
74 | return None
75 | if np.isnan(action).any():
76 | warnings.warn(
77 | "[WARNING]: Step received an NaN action {}. Environment is {}. Taking the all zeroes action.".format(
78 | action, env
79 | )
80 | )
81 | return np.zeros_like(action)
82 | return action
83 |
84 | return action_lambda_v1(env, on_action, lambda act_space: act_space)
85 |
--------------------------------------------------------------------------------
/test/test_utils/test_agent_indicator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from gymnasium.spaces import Box, Discrete
4 |
5 | from supersuit.utils.agent_indicator import (
6 | change_obs_space,
7 | change_observation,
8 | get_indicator_map,
9 | )
10 |
11 |
12 | obs_space_3d = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 3))
13 | obs_space_2d = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(4, 3))
14 | obs_space_1d = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(3,))
15 |
16 | discrete_space = Discrete(3)
17 |
18 | NUM_INDICATORS = 11
19 |
20 |
21 | def test_obs_space():
22 | assert change_obs_space(obs_space_1d, NUM_INDICATORS).shape == (3 + NUM_INDICATORS,)
23 | assert change_obs_space(obs_space_2d, NUM_INDICATORS).shape == (
24 | 4,
25 | 3,
26 | 1 + NUM_INDICATORS,
27 | )
28 | assert change_obs_space(obs_space_3d, NUM_INDICATORS).shape == (
29 | 4,
30 | 4,
31 | 3 + NUM_INDICATORS,
32 | )
33 | assert change_obs_space(discrete_space, NUM_INDICATORS).n == 3 * NUM_INDICATORS
34 |
35 |
36 | def test_change_observation():
37 | assert change_observation(
38 | np.ones((4, 4, 3)), obs_space_3d, (4, NUM_INDICATORS)
39 | ).shape == (4, 4, 3 + NUM_INDICATORS)
40 | assert change_observation(
41 | np.ones((4, 3)), obs_space_2d, (4, NUM_INDICATORS)
42 | ).shape == (4, 3, 1 + NUM_INDICATORS)
43 | assert change_observation(np.ones(41), obs_space_1d, (4, NUM_INDICATORS)).shape == (
44 | 41 + NUM_INDICATORS,
45 | )
46 |
47 | assert (
48 | change_observation(np.ones((4, 4, 3)), obs_space_3d, (4, NUM_INDICATORS))[
49 | 0, 0, 0
50 | ]
51 | == 1.0
52 | )
53 | assert (
54 | change_observation(np.ones((4, 4, 3)), obs_space_3d, (4, NUM_INDICATORS))[
55 | 0, 0, 4
56 | ]
57 | == 0.0
58 | )
59 | assert (
60 | change_observation(np.ones((4, 4, 3)), obs_space_3d, (4, NUM_INDICATORS))[
61 | 0, 1, 7
62 | ]
63 | == 1.0
64 | )
65 | assert (
66 | change_observation(np.ones((4, 4, 3)), obs_space_3d, (4, NUM_INDICATORS))[
67 | 0, 0, 8
68 | ]
69 | == 0.0
70 | )
71 | assert (
72 | change_observation(np.ones((3,)), obs_space_1d, (4, NUM_INDICATORS))[2] == 1.0
73 | )
74 | assert (
75 | change_observation(np.ones((3,)), obs_space_1d, (4, NUM_INDICATORS))[6] == 0.0
76 | )
77 | assert (
78 | change_observation(np.ones((3,)), obs_space_1d, (4, NUM_INDICATORS))[7] == 1.0
79 | )
80 | assert (
81 | change_observation(np.ones((3,)), obs_space_1d, (4, NUM_INDICATORS))[8] == 0.0
82 | )
83 | assert (
84 | change_observation(2, discrete_space, (4, NUM_INDICATORS))
85 | == 2 * NUM_INDICATORS + 4
86 | )
87 |
88 |
89 | def test_get_indicator_map():
90 | assert len(get_indicator_map(["bob", "joe", "fren"], False)) == 3
91 |
92 | with pytest.raises(AssertionError):
93 | get_indicator_map(["bob", "joe", "fren"], True)
94 |
95 | assert (
96 | len(
97 | set(get_indicator_map(["bob_0", "joe_1", "fren_2", "joe_3"], True).values())
98 | )
99 | == 3
100 | )
101 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # Package ######################################################################
2 |
3 | [build-system]
4 | requires = ["setuptools >= 61.0.0"]
5 | build-backend = "setuptools.build_meta"
6 |
7 | [project]
8 | name="SuperSuit"
9 | description="Wrappers for Gymnasium and PettingZoo"
10 | readme="README.md"
11 | requires-python = ">= 3.9"
12 | authors = [{ name = "Farama Foundation", email = "contact@farama.org" }]
13 | license = { text = "MIT License" }
14 | keywords=["Reinforcement Learning", "game", "RL", "AI", "gymnasium"]
15 | classifiers=[
16 | "Programming Language :: Python :: 3.12",
17 | "Programming Language :: Python :: 3.11",
18 | "Programming Language :: Python :: 3.10",
19 | "Programming Language :: Python :: 3.9",
20 | "License :: OSI Approved :: MIT License",
21 | "Operating System :: OS Independent",
22 | ]
23 | dependencies = ["numpy>=1.19.0", "gymnasium>=1.0.0", "tinyscaler>=1.2.6"]
24 | dynamic = ["version"]
25 |
26 | [project.optional-dependencies]
27 | # Update dependencies in `all` if any are added or removed
28 |
29 | # Remove "pettingzoo[all,atari]" from dependencies for now
30 | # as there's a pygame version mismatch issue for pettingzoo version
31 | # 1.24.3. Once pettingzoo gets an updated version, add it back.
32 | testing = ["AutoROM", "pytest", "pytest-xdist", "stable-baselines3>=2.0.0", "moviepy >=1.0.0"]
33 |
34 | [project.urls]
35 | Homepage = "https://farama.org"
36 | Repository = "https://github.com/Farama-Foundation/SuperSuit"
37 | "Bug Report" = "https://github.com/Farama-Foundation/SuperSuit/issues"
38 |
39 |
40 | [tool.setuptools]
41 | include-package-data = true
42 |
43 | [tool.setuptools.packages.find]
44 | include = ["supersuit", "supersuit.*"]
45 |
46 | # Linters and Test tools #######################################################
47 |
48 | [tool.black]
49 | safe = true
50 |
51 | [tool.isort]
52 | atomic = true
53 | profile = "black"
54 | src_paths = ["supersuit", "test"]
55 | extra_standard_library = ["typing_extensions"]
56 | indent = 4
57 | lines_after_imports = 2
58 | multi_line_output = 3
59 |
60 | [tool.pyright]
61 | include = ["supersuit/**", "tests/**"]
62 | exclude = ["**/node_modules", "**/__pycache__"]
63 | strict = []
64 |
65 | typeCheckingMode = "basic"
66 | pythonVersion = "3.9"
67 | pythonPlatform = "All"
68 | typeshedPath = "typeshed"
69 | enableTypeIgnoreComments = true
70 |
71 | # This is required as the CI pre-commit does not download the module (i.e. numpy, pygame, box2d)
72 | # Therefore, we have to ignore missing imports
73 | reportMissingImports = "none"
74 | # Some modules are missing type stubs, which is an issue when running pyright locally
75 | reportMissingTypeStubs = false
76 | # For warning and error, will raise an error when
77 | reportInvalidTypeVarUse = "none"
78 |
79 | # reportUnknownMemberType = "warning" # -> raises 6035 warnings
80 | # reportUnknownParameterType = "warning" # -> raises 1327 warnings
81 | # reportUnknownVariableType = "warning" # -> raises 2585 warnings
82 | # reportUnknownArgumentType = "warning" # -> raises 2104 warnings
83 | reportGeneralTypeIssues = "none" # -> commented out raises 489 errors
84 | # reportUntypedFunctionDecorator = "none" # -> pytest.mark.parameterize issues
85 |
86 | reportPrivateUsage = "warning"
87 | reportUnboundVariable = "warning"
88 |
89 | [tool.pytest.ini_options]
90 | filterwarnings = ["ignore::DeprecationWarning:gymnasium.*:"]
91 | addopts = [ "-n=auto" ]
92 |
--------------------------------------------------------------------------------
/supersuit/vector/vector_constructors.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import cloudpickle
4 | import gymnasium
5 | from pettingzoo.utils.env import ParallelEnv
6 |
7 | from . import MakeCPUAsyncConstructor, MarkovVectorEnv
8 |
9 |
10 | def vec_env_args(env, num_envs):
11 | def env_fn():
12 | env_copy = cloudpickle.loads(cloudpickle.dumps(env))
13 | return env_copy
14 |
15 | return [env_fn] * num_envs, env.observation_space, env.action_space
16 |
17 |
18 | def warn_not_gym_env(env, fn_name):
19 | if not isinstance(env, gymnasium.Env):
20 | warnings.warn(
21 | f"{fn_name} took in an environment which does not inherit from gymnasium.Env. Note that gym_vec_env only takes in gymnasium-style environments, not pettingzoo environments."
22 | )
23 |
24 |
25 | def gym_vec_env_v0(env, num_envs, multiprocessing=False):
26 | warn_not_gym_env(env, "gym_vec_env")
27 | args = vec_env_args(env, num_envs)[:1]
28 | constructor = (
29 | gymnasium.vector.AsyncVectorEnv
30 | if multiprocessing
31 | else gymnasium.vector.SyncVectorEnv
32 | )
33 | return constructor(*args)
34 |
35 |
36 | def stable_baselines_vec_env_v0(env, num_envs, multiprocessing=False):
37 | import stable_baselines
38 |
39 | warn_not_gym_env(env, "stable_baselines_vec_env")
40 | args = vec_env_args(env, num_envs)[:1]
41 | constructor = (
42 | stable_baselines.common.vec_env.SubprocVecEnv
43 | if multiprocessing
44 | else stable_baselines.common.vec_env.DummyVecEnv
45 | )
46 | return constructor(*args)
47 |
48 |
49 | def stable_baselines3_vec_env_v0(env, num_envs, multiprocessing=False):
50 | import stable_baselines3
51 |
52 | warn_not_gym_env(env, "stable_baselines3_vec_env")
53 | args = vec_env_args(env, num_envs)[:1]
54 | constructor = (
55 | stable_baselines3.common.vec_env.SubprocVecEnv
56 | if multiprocessing
57 | else stable_baselines3.common.vec_env.DummyVecEnv
58 | )
59 | return constructor(*args)
60 |
61 |
62 | def concat_vec_envs_v1(vec_env, num_vec_envs, num_cpus=0, base_class="gymnasium"):
63 | num_cpus = min(num_cpus, num_vec_envs)
64 | vec_env = MakeCPUAsyncConstructor(num_cpus)(*vec_env_args(vec_env, num_vec_envs))
65 |
66 | if base_class == "gymnasium":
67 | return vec_env
68 | elif base_class == "stable_baselines":
69 | from .sb_vector_wrapper import SBVecEnvWrapper
70 |
71 | return SBVecEnvWrapper(vec_env)
72 | elif base_class == "stable_baselines3":
73 | from .sb3_vector_wrapper import SB3VecEnvWrapper
74 |
75 | return SB3VecEnvWrapper(vec_env)
76 | else:
77 | raise ValueError(
78 | "supersuit_vec_env only supports 'gymnasium', 'stable_baselines', and 'stable_baselines3' for its base_class"
79 | )
80 |
81 |
82 | def pettingzoo_env_to_vec_env_v1(parallel_env):
83 | assert isinstance(
84 | parallel_env, ParallelEnv
85 | ), "pettingzoo_env_to_vec_env takes in a pettingzoo ParallelEnv. Can create a parallel_env with pistonball.parallel_env() or convert it from an AEC env with `from pettingzoo.utils.conversions import aec_to_parallel; aec_to_parallel(env)``"
86 | assert hasattr(
87 | parallel_env, "possible_agents"
88 | ), "environment passed to pettingzoo_env_to_vec_env must have possible_agents attribute."
89 | return MarkovVectorEnv(parallel_env)
90 |
--------------------------------------------------------------------------------
/test/test_vector/test_vector_dict.py:
--------------------------------------------------------------------------------
1 | from test.dummy_aec_env import DummyEnv
2 |
3 | import numpy as np
4 | import pytest
5 | from gymnasium.spaces import Box, Dict, Discrete, Tuple
6 | from gymnasium.vector.utils import concatenate, create_empty_array
7 | from pettingzoo.utils.conversions import aec_to_parallel
8 |
9 | from supersuit import concat_vec_envs_v1, pettingzoo_env_to_vec_env_v1
10 |
11 |
12 | n_agents = 5
13 |
14 |
15 | def make_env():
16 | test_env = DummyEnv(
17 | {
18 | str(i): {
19 | "feature": i * np.ones((5,), dtype=np.float32),
20 | "id": (
21 | i * np.ones((7,), dtype=np.float32),
22 | i * np.ones((8,), dtype=np.float32),
23 | ),
24 | }
25 | for i in range(n_agents)
26 | },
27 | {
28 | str(i): Dict(
29 | {
30 | "feature": Box(low=0, high=10, shape=(5,)),
31 | "id": Tuple(
32 | [
33 | Box(low=0, high=10, shape=(7,)),
34 | Box(low=0, high=10, shape=(8,)),
35 | ]
36 | ),
37 | }
38 | )
39 | for i in range(n_agents)
40 | },
41 | {
42 | str(i): Dict(
43 | {
44 | "obs": Box(low=0, high=10, shape=(5,)),
45 | "mask": Discrete(10),
46 | }
47 | )
48 | for i in range(n_agents)
49 | },
50 | )
51 | test_env.metadata["is_parallelizable"] = True
52 | return aec_to_parallel(test_env)
53 |
54 |
55 | def dict_vec_env_test(env):
56 | # tests that environment really is a vectorized
57 | # version of the environment returned by make_env
58 |
59 | obss, infos = env.reset()
60 | for i in range(55):
61 | actions = [env.action_space.sample() for i in range(env.num_envs)]
62 | actions = concatenate(
63 | env.action_space,
64 | actions,
65 | create_empty_array(env.action_space, env.num_envs),
66 | )
67 | obss, rews, terms, truncs, infos = env.step(actions)
68 | assert obss["feature"][1][0] == 1
69 | assert {
70 | "feature": obss["feature"][1][:],
71 | "id": [o[1] for o in obss["id"]],
72 | } in env.observation_space
73 | # no agent death, only env death
74 | if any(terms):
75 | assert all(terms)
76 | if any(truncs):
77 | assert all(truncs)
78 |
79 |
80 | def test_pettingzoo_vec_env():
81 | env = make_env()
82 | env = pettingzoo_env_to_vec_env_v1(env)
83 | dict_vec_env_test(env)
84 |
85 |
86 | def test_single_threaded_concatenate():
87 | env = make_env()
88 | env = pettingzoo_env_to_vec_env_v1(env)
89 | env = concat_vec_envs_v1(env, 2, num_cpus=1)
90 | dict_vec_env_test(env)
91 |
92 |
93 | @pytest.mark.skip(
94 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188"
95 | )
96 | def test_multi_threaded_concatenate():
97 | env = make_env()
98 | env = pettingzoo_env_to_vec_env_v1(env)
99 | env = concat_vec_envs_v1(env, 2, num_cpus=2)
100 | dict_vec_env_test(env)
101 |
102 |
103 | if __name__ == "__main__":
104 | test_multi_threaded_concatenate()
105 | exit(0)
106 |
--------------------------------------------------------------------------------
/supersuit/generic_wrappers/frame_stack.py:
--------------------------------------------------------------------------------
1 | from gymnasium.spaces import Box, Discrete
2 |
3 | from supersuit.utils.frame_stack import stack_init, stack_obs, stack_obs_space
4 |
5 | from .utils.base_modifier import BaseModifier
6 | from .utils.shared_wrapper_util import shared_wrapper
7 |
8 |
9 | def frame_stack_v1(env, stack_size=4, stack_dim=-1):
10 | assert isinstance(stack_size, int), "stack size of frame_stack must be an int"
11 |
12 | class FrameStackModifier(BaseModifier):
13 | def modify_obs_space(self, obs_space):
14 | if isinstance(obs_space, Box):
15 | assert (
16 | 1 <= len(obs_space.shape) <= 3
17 | ), "frame_stack only works for 1, 2 or 3 dimensional observations"
18 | elif isinstance(obs_space, Discrete):
19 | pass
20 | else:
21 | assert (
22 | False
23 | ), "Stacking is currently only allowed for Box and Discrete observation spaces. The given observation space is {}".format(
24 | obs_space
25 | )
26 |
27 | self.old_obs_space = obs_space
28 | self.observation_space = stack_obs_space(obs_space, stack_size, stack_dim)
29 | return self.observation_space
30 |
31 | def reset(self, seed=None, options=None):
32 | self.stack = stack_init(self.old_obs_space, stack_size, stack_dim)
33 |
34 | def modify_obs(self, obs):
35 | self.stack = stack_obs(
36 | self.stack, obs, self.old_obs_space, stack_size, stack_dim
37 | )
38 |
39 | return self.stack
40 |
41 | def get_last_obs(self):
42 | return self.stack
43 |
44 | return shared_wrapper(env, FrameStackModifier)
45 |
46 |
47 | def frame_stack_v2(env, stack_size=4, stack_dim=-1):
48 | assert isinstance(stack_size, int), "stack size of frame_stack must be an int"
49 | assert f"stack_dim should be 0 or -1, not {stack_dim}"
50 |
51 | class FrameStackModifier(BaseModifier):
52 | def modify_obs_space(self, obs_space):
53 | if isinstance(obs_space, Box):
54 | assert (
55 | 1 <= len(obs_space.shape) <= 3
56 | ), "frame_stack only works for 1, 2 or 3 dimensional observations"
57 | elif isinstance(obs_space, Discrete):
58 | pass
59 | else:
60 | assert (
61 | False
62 | ), "Stacking is currently only allowed for Box and Discrete observation spaces. The given observation space is {}".format(
63 | obs_space
64 | )
65 |
66 | self.old_obs_space = obs_space
67 | self.observation_space = stack_obs_space(obs_space, stack_size, stack_dim)
68 | return self.observation_space
69 |
70 | def reset(self, seed=None, options=None):
71 | self.stack = stack_init(self.old_obs_space, stack_size, stack_dim)
72 | self.reset_flag = True
73 |
74 | def modify_obs(self, obs):
75 | if self.reset_flag:
76 | for _ in range(stack_size):
77 | self.stack = stack_obs(
78 | self.stack, obs, self.old_obs_space, stack_size, stack_dim
79 | )
80 | self.reset_flag = False
81 | else:
82 | self.stack = stack_obs(
83 | self.stack, obs, self.old_obs_space, stack_size, stack_dim
84 | )
85 |
86 | return self.stack
87 |
88 | def get_last_obs(self):
89 | return self.stack
90 |
91 | return shared_wrapper(env, FrameStackModifier)
92 |
--------------------------------------------------------------------------------
/supersuit/utils/action_transforms/homogenize_ops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gymnasium import spaces
3 | from gymnasium.spaces import Box, Discrete
4 |
5 |
6 | def check_homogenize_spaces(all_spaces):
7 | assert len(all_spaces) > 0
8 | space1 = all_spaces[0]
9 | assert all(
10 | isinstance(space, space1.__class__) for space in all_spaces
11 | ), "all spaces to homogenize must be of same general shape"
12 |
13 | if isinstance(space1, spaces.Box):
14 | for space in all_spaces:
15 | assert isinstance(
16 | space, spaces.Box
17 | ), "all spaces for homogenize must be either Box or Discrete, not a mix"
18 | assert len(space1.shape) == len(
19 | space.shape
20 | ), "all spaces to homogenize must be of same shape"
21 | assert (
22 | space1.dtype == space.dtype
23 | ), "all spaces to homogenize must be of same dtype"
24 | elif isinstance(space1, spaces.Discrete):
25 | for space in all_spaces:
26 | assert isinstance(
27 | space, spaces.Discrete
28 | ), "all spaces for homogenize must be either Box or Discrete, not a mix"
29 | else:
30 | assert False, "homogenization only supports Discrete and Box spaces"
31 |
32 |
33 | def pad_to(arr, new_shape, pad_value):
34 | old_shape = arr.shape
35 | if old_shape == new_shape:
36 | return arr
37 | pad_size = [ns - os for ns, os in zip(new_shape, old_shape)]
38 | pad_tuples = [(0, ps) for ps in pad_size]
39 | return np.pad(arr, pad_tuples, constant_values=pad_value)
40 |
41 |
42 | def homogenize_spaces(all_spaces):
43 | space1 = all_spaces[0]
44 | if isinstance(space1, spaces.Box):
45 | all_dims = np.array([space.shape for space in all_spaces], dtype=np.int32)
46 | max_dims = np.max(all_dims, axis=0)
47 | new_shape = tuple(max_dims)
48 | all_lows = np.stack(
49 | [
50 | pad_to(space.low, new_shape, np.minimum(0, np.min(space.low)))
51 | for space in all_spaces
52 | ]
53 | )
54 | all_highs = np.stack(
55 | [
56 | pad_to(space.high, new_shape, np.maximum(1e-5, np.max(space.high)))
57 | for space in all_spaces
58 | ]
59 | )
60 | new_low = np.min(all_lows, axis=0)
61 | new_high = np.max(all_highs, axis=0)
62 | assert new_shape == new_low.shape
63 | return Box(low=new_low, high=new_high, dtype=space1.dtype)
64 | elif isinstance(space1, spaces.Discrete):
65 | max_n = max([space.n for space in all_spaces])
66 | return Discrete(max_n)
67 | else:
68 | assert False
69 |
70 |
71 | def dehomogenize_actions(orig_action_space, action):
72 | if isinstance(orig_action_space, spaces.Box):
73 | # choose only the relevant action values
74 | cur_shape = action.shape
75 | new_shape = orig_action_space.shape
76 | if cur_shape == new_shape:
77 | return action
78 | else:
79 | assert len(cur_shape) == len(new_shape)
80 | slices = [slice(0, i) for i in new_shape]
81 | new_action = action[tuple(slices)]
82 |
83 | return new_action
84 |
85 | elif isinstance(orig_action_space, spaces.Discrete):
86 | # extra action values refer to action value 0
87 | n = orig_action_space.n
88 | if action is None:
89 | return None
90 | if action > n - 1:
91 | action = 0
92 | return action
93 | else:
94 | assert False
95 |
96 |
97 | def homogenize_observations(obs_space, obs):
98 | if isinstance(obs_space, spaces.Box):
99 | return pad_to(obs, obs_space.shape, 0)
100 | elif isinstance(obs_space, spaces.Discrete):
101 | return obs_space
102 | else:
103 | assert False
104 |
--------------------------------------------------------------------------------
/supersuit/vector/concat_vec_env.py:
--------------------------------------------------------------------------------
1 | import gymnasium.vector
2 | import numpy as np
3 | from gymnasium.spaces import Discrete
4 | from gymnasium.vector.utils import concatenate, create_empty_array, iterate
5 |
6 | from .single_vec_env import SingleVecEnv
7 |
8 |
9 | def transpose(ll):
10 | return [[ll[i][j] for i in range(len(ll))] for j in range(len(ll[0]))]
11 |
12 |
13 | @iterate.register(Discrete)
14 | def iterate_discrete(space, items):
15 | try:
16 | return iter(items)
17 | except TypeError:
18 | raise TypeError(f"Unable to iterate over the following elements: {items}")
19 |
20 |
21 | class ConcatVecEnv(gymnasium.vector.VectorEnv):
22 | def __init__(self, vec_env_fns, obs_space=None, act_space=None):
23 | self.vec_envs = vec_envs = [vec_env_fn() for vec_env_fn in vec_env_fns]
24 | for i in range(len(vec_envs)):
25 | if not hasattr(vec_envs[i], "num_envs"):
26 | vec_envs[i] = SingleVecEnv([lambda: vec_envs[i]])
27 | self.metadata = self.vec_envs[0].metadata
28 | self.render_mode = self.vec_envs[0].render_mode
29 | self.observation_space = vec_envs[0].observation_space
30 | self.action_space = vec_envs[0].action_space
31 | tot_num_envs = sum(env.num_envs for env in vec_envs)
32 | self.num_envs = tot_num_envs
33 |
34 | def reset(self, seed=None, options=None):
35 | _res_obs = []
36 | _res_infos = []
37 |
38 | if seed is not None:
39 | for i in range(len(self.vec_envs)):
40 | _obs, _info = self.vec_envs[i].reset(seed=seed + i, options=options)
41 | _res_obs.append(_obs)
42 | _res_infos.append(_info)
43 | else:
44 | for i in range(len(self.vec_envs)):
45 | _obs, _info = self.vec_envs[i].reset(options=options)
46 | _res_obs.append(_obs)
47 | _res_infos.append(_info)
48 |
49 | # flatten infos (also done in step function)
50 | flattened_infos = [info for sublist in _res_infos for info in sublist]
51 |
52 | return self.concat_obs(_res_obs), flattened_infos
53 |
54 | def concat_obs(self, observations):
55 | return concatenate(
56 | self.observation_space,
57 | [
58 | item
59 | for obs in observations
60 | for item in iterate(self.observation_space, obs)
61 | ],
62 | create_empty_array(self.observation_space, n=self.num_envs),
63 | )
64 |
65 | def concatenate_actions(self, actions, n_actions):
66 | return concatenate(
67 | self.action_space,
68 | actions,
69 | create_empty_array(self.action_space, n=n_actions),
70 | )
71 |
72 | def step_async(self, actions):
73 | self._saved_actions = actions
74 |
75 | def step_wait(self):
76 | return self.step(self._saved_actions)
77 |
78 | def step(self, actions):
79 | data = []
80 | idx = 0
81 | actions = list(iterate(self.action_space, actions))
82 | for venv in self.vec_envs:
83 | data.append(
84 | venv.step(
85 | self.concatenate_actions(
86 | actions[idx : idx + venv.num_envs], venv.num_envs
87 | )
88 | )
89 | )
90 | idx += venv.num_envs
91 | observations, rewards, terminations, truncations, infos = transpose(data)
92 | observations = self.concat_obs(observations)
93 | rewards = np.concatenate(rewards, axis=0)
94 | terminations = np.concatenate(terminations, axis=0)
95 | truncations = np.concatenate(truncations, axis=0)
96 | infos = [
97 | info for sublist in infos for info in sublist
98 | ] # flatten infos from nested lists
99 | return observations, rewards, terminations, truncations, infos
100 |
101 | def render(self):
102 | return self.vec_envs[0].render()
103 |
104 | def close(self):
105 | for vec_env in self.vec_envs:
106 | vec_env.close()
107 |
108 | def env_is_wrapped(self, wrapper_class):
109 | return sum(
110 | [sub_venv.env_is_wrapped(wrapper_class) for sub_venv in self.vec_envs], []
111 | )
112 |
--------------------------------------------------------------------------------
/supersuit/utils/frame_stack.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gymnasium.spaces import Box, Discrete
3 |
4 |
5 | def get_tile_shape(shape, stack_size, stack_dim=-1):
6 | obs_dim = len(shape)
7 |
8 | if stack_dim == -1:
9 | if obs_dim == 1:
10 | tile_shape = (stack_size,)
11 | new_shape = shape
12 | elif obs_dim == 3:
13 | tile_shape = (1, 1, stack_size)
14 | new_shape = shape
15 | # stack 2-D frames
16 | elif obs_dim == 2:
17 | tile_shape = (1, 1, stack_size)
18 | new_shape = shape + (1,)
19 | else:
20 | assert False, "Stacking is only available for 1,2 or 3 dimensional arrays"
21 |
22 | elif stack_dim == 0:
23 | if obs_dim == 1:
24 | tile_shape = (stack_size,)
25 | new_shape = shape
26 | elif obs_dim == 3:
27 | tile_shape = (stack_size, 1, 1)
28 | new_shape = shape
29 | # stack 2-D frames
30 | elif obs_dim == 2:
31 | tile_shape = (stack_size, 1, 1)
32 | new_shape = (1,) + shape
33 | else:
34 | assert False, "Stacking is only available for 1,2 or 3 dimensional arrays"
35 |
36 | return tile_shape, new_shape
37 |
38 |
39 | def stack_obs_space(obs_space, stack_size, stack_dim=-1):
40 | """
41 | obs_space_dict: Dictionary of observations spaces of agents
42 | stack_size: Number of frames in the observation stack
43 | Returns:
44 | New obs_space_dict
45 | """
46 | if isinstance(obs_space, Box):
47 | dtype = obs_space.dtype
48 | # stack 1-D frames and 3-D frames
49 | tile_shape, new_shape = get_tile_shape(
50 | obs_space.low.shape, stack_size, stack_dim
51 | )
52 |
53 | low = np.tile(obs_space.low.reshape(new_shape), tile_shape)
54 | high = np.tile(obs_space.high.reshape(new_shape), tile_shape)
55 | new_obs_space = Box(low=low, high=high, dtype=dtype)
56 | return new_obs_space
57 | elif isinstance(obs_space, Discrete):
58 | return Discrete(obs_space.n**stack_size)
59 | else:
60 | assert (
61 | False
62 | ), "Stacking is currently only allowed for Box and Discrete observation spaces. The given observation space is {}".format(
63 | obs_space
64 | )
65 |
66 |
67 | def stack_init(obs_space, stack_size, stack_dim=-1):
68 | if isinstance(obs_space, Box):
69 | tile_shape, new_shape = get_tile_shape(
70 | obs_space.low.shape, stack_size, stack_dim
71 | )
72 | return np.tile(np.zeros(new_shape, dtype=obs_space.dtype), tile_shape)
73 | else:
74 | return 0
75 |
76 |
77 | def stack_obs(frame_stack, obs, obs_space, stack_size, stack_dim=-1):
78 | """
79 | Parameters
80 | ----------
81 | frame_stack : if not None, it is the stack of frames
82 | obs : new observation
83 | Rearranges frame_stack. Appends the new observation at the end.
84 | Throws away the oldest observation.
85 | stack_size : needed for stacking reset observations
86 | """
87 | if isinstance(obs_space, Box):
88 | obs_shape = obs.shape
89 | agent_fs = frame_stack
90 |
91 | if len(obs_shape) == 1:
92 | size = obs_shape[0]
93 | agent_fs[:-size] = agent_fs[size:]
94 | agent_fs[-size:] = obs
95 |
96 | elif len(obs_shape) == 2:
97 | if stack_dim == -1:
98 | agent_fs[:, :, :-1] = agent_fs[:, :, 1:]
99 | agent_fs[:, :, -1] = obs
100 | elif stack_dim == 0:
101 | agent_fs[:-1] = agent_fs[1:]
102 | agent_fs[:-1] = obs
103 |
104 | elif len(obs_shape) == 3:
105 | if stack_dim == -1:
106 | nchannels = obs_shape[-1]
107 | agent_fs[:, :, :-nchannels] = agent_fs[:, :, nchannels:]
108 | agent_fs[:, :, -nchannels:] = obs
109 | elif stack_dim == 0:
110 | nchannels = obs_shape[0]
111 | agent_fs[:-nchannels] = agent_fs[nchannels:]
112 | agent_fs[-nchannels:] = obs
113 |
114 | return agent_fs
115 |
116 | elif isinstance(obs_space, Discrete):
117 | return (frame_stack * obs_space.n + obs) % (obs_space.n**stack_size)
118 |
--------------------------------------------------------------------------------
/supersuit/utils/agent_indicator.py:
--------------------------------------------------------------------------------
1 | import re
2 | import warnings
3 |
4 | import numpy as np
5 | from gymnasium.spaces import Box, Discrete
6 |
7 |
8 | def change_obs_space(space, num_indicators):
9 | if isinstance(space, Box):
10 | ndims = len(space.shape)
11 | if ndims == 1:
12 | pad_space = np.min(space.high) * np.ones(
13 | (num_indicators,), dtype=space.dtype
14 | )
15 | new_low = np.concatenate([space.low, np.zeros_like(pad_space)], axis=0)
16 | new_high = np.concatenate([space.high, pad_space], axis=0)
17 | new_space = Box(low=new_low, high=new_high, dtype=space.dtype)
18 | return new_space
19 | elif ndims == 3 or ndims == 2:
20 | orig_low = space.low if ndims == 3 else np.expand_dims(space.low, 2)
21 | orig_high = space.high if ndims == 3 else np.expand_dims(space.high, 2)
22 | pad_space = np.min(space.high) * np.ones(
23 | orig_low.shape[:2] + (num_indicators,), dtype=space.dtype
24 | )
25 | new_low = np.concatenate([orig_low, np.zeros_like(pad_space)], axis=2)
26 | new_high = np.concatenate([orig_high, pad_space], axis=2)
27 | new_space = Box(low=new_low, high=new_high, dtype=space.dtype)
28 | return new_space
29 | elif isinstance(space, Discrete):
30 | return Discrete(space.n * num_indicators)
31 |
32 | assert (
33 | False
34 | ), "agent_indicator space must be 1d, 2d, or 3d Box or Discrete, was {}".format(
35 | space
36 | )
37 |
38 |
39 | def get_indicator_map(agents, type_only):
40 | if type_only:
41 | assert all(
42 | re.match("[a-z]+_[0-9]+", agent) for agent in agents
43 | ), "when the `type_only` parameter is True to agent_indicator, the agent names must follow the `_` format"
44 | agent_id_map = {}
45 | type_idx_map = {}
46 | idx_num = 0
47 | for agent in agents:
48 | type = agent.split("_")[0]
49 | if type not in type_idx_map:
50 | type_idx_map[type] = idx_num
51 | idx_num += 1
52 | agent_id_map[agent] = type_idx_map[type]
53 | if idx_num == 1:
54 | warnings.warn(
55 | "agent_indicator wrapper is degenerate, only one agent type; doing nothing"
56 | )
57 | return agent_id_map
58 | else:
59 | return {agent: i for i, agent in enumerate(agents)}
60 |
61 |
62 | def check_params(spaces):
63 | spaces = list(spaces)
64 | first_space = spaces[0]
65 | for space in spaces:
66 | assert repr(space) == repr(
67 | first_space
68 | ), "spaces need to be the same shape to add an indicator. Try using the `pad_observations` wrapper before agent_indicator."
69 | change_obs_space(space, 1)
70 |
71 |
72 | def change_observation(obs, space, indicator_data):
73 | indicator_num, num_indicators = indicator_data
74 | assert 0 <= indicator_num < num_indicators
75 | if isinstance(space, Box):
76 | ndims = len(space.shape)
77 | if ndims == 1:
78 | old_len = len(obs)
79 | new_obs = np.pad(obs, (0, num_indicators))
80 | # if all spaces are finite, use the max, otherwise use 1.0 as agent indicator
81 | if not np.isinf(space.high).any():
82 | new_obs[indicator_num + old_len] = np.max(space.high)
83 | else:
84 | new_obs[indicator_num + old_len] = 1.0
85 |
86 | return new_obs
87 | elif ndims == 3 or ndims == 2:
88 | obs = obs if ndims == 3 else np.expand_dims(obs, 2)
89 | old_shaped3 = obs.shape[2]
90 | new_obs = np.pad(obs, [(0, 0), (0, 0), (0, num_indicators)])
91 | # if all spaces are finite, use the max, otherwise use 1.0 as agent indicator
92 | if not np.isinf(space.high).any():
93 | new_obs[:, :, old_shaped3 + indicator_num] = np.max(space.high)
94 | else:
95 | new_obs[:, :, old_shaped3 + indicator_num] = 1.0
96 | return new_obs
97 | elif isinstance(space, Discrete):
98 | return obs * num_indicators + indicator_num
99 |
100 | assert (
101 | False
102 | ), "agent_indicator space must be 1d, 2d, or 3d Box or Discrete, was {}".format(
103 | space
104 | )
105 |
--------------------------------------------------------------------------------
/test/test_vector/test_pettingzoo_to_vec.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 | import pytest
5 | from pettingzoo.butterfly import knights_archers_zombies_v10
6 | from pettingzoo.mpe import simple_spread_v3, simple_world_comm_v3
7 |
8 | from supersuit import black_death_v3, concat_vec_envs_v1, pettingzoo_env_to_vec_env_v1
9 |
10 |
11 | def test_good_env():
12 | env = simple_spread_v3.parallel_env()
13 | max_num_agents = len(env.possible_agents)
14 | env = pettingzoo_env_to_vec_env_v1(env)
15 | assert env.num_envs == max_num_agents
16 |
17 | obss, infos = env.reset()
18 | for i in range(55):
19 | actions = [env.action_space.sample() for i in range(env.num_envs)]
20 |
21 | # Check we're not passing a thing that gets mutated
22 | keep_obs = copy.deepcopy(obss)
23 | new_obss, rews, terms, truncs, infos = env.step(actions)
24 |
25 | assert hash(str(keep_obs)) == hash(str(obss))
26 | assert len(new_obss) == max_num_agents
27 | assert len(rews) == max_num_agents
28 | assert len(terms) == max_num_agents
29 | assert len(truncs) == max_num_agents
30 | assert len(infos) == max_num_agents
31 | # no agent death, only env death
32 | if any(terms):
33 | assert all(terms)
34 | if any(truncs):
35 | assert all(truncs)
36 | obss = new_obss
37 |
38 |
39 | def test_good_vecenv():
40 | num_envs = 2
41 | env = simple_spread_v3.parallel_env()
42 | max_num_agents = len(env.possible_agents) * num_envs
43 | env = pettingzoo_env_to_vec_env_v1(env)
44 | env = concat_vec_envs_v1(env, num_envs)
45 |
46 | obss, infos = env.reset()
47 | for i in range(55):
48 | actions = [env.action_space.sample() for i in range(env.num_envs)]
49 |
50 | # Check we're not passing a thing that gets mutated
51 | keep_obs = copy.deepcopy(obss)
52 | new_obss, rews, terms, truncs, infos = env.step(actions)
53 |
54 | assert hash(str(keep_obs)) == hash(str(obss))
55 | assert len(new_obss) == max_num_agents
56 | assert len(rews) == max_num_agents
57 | assert len(terms) == max_num_agents
58 | assert len(truncs) == max_num_agents
59 | assert len(infos) == max_num_agents
60 | # no agent death, only env death
61 | if any(terms):
62 | assert all(terms)
63 | if any(truncs):
64 | assert all(truncs)
65 | obss = new_obss
66 |
67 |
68 | def test_bad_action_spaces_env():
69 | env = simple_world_comm_v3.parallel_env()
70 | with pytest.raises(AssertionError):
71 | env = pettingzoo_env_to_vec_env_v1(env)
72 |
73 |
74 | def test_env_black_death_assertion():
75 | env = knights_archers_zombies_v10.parallel_env(spawn_rate=50, max_cycles=2000)
76 | env = pettingzoo_env_to_vec_env_v1(env)
77 | with pytest.raises(AssertionError):
78 | for i in range(100):
79 | env.reset()
80 | for i in range(2000):
81 | actions = [env.action_space.sample() for i in range(env.num_envs)]
82 | obss, rews, terms, truncs, infos = env.step(actions)
83 |
84 |
85 | def test_env_black_death_wrapper():
86 | env = knights_archers_zombies_v10.parallel_env(spawn_rate=50, max_cycles=300)
87 | env = black_death_v3(env)
88 | env = pettingzoo_env_to_vec_env_v1(env)
89 | env.reset()
90 | for i in range(300):
91 | actions = [env.action_space.sample() for i in range(env.num_envs)]
92 | obss, rews, terms, truncs, infos = env.step(actions)
93 |
94 |
95 | def test_terminal_obs_are_returned():
96 | """
97 | If we reach (and pass) the end of the episode, the last observation is returned in the info dict.
98 | """
99 | max_cycles = 300
100 | env = knights_archers_zombies_v10.parallel_env(spawn_rate=50, max_cycles=300)
101 | env = black_death_v3(env)
102 | env = pettingzoo_env_to_vec_env_v1(env)
103 | env.reset(seed=42)
104 |
105 | # run past max_cycles or until terminated - causing the env to reset and continue
106 | for _ in range(0, max_cycles + 10):
107 | actions = [env.action_space.sample() for i in range(env.num_envs)]
108 | _, _, terms, truncs, infos = env.step(actions)
109 |
110 | env_done = (np.array(terms) | np.array(truncs)).all()
111 |
112 | if env_done:
113 | # check we have infos for all agents
114 | assert len(infos) == len(env.par_env.possible_agents)
115 | # check infos contain terminal_observation
116 | for info in infos:
117 | assert "terminal_observation" in info
118 |
--------------------------------------------------------------------------------
/test/gym_mock_test.py:
--------------------------------------------------------------------------------
1 | from test.dummy_gym_env import DummyEnv
2 |
3 | import numpy as np
4 | import pytest
5 | from gymnasium.spaces import Box, Discrete
6 |
7 | import supersuit
8 | from supersuit import action_lambda_v1, dtype_v0, observation_lambda_v0, reshape_v0
9 |
10 |
11 | base_obs = (np.zeros([8, 8, 3]) + np.arange(3)).astype(np.float32)
12 | base_obs_space = Box(low=np.float32(0.0), high=np.float32(10.0), shape=[8, 8, 3])
13 | base_act_spaces = Discrete(5)
14 |
15 |
16 | def test_reshape():
17 | base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces)
18 | env = reshape_v0(base_env, (64, 3))
19 | obs, info = env.reset()
20 | assert obs.shape == (64, 3)
21 | first_obs, _, _, _, _ = env.step(5)
22 | assert np.all(np.equal(first_obs, base_obs.reshape([64, 3])))
23 |
24 |
25 | def new_continuous_dummy():
26 | base_act_spaces = Box(low=np.float32(0.0), high=np.float32(10.0), shape=[3])
27 | return DummyEnv(base_obs, base_obs_space, base_act_spaces)
28 |
29 |
30 | def new_dummy():
31 | return DummyEnv(base_obs, base_obs_space, base_act_spaces)
32 |
33 |
34 | wrappers = [
35 | supersuit.color_reduction_v0(new_dummy(), "R"),
36 | supersuit.resize_v1(dtype_v0(new_dummy(), np.uint8), x_size=5, y_size=10),
37 | supersuit.resize_v1(
38 | dtype_v0(new_dummy(), np.uint8), x_size=5, y_size=10, linear_interp=True
39 | ),
40 | supersuit.dtype_v0(new_dummy(), np.int32),
41 | supersuit.flatten_v0(new_dummy()),
42 | supersuit.reshape_v0(new_dummy(), (64, 3)),
43 | supersuit.normalize_obs_v0(new_dummy(), env_min=-1, env_max=5.0),
44 | supersuit.frame_stack_v1(new_dummy(), 8),
45 | supersuit.reward_lambda_v0(new_dummy(), lambda x: x / 10),
46 | supersuit.clip_reward_v0(new_dummy()),
47 | supersuit.clip_actions_v0(new_continuous_dummy()),
48 | supersuit.frame_skip_v0(new_dummy(), 4),
49 | supersuit.frame_skip_v0(new_dummy(), (4, 6)),
50 | supersuit.sticky_actions_v0(new_dummy(), 0.75),
51 | supersuit.delay_observations_v0(new_dummy(), 1),
52 | supersuit.max_observation_v0(new_dummy(), 3),
53 | supersuit.nan_noop_v0(new_dummy(), 0),
54 | supersuit.nan_zeros_v0(new_dummy()),
55 | supersuit.nan_random_v0(new_dummy()),
56 | supersuit.scale_actions_v0(new_continuous_dummy(), 0.5),
57 | ]
58 |
59 |
60 | @pytest.mark.parametrize("env", wrappers)
61 | def test_basic_wrappers(env):
62 | obs, info = env.reset(seed=5)
63 | act_space = env.action_space
64 | obs_space = env.observation_space
65 | assert obs_space.contains(obs)
66 | assert obs.dtype == obs_space.dtype
67 | for i in range(10):
68 | env.step(act_space.sample())
69 |
70 |
71 | def test_lambda():
72 | def add1(obs, obs_space):
73 | return obs + 1
74 |
75 | base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces)
76 | env = observation_lambda_v0(base_env, add1)
77 | obs0, info0 = env.reset()
78 | assert int(obs0[0][0][0]) == 1
79 | env = observation_lambda_v0(env, add1)
80 | obs0, info0 = env.reset()
81 | assert int(obs0[0][0][0]) == 2
82 |
83 | def tile_obs(obs, obs_space):
84 | shape_size = len(obs.shape)
85 | tile_shape = [1] * shape_size
86 | tile_shape[0] *= 2
87 | return np.tile(obs, tile_shape)
88 |
89 | env = observation_lambda_v0(env, tile_obs)
90 | obs0, info0 = env.reset()
91 | assert env.observation_space.shape == (16, 8, 3)
92 |
93 | def change_shape_fn(obs_space):
94 | return Box(low=0, high=1, shape=(32, 8, 3))
95 |
96 | env = observation_lambda_v0(env, tile_obs)
97 | obs0, info0 = env.reset()
98 | assert env.observation_space.shape == (32, 8, 3)
99 | assert obs0.shape == (32, 8, 3)
100 |
101 |
102 | def test_action_lambda():
103 | def inc1(x, space):
104 | return x + 1
105 |
106 | def change_space_fn(space):
107 | return Discrete(space.n + 1)
108 |
109 | base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces)
110 | env = action_lambda_v1(base_env, inc1, change_space_fn)
111 | assert env.action_space.n == base_env.action_space.n + 1
112 | env.reset()
113 | env.step(5)
114 |
115 | def one_hot(x, n):
116 | v = np.zeros(n)
117 | v[x] = 1
118 | return v
119 |
120 | act_spaces = Box(low=0, high=1, shape=(15,))
121 | base_env = DummyEnv(base_obs, base_obs_space, act_spaces)
122 | env = action_lambda_v1(
123 | base_env,
124 | lambda action, act_space: one_hot(action, act_space.shape[0]),
125 | lambda act_space: Discrete(act_space.shape[0]),
126 | )
127 |
128 | env.reset()
129 | env.step(2)
130 |
131 |
132 | def test_rew_lambda():
133 | env = supersuit.reward_lambda_v0(new_dummy(), lambda x: x / 10)
134 | env.reset()
135 | obs, rew, termination, truncation, info = env.step(0)
136 | assert rew == 1.0 / 10
137 |
--------------------------------------------------------------------------------
/test/test_vector/test_gym_vector.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import gymnasium
4 | import numpy as np
5 | import pytest
6 | from pettingzoo.mpe import simple_spread_v3
7 |
8 | from supersuit import concat_vec_envs_v1, gym_vec_env_v0, pettingzoo_env_to_vec_env_v1
9 |
10 |
11 | def recursive_equal(info1, info2):
12 | try:
13 | if info1 == info2:
14 | return True
15 | except ValueError:
16 | if isinstance(info1, np.ndarray) and isinstance(info2, np.ndarray):
17 | return np.all(np.equal(info1, info2))
18 | elif isinstance(info1, dict) and isinstance(info2, dict):
19 | return all(
20 | (
21 | set(info1.keys()) == set(info2.keys())
22 | and recursive_equal(info1[i], info2[i])
23 | )
24 | for i in info1.keys()
25 | )
26 | elif isinstance(info1, list) and isinstance(info2, list):
27 | return all(recursive_equal(i1, i2) for i1, i2 in zip(info1, info2))
28 | return False
29 |
30 |
31 | def check_vec_env_equivalency(venv1, venv2, check_info=True):
32 | # assert venv1.observation_space == venv2.observation_space
33 | # assert venv1.action_space == venv2.action_space
34 |
35 | obs1 = venv1.reset(seed=51)
36 | obs2 = venv2.reset(seed=51)
37 |
38 | for i in range(400):
39 | action = [venv1.action_space.sample() for env in range(venv1.num_envs)]
40 | assert np.all(np.equal(obs1, obs2))
41 |
42 | obs1, rew1, term1, trunc1, info1 = venv1.step(action)
43 | obs2, rew2, term2, trunc2, info2 = venv2.step(action)
44 |
45 | # uses close rather than equal due to inconsistency in reporting rewards as float32 or float64
46 | assert np.allclose(rew1, rew2)
47 | assert np.all(np.equal(term1, term2))
48 | assert np.all(np.equal(trunc1, trunc2))
49 | assert recursive_equal(info1, info2) or not check_info
50 |
51 |
52 | @pytest.mark.skip(
53 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188"
54 | )
55 | def test_gym_supersuit_equivalency():
56 | env = gymnasium.make("MountainCarContinuous-v0")
57 | num_envs = 3
58 | venv1 = concat_vec_envs_v1(env, num_envs)
59 | venv2 = gym_vec_env_v0(env, num_envs)
60 | check_vec_env_equivalency(venv1, venv2)
61 |
62 |
63 | @pytest.mark.skip(
64 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188"
65 | )
66 | def test_inital_state_dissimilarity():
67 | env = gymnasium.make("CartPole-v1")
68 | venv = concat_vec_envs_v1(env, 2)
69 | observations = venv.reset()
70 | assert not np.equal(observations[0], observations[1]).all()
71 |
72 |
73 | # we really don't want to have a stable baselines dependency even in tests
74 | # def test_stable_baselines_supersuit_equivalency():
75 | # env = gymnasium.make("MountainCarContinuous-v0")
76 | # num_envs = 3
77 | # venv1 = supersuit_vec_env(env, num_envs, base_class='stable_baselines3')
78 | # venv2 = stable_baselines3_vec_env(env, num_envs)
79 | # check_vec_env_equivalency(venv1, venv2, check_info=False) # stable baselines does not implement info correctly
80 |
81 |
82 | @pytest.mark.skip(
83 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188"
84 | )
85 | def test_mutliproc_single_proc_equivalency():
86 | env = gymnasium.make("CartPole-v1")
87 | num_envs = 3
88 | # uses single threaded vector environment
89 | venv1 = concat_vec_envs_v1(env, num_envs, num_cpus=0)
90 | # uses multiprocessing vector environment
91 | venv2 = concat_vec_envs_v1(env, num_envs, num_cpus=4)
92 | check_vec_env_equivalency(venv1, venv2)
93 |
94 |
95 | @pytest.mark.skip(
96 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188"
97 | )
98 | def test_multiagent_mutliproc_single_proc_equivalency():
99 | env = simple_spread_v3.parallel_env(max_cycles=10)
100 | env = pettingzoo_env_to_vec_env_v1(env)
101 | num_envs = 3
102 | # uses single threaded vector environment
103 | venv1 = concat_vec_envs_v1(env, num_envs, num_cpus=0)
104 | # uses multiprocessing vector environment
105 | venv2 = concat_vec_envs_v1(env, num_envs, num_cpus=4)
106 | check_vec_env_equivalency(venv1, venv2)
107 |
108 |
109 | @pytest.mark.skip(
110 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188"
111 | )
112 | def test_multiproc_buffer():
113 | num_envs = 2
114 | env = gymnasium.make("CartPole-v1")
115 | env = concat_vec_envs_v1(env, num_envs, num_cpus=2)
116 |
117 | obss = env.reset()
118 | for i in range(55):
119 | actions = [env.action_space.sample() for i in range(env.num_envs)]
120 |
121 | # Check we're not passing a thing that gets mutated
122 | keep_obs = copy.deepcopy(obss)
123 | new_obss, rews, terms, truncs, infos = env.step(actions)
124 |
125 | assert hash(str(keep_obs)) == hash(str(obss))
126 |
127 | obss = new_obss
128 |
--------------------------------------------------------------------------------
/test/aec_unwrapped_test.py:
--------------------------------------------------------------------------------
1 | from test.dummy_aec_env import DummyEnv
2 |
3 | import numpy as np
4 | from gymnasium import spaces
5 |
6 | from supersuit import (
7 | agent_indicator_v0,
8 | black_death_v3,
9 | clip_actions_v0,
10 | clip_reward_v0,
11 | color_reduction_v0,
12 | delay_observations_v0,
13 | dtype_v0,
14 | flatten_v0,
15 | frame_skip_v0,
16 | frame_stack_v1,
17 | max_observation_v0,
18 | nan_random_v0,
19 | nan_zeros_v0,
20 | normalize_obs_v0,
21 | pad_action_space_v0,
22 | pad_observations_v0,
23 | scale_actions_v0,
24 | sticky_actions_v0,
25 | )
26 |
27 |
28 | def observation_homogenizable(env, agents):
29 | homogenizable = True
30 | for agent in agents:
31 | homogenizable = homogenizable and (
32 | isinstance(env.observation_space(agent), spaces.Box)
33 | or isinstance(env.observation_space(agent), spaces.Discrete)
34 | )
35 | return homogenizable
36 |
37 |
38 | def action_homogenizable(env, agents):
39 | homogenizable = True
40 | for agent in agents:
41 | homogenizable = homogenizable and (
42 | isinstance(env.action_space(agent), spaces.Box)
43 | or isinstance(env.action_space(agent), spaces.Discrete)
44 | )
45 | return homogenizable
46 |
47 |
48 | def image_observation(env, agents):
49 | imagable = True
50 | for agent in agents:
51 | if isinstance(env.observation_space(agent), spaces.Box):
52 | imagable = imagable and (env.observation_space(agent).low.shape == 3)
53 | imagable = imagable and (len(env.observation_space(agent).shape[2]) == 3)
54 | imagable = imagable and (env.observation_space(agent).low == 0).all()
55 | imagable = imagable and (env.observation_space(agent).high == 255).all()
56 | else:
57 | return False
58 | return imagable
59 |
60 |
61 | def box_action(env, agents):
62 | boxable = True
63 | for agent in agents:
64 | boxable = boxable and isinstance(env.action_space(agent), spaces.Box)
65 | return boxable
66 |
67 |
68 | def not_dict_observation(env, agents):
69 | is_dict = True
70 | for agent in agents:
71 | is_dict = is_dict and (isinstance(env.observation_space(agent), spaces.Dict))
72 | return not is_dict
73 |
74 |
75 | def not_discrete_observation(env, agents):
76 | is_discrete = True
77 | for agent in agents:
78 | is_discrete = is_discrete and (
79 | isinstance(env.observation_space(agent), spaces.Discrete)
80 | )
81 | return not is_discrete
82 |
83 |
84 | def not_multibinary_observation(env, agents):
85 | is_discrete = True
86 | for agent in agents:
87 | is_discrete = is_discrete and (
88 | isinstance(env.observation_space(agent), spaces.MultiBinary)
89 | )
90 | return not is_discrete
91 |
92 |
93 | def unwrapped_check(env):
94 | env.reset()
95 | agents = env.agents
96 |
97 | if image_observation(env, agents):
98 | env = max_observation_v0(env, 2)
99 | env = color_reduction_v0(env, mode="full")
100 | env = normalize_obs_v0(env)
101 |
102 | if box_action(env, agents):
103 | env = clip_actions_v0(env)
104 | env = scale_actions_v0(env, 0.5)
105 |
106 | if observation_homogenizable(env, agents):
107 | env = pad_observations_v0(env)
108 | env = frame_stack_v1(env, 2)
109 | env = agent_indicator_v0(env)
110 | env = black_death_v3(env)
111 |
112 | if (
113 | not_dict_observation(env, agents)
114 | and not_discrete_observation(env, agents)
115 | and not_multibinary_observation(env, agents)
116 | ):
117 | env = dtype_v0(env, np.float16)
118 | env = flatten_v0(env)
119 | env = frame_skip_v0(env, 2)
120 |
121 | if action_homogenizable(env, agents):
122 | env = pad_action_space_v0(env)
123 |
124 | env = clip_reward_v0(env, lower_bound=-1, upper_bound=1)
125 | env = delay_observations_v0(env, 2)
126 | env = sticky_actions_v0(env, 0.5)
127 | env = nan_random_v0(env)
128 | env = nan_zeros_v0(env)
129 |
130 | assert env.unwrapped.__class__ == DummyEnv, f"Failed to unwrap {env}"
131 |
132 |
133 | def test_unwrapped():
134 | observation_spaces = []
135 | base = spaces.Box(low=-1.0, high=1.0, shape=[2], dtype=np.float32)
136 | observation_spaces.append({f"a{i}": base for i in range(2)})
137 | base = spaces.Box(low=0, high=255, shape=[64, 64, 3], dtype=np.int16)
138 | observation_spaces.append({f"a{i}": base for i in range(2)})
139 | base = spaces.Discrete(5)
140 | observation_spaces.append({f"a{i}": base for i in range(2)})
141 | base = spaces.MultiBinary([3, 4])
142 | observation_spaces.append({f"a{i}": base for i in range(2)})
143 |
144 | action_spaces = []
145 | base = spaces.Box(-3.0, 3.0, [3], np.float32)
146 | action_spaces.append({f"a{i}": base for i in range(2)})
147 | base = spaces.Discrete(5)
148 | action_spaces.append({f"a{i}": base for i in range(2)})
149 | base = spaces.MultiDiscrete([4, 5])
150 | action_spaces.append({f"a{i}": base for i in range(2)})
151 |
152 | for obs_space in observation_spaces:
153 | for act_space in action_spaces:
154 | base_obs = {a: obs_space[a].sample() for a in obs_space}
155 | env = DummyEnv(base_obs, obs_space, act_space)
156 | unwrapped_check(env)
157 |
--------------------------------------------------------------------------------
/supersuit/vector/markov_vector_wrapper.py:
--------------------------------------------------------------------------------
1 | import gymnasium.vector
2 | import numpy as np
3 | from gymnasium.vector.utils import concatenate, create_empty_array, iterate
4 |
5 |
6 | class MarkovVectorEnv(gymnasium.vector.VectorEnv):
7 | def __init__(self, par_env, black_death=False):
8 | """
9 | parameters:
10 | - par_env: the pettingzoo Parallel environment that will be converted to a gymnasium vector environment
11 | - black_death: whether to give zero valued observations and 0 rewards when an agent is done, allowing for environments with multiple numbers of agents.
12 | Is equivalent to adding the black death wrapper, but somewhat more efficient.
13 |
14 | The resulting object will be a valid vector environment that has a num_envs
15 | parameter equal to the max number of agents, will return an array of observations,
16 | rewards, dones, etc, and will reset environment automatically when it finishes
17 | """
18 | self.par_env = par_env
19 | self.metadata = par_env.metadata
20 | self.render_mode = par_env.unwrapped.render_mode
21 | self.observation_space = par_env.observation_space(par_env.possible_agents[0])
22 | self.action_space = par_env.action_space(par_env.possible_agents[0])
23 | assert all(
24 | self.observation_space == par_env.observation_space(agent)
25 | for agent in par_env.possible_agents
26 | ), "observation spaces not consistent. Perhaps you should wrap with `supersuit.multiagent_wrappers.pad_observations_v0`?"
27 | assert all(
28 | self.action_space == par_env.action_space(agent)
29 | for agent in par_env.possible_agents
30 | ), "action spaces not consistent. Perhaps you should wrap with `supersuit.multiagent_wrappers.pad_action_space_v0`?"
31 | self.num_envs = len(par_env.possible_agents)
32 | self.black_death = black_death
33 |
34 | def concat_obs(self, obs_dict):
35 | obs_list = []
36 | for i, agent in enumerate(self.par_env.possible_agents):
37 | if agent not in obs_dict:
38 | raise AssertionError(
39 | "environment has agent death. Not allowed for pettingzoo_env_to_vec_env_v1 unless black_death is True"
40 | )
41 | obs_list.append(obs_dict[agent])
42 |
43 | return concatenate(
44 | self.observation_space,
45 | obs_list,
46 | create_empty_array(self.observation_space, self.num_envs),
47 | )
48 |
49 | def step_async(self, actions):
50 | self._saved_actions = actions
51 |
52 | def step_wait(self):
53 | return self.step(self._saved_actions)
54 |
55 | def reset(self, seed=None, options=None):
56 | # TODO: should this be changed to infos?
57 | _observations, infos = self.par_env.reset(seed=seed, options=options)
58 | observations = self.concat_obs(_observations)
59 | infs = [infos.get(agent, {}) for agent in self.par_env.possible_agents]
60 | return observations, infs
61 |
62 | def step(self, actions):
63 | actions = list(iterate(self.action_space, actions))
64 | agent_set = set(self.par_env.agents)
65 | act_dict = {
66 | agent: actions[i]
67 | for i, agent in enumerate(self.par_env.possible_agents)
68 | if agent in agent_set
69 | }
70 | observations, rewards, terms, truncs, infos = self.par_env.step(act_dict)
71 |
72 | # adds last observation to info where user can get it
73 | terminations = np.fromiter(terms.values(), dtype=bool)
74 | truncations = np.fromiter(truncs.values(), dtype=bool)
75 | env_done = (terminations | truncations).all()
76 | if env_done:
77 | for agent, obs in observations.items():
78 | infos[agent]["terminal_observation"] = obs
79 |
80 | rews = np.array(
81 | [rewards.get(agent, 0) for agent in self.par_env.possible_agents],
82 | dtype=np.float32,
83 | )
84 | tms = np.array(
85 | [terms.get(agent, False) for agent in self.par_env.possible_agents],
86 | dtype=np.uint8,
87 | )
88 | tcs = np.array(
89 | [truncs.get(agent, False) for agent in self.par_env.possible_agents],
90 | dtype=np.uint8,
91 | )
92 | infs = [infos.get(agent, {}) for agent in self.par_env.possible_agents]
93 |
94 | if env_done:
95 | observations, reset_infs = self.reset()
96 | else:
97 | observations = self.concat_obs(observations)
98 | # empty infos for reset infs
99 | reset_infs = [{} for _ in range(len(self.par_env.possible_agents))]
100 | # combine standard infos and reset infos
101 | infs = [{**inf, **reset_inf} for inf, reset_inf in zip(infs, reset_infs)]
102 |
103 | assert (
104 | self.black_death or self.par_env.agents == self.par_env.possible_agents
105 | ), "MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True"
106 | return observations, rews, tms, tcs, infs
107 |
108 | def render(self):
109 | return self.par_env.render()
110 |
111 | def close(self):
112 | return self.par_env.close()
113 |
114 | def env_is_wrapped(self, wrapper_class):
115 | """
116 | env_is_wrapped only suppors vector and gymnasium environments
117 | currently, not pettingzoo environments
118 | """
119 | return [False] * self.num_envs
120 |
--------------------------------------------------------------------------------
/supersuit/lambda_wrappers/observation_lambda.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import gymnasium
4 | import numpy as np
5 | from gymnasium.spaces import Box, Space
6 |
7 | from supersuit.utils.base_aec_wrapper import BaseWrapper
8 | from supersuit.utils.wrapper_chooser import WrapperChooser
9 |
10 |
11 | class aec_observation_lambda(BaseWrapper):
12 | def __init__(self, env, change_observation_fn, change_obs_space_fn=None):
13 | assert callable(
14 | change_observation_fn
15 | ), "change_observation_fn needs to be a function. It is {}".format(
16 | change_observation_fn
17 | )
18 | assert change_obs_space_fn is None or callable(
19 | change_obs_space_fn
20 | ), "change_obs_space_fn needs to be a function. It is {}".format(
21 | change_obs_space_fn
22 | )
23 |
24 | self.change_observation_fn = change_observation_fn
25 | self.change_obs_space_fn = change_obs_space_fn
26 |
27 | super().__init__(env)
28 |
29 | if hasattr(self, "possible_agents"):
30 | for agent in self.possible_agents:
31 | # call any validation logic in this function
32 | self.observation_space(agent)
33 |
34 | def _modify_action(self, agent, action):
35 | return action
36 |
37 | def _check_wrapper_params(self):
38 | if self.change_obs_space_fn is None and hasattr(self, "possible_agents"):
39 | for agent in self.possible_agents:
40 | assert isinstance(
41 | self.observation_space(agent), Box
42 | ), "the observation_lambda_wrapper only allows the change_obs_space_fn argument to be optional for Box observation spaces"
43 |
44 | @functools.lru_cache(maxsize=None)
45 | def observation_space(self, agent):
46 | if self.change_obs_space_fn is None:
47 | space = self.env.observation_space(agent)
48 | try:
49 | trans_low = self.change_observation_fn(space.low, space, agent)
50 | trans_high = self.change_observation_fn(space.high, space, agent)
51 | except TypeError:
52 | trans_low = self.change_observation_fn(space.low, space)
53 | trans_high = self.change_observation_fn(space.high, space)
54 | new_low = np.minimum(trans_low, trans_high)
55 | new_high = np.maximum(trans_low, trans_high)
56 |
57 | return Box(low=new_low, high=new_high, dtype=new_low.dtype)
58 | else:
59 | old_obs_space = self.env.observation_space(agent)
60 | try:
61 | return self.change_obs_space_fn(old_obs_space, agent)
62 | except TypeError:
63 | return self.change_obs_space_fn(old_obs_space)
64 |
65 | def _modify_observation(self, agent, observation):
66 | old_obs_space = self.env.observation_space(agent)
67 | try:
68 | return self.change_observation_fn(observation, old_obs_space, agent)
69 | except TypeError:
70 | return self.change_observation_fn(observation, old_obs_space)
71 |
72 |
73 | class gym_observation_lambda(gymnasium.Wrapper):
74 | def __init__(self, env, change_observation_fn, change_obs_space_fn=None):
75 | assert callable(
76 | change_observation_fn
77 | ), "change_observation_fn needs to be a function. It is {}".format(
78 | change_observation_fn
79 | )
80 | assert change_obs_space_fn is None or callable(
81 | change_obs_space_fn
82 | ), "change_obs_space_fn needs to be a function. It is {}".format(
83 | change_obs_space_fn
84 | )
85 | self.change_observation_fn = change_observation_fn
86 | self.change_obs_space_fn = change_obs_space_fn
87 |
88 | super().__init__(env)
89 | self._check_wrapper_params()
90 | self._modify_spaces()
91 |
92 | def _check_wrapper_params(self):
93 | if self.change_obs_space_fn is None:
94 | space = self.observation_space
95 | assert isinstance(
96 | space, Box
97 | ), "the observation_lambda_wrapper only allows the change_obs_space_fn argument to be optional for Box observation spaces"
98 |
99 | def _modify_spaces(self):
100 | space = self.observation_space
101 |
102 | if self.change_obs_space_fn is None:
103 | new_low = self.change_observation_fn(space.low, space)
104 | new_high = self.change_observation_fn(space.high, space)
105 | new_space = Box(low=new_low, high=new_high, dtype=new_low.dtype)
106 | else:
107 | new_space = self.change_obs_space_fn(space)
108 | assert isinstance(
109 | new_space, Space
110 | ), "output of change_obs_space_fn to observation_lambda_wrapper must be a gymnasium space"
111 | self.observation_space = new_space
112 |
113 | def _modify_observation(self, observation):
114 | return self.change_observation_fn(observation, self.env.observation_space)
115 |
116 | def step(self, action):
117 | observation, rew, termination, truncation, info = self.env.step(action)
118 | observation = self._modify_observation(observation)
119 | return observation, rew, termination, truncation, info
120 |
121 | def reset(self, seed=None, options=None):
122 | observation, infos = self.env.reset(seed=seed, options=options)
123 | observation = self._modify_observation(observation)
124 | return observation, infos
125 |
126 |
127 | observation_lambda_v0 = WrapperChooser(
128 | aec_wrapper=aec_observation_lambda, gym_wrapper=gym_observation_lambda
129 | )
130 |
--------------------------------------------------------------------------------
/test/generated_agents_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from pettingzoo.test import api_test, parallel_test
4 | from pettingzoo.test.example_envs import (
5 | generated_agents_env_v0,
6 | generated_agents_parallel_v0,
7 | )
8 |
9 | import supersuit
10 | from supersuit import dtype_v0
11 |
12 |
13 | wrappers = [
14 | supersuit.dtype_v0(generated_agents_parallel_v0.env(), np.int32),
15 | supersuit.flatten_v0(generated_agents_parallel_v0.env()),
16 | supersuit.normalize_obs_v0(
17 | dtype_v0(generated_agents_parallel_v0.env(), np.float32),
18 | env_min=-1,
19 | env_max=5.0,
20 | ),
21 | supersuit.frame_stack_v1(generated_agents_parallel_v0.env(), 8),
22 | supersuit.reward_lambda_v0(generated_agents_parallel_v0.env(), lambda x: x / 10),
23 | supersuit.clip_reward_v0(generated_agents_parallel_v0.env()),
24 | supersuit.nan_noop_v0(generated_agents_parallel_v0.env(), 0),
25 | supersuit.nan_zeros_v0(generated_agents_parallel_v0.env()),
26 | supersuit.nan_random_v0(generated_agents_parallel_v0.env()),
27 | supersuit.frame_skip_v0(generated_agents_parallel_v0.env(), 4),
28 | supersuit.sticky_actions_v0(generated_agents_parallel_v0.env(), 0.75),
29 | supersuit.delay_observations_v0(generated_agents_parallel_v0.env(), 3),
30 | supersuit.max_observation_v0(generated_agents_parallel_v0.env(), 3),
31 | ]
32 |
33 |
34 | # TODO: fix errors: AssertionError: action is not in action space
35 | @pytest.mark.skip(
36 | reason="skipped: unknown bug, most likely due to converting to AEC env (e.g., obs_lambda has no parallel wrapper)"
37 | )
38 | @pytest.mark.parametrize("env", wrappers)
39 | def test_pettingzoo_aec_api_par_gen(env):
40 | api_test(env, num_cycles=50)
41 |
42 |
43 | wrappers = [
44 | supersuit.dtype_v0(generated_agents_env_v0.env(), np.int32),
45 | supersuit.flatten_v0(generated_agents_env_v0.env()),
46 | supersuit.normalize_obs_v0(
47 | dtype_v0(generated_agents_env_v0.env(), np.float32), env_min=-1, env_max=5.0
48 | ),
49 | supersuit.frame_stack_v1(generated_agents_env_v0.env(), 8),
50 | supersuit.reward_lambda_v0(generated_agents_env_v0.env(), lambda x: x / 10),
51 | supersuit.clip_reward_v0(generated_agents_env_v0.env()),
52 | supersuit.nan_noop_v0(generated_agents_env_v0.env(), 0),
53 | supersuit.nan_zeros_v0(generated_agents_env_v0.env()),
54 | supersuit.nan_random_v0(generated_agents_env_v0.env()),
55 | supersuit.frame_skip_v0(generated_agents_env_v0.env(), 4),
56 | supersuit.sticky_actions_v0(generated_agents_env_v0.env(), 0.75),
57 | supersuit.delay_observations_v0(generated_agents_env_v0.env(), 3),
58 | supersuit.max_observation_v0(generated_agents_env_v0.env(), 3),
59 | ]
60 |
61 |
62 | # TODO fix error: ValueError: operands could not be broadcast together with shapes (42,) (10,)
63 | @pytest.mark.skip(
64 | reason="skipped: unknown bug, most likely due to converting to AEC env (e.g., obs_lambda has no parallel wrapper)"
65 | )
66 | @pytest.mark.parametrize("env", wrappers)
67 | def test_pettingzoo_aec_api_aec_gen(env):
68 | api_test(env, num_cycles=50)
69 |
70 |
71 | parallel_wrappers = wrappers = [
72 | supersuit.dtype_v0(generated_agents_parallel_v0.parallel_env(), np.int32),
73 | supersuit.flatten_v0(generated_agents_parallel_v0.parallel_env()),
74 | supersuit.normalize_obs_v0(
75 | dtype_v0(generated_agents_parallel_v0.parallel_env(), np.float32),
76 | env_min=-1,
77 | env_max=5.0,
78 | ),
79 | supersuit.frame_stack_v1(generated_agents_parallel_v0.parallel_env(), 8),
80 | supersuit.reward_lambda_v0(
81 | generated_agents_parallel_v0.parallel_env(), lambda x: x / 10
82 | ),
83 | supersuit.clip_reward_v0(generated_agents_parallel_v0.parallel_env()),
84 | supersuit.nan_noop_v0(generated_agents_parallel_v0.parallel_env(), 0),
85 | supersuit.nan_zeros_v0(generated_agents_parallel_v0.parallel_env()),
86 | supersuit.nan_random_v0(generated_agents_parallel_v0.parallel_env()),
87 | supersuit.frame_skip_v0(generated_agents_parallel_v0.parallel_env(), 4, 0),
88 | supersuit.sticky_actions_v0(generated_agents_parallel_v0.parallel_env(), 0.75),
89 | supersuit.delay_observations_v0(generated_agents_parallel_v0.parallel_env(), 3),
90 | supersuit.max_observation_v0(generated_agents_parallel_v0.parallel_env(), 3),
91 | ]
92 |
93 |
94 | # TODO: fix normalizing obs issue: ValueError: operands could not be broadcast together with shapes (48,) (20,)
95 | @pytest.mark.skip(
96 | reason="skipped: unknown bug, most likely due to converting to AEC env (e.g., obs_lambda has no parallel wrapper)"
97 | )
98 | @pytest.mark.parametrize("env", parallel_wrappers)
99 | def test_pettingzoo_parallel_api_gen(env):
100 | parallel_test.parallel_api_test(env, num_cycles=50)
101 |
102 |
103 | wrapper_fns = [
104 | lambda: supersuit.pad_action_space_v0(generated_agents_parallel_v0.env()),
105 | lambda: supersuit.pad_observations_v0(generated_agents_parallel_v0.env()),
106 | lambda: supersuit.agent_indicator_v0(generated_agents_parallel_v0.env()),
107 | lambda: supersuit.vectorize_aec_env_v0(generated_agents_parallel_v0.env(), 2),
108 | lambda: supersuit.pad_action_space_v0(generated_agents_parallel_v0.parallel_env()),
109 | lambda: supersuit.pad_observations_v0(generated_agents_parallel_v0.parallel_env()),
110 | lambda: supersuit.agent_indicator_v0(generated_agents_parallel_v0.parallel_env()),
111 | lambda: supersuit.pettingzoo_env_to_vec_env_v1(
112 | generated_agents_parallel_v0.parallel_env()
113 | ),
114 | ]
115 |
116 |
117 | @pytest.mark.parametrize("wrapper_fn", wrapper_fns)
118 | def test_pettingzoo_missing_optional_error_message(wrapper_fn):
119 | with pytest.raises(AssertionError, match=" must have "):
120 | wrapper_fn()
121 |
--------------------------------------------------------------------------------
/supersuit/generic_wrappers/utils/shared_wrapper_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import gymnasium
4 | from pettingzoo.utils import BaseParallelWrapper
5 | from pettingzoo.utils.wrappers import OrderEnforcingWrapper as BaseWrapper
6 |
7 | from supersuit.utils.wrapper_chooser import WrapperChooser
8 |
9 |
10 | class shared_wrapper_aec(BaseWrapper):
11 | def __init__(self, env, modifier_class):
12 | super().__init__(env)
13 |
14 | self.modifier_class = modifier_class
15 | self.modifiers = {}
16 | self._cur_seed = None
17 | self._cur_options = None
18 |
19 | if hasattr(self.env, "possible_agents"):
20 | self.add_modifiers(self.env.possible_agents)
21 |
22 | @functools.lru_cache(maxsize=None)
23 | def observation_space(self, agent):
24 | return self.modifiers[agent].modify_obs_space(self.env.observation_space(agent))
25 |
26 | @functools.lru_cache(maxsize=None)
27 | def action_space(self, agent):
28 | return self.modifiers[agent].modify_action_space(self.env.action_space(agent))
29 |
30 | def add_modifiers(self, agents_list):
31 | for agent in agents_list:
32 | if agent not in self.modifiers:
33 | # populate modifier spaces
34 | self.modifiers[agent] = self.modifier_class()
35 | self.observation_space(agent)
36 | self.action_space(agent)
37 | self.modifiers[agent].reset(
38 | seed=self._cur_seed, options=self._cur_options
39 | )
40 |
41 | # modifiers for each agent has a different seed
42 | if self._cur_seed is not None:
43 | self._cur_seed += 1
44 |
45 | def reset(self, seed=None, options=None):
46 | self._cur_seed = seed
47 | self._cur_options = options
48 |
49 | for mod in self.modifiers.values():
50 | mod.reset(seed=seed, options=options)
51 | super().reset(seed=seed, options=options)
52 |
53 | self.add_modifiers(self.agents)
54 | self.modifiers[self.agent_selection].modify_obs(
55 | super().observe(self.agent_selection)
56 | )
57 |
58 | def step(self, action):
59 | mod = self.modifiers[self.agent_selection]
60 | action = mod.modify_action(action)
61 | if (
62 | self.terminations[self.agent_selection]
63 | or self.truncations[self.agent_selection]
64 | ):
65 | action = None
66 | super().step(action)
67 | self.add_modifiers(self.agents)
68 | self.modifiers[self.agent_selection].modify_obs(
69 | super().observe(self.agent_selection)
70 | )
71 |
72 | def observe(self, agent):
73 | return self.modifiers[agent].get_last_obs()
74 |
75 |
76 | class shared_wrapper_parr(BaseParallelWrapper):
77 | def __init__(self, env, modifier_class):
78 | super().__init__(env)
79 |
80 | self.modifier_class = modifier_class
81 | self.modifiers = {}
82 | self._cur_seed = None
83 | self._cur_options = None
84 |
85 | if hasattr(self.env, "possible_agents"):
86 | self.add_modifiers(self.env.possible_agents)
87 |
88 | @functools.lru_cache(maxsize=None)
89 | def observation_space(self, agent):
90 | return self.modifiers[agent].modify_obs_space(self.env.observation_space(agent))
91 |
92 | @functools.lru_cache(maxsize=None)
93 | def action_space(self, agent):
94 | return self.modifiers[agent].modify_action_space(self.env.action_space(agent))
95 |
96 | def add_modifiers(self, agents_list):
97 | for agent in agents_list:
98 | if agent not in self.modifiers:
99 | # populate modifier spaces
100 | self.modifiers[agent] = self.modifier_class()
101 | self.observation_space(agent)
102 | self.action_space(agent)
103 | self.modifiers[agent].reset(
104 | seed=self._cur_seed, options=self._cur_options
105 | )
106 |
107 | # modifiers for each agent has a different seed
108 | if self._cur_seed is not None:
109 | self._cur_seed += 1
110 |
111 | def reset(self, seed=None, options=None):
112 | self._cur_seed = seed
113 | self._cur_options = options
114 |
115 | observations, infos = super().reset(seed=seed, options=options)
116 | self.add_modifiers(self.agents)
117 | for agent, mod in self.modifiers.items():
118 | mod.reset(seed=seed, options=options)
119 | observations = {
120 | agent: self.modifiers[agent].modify_obs(obs)
121 | for agent, obs in observations.items()
122 | }
123 | return observations, infos
124 |
125 | def step(self, actions):
126 | actions = {
127 | agent: self.modifiers[agent].modify_action(action)
128 | for agent, action in actions.items()
129 | }
130 | observations, rewards, terminations, truncations, infos = super().step(actions)
131 | self.add_modifiers(self.agents)
132 | observations = {
133 | agent: self.modifiers[agent].modify_obs(obs)
134 | for agent, obs in observations.items()
135 | }
136 | return observations, rewards, terminations, truncations, infos
137 |
138 |
139 | class shared_wrapper_gym(gymnasium.Wrapper):
140 | def __init__(self, env, modifier_class):
141 | super().__init__(env)
142 | self.modifier = modifier_class()
143 | self.observation_space = self.modifier.modify_obs_space(self.observation_space)
144 | self.action_space = self.modifier.modify_action_space(self.action_space)
145 |
146 | def reset(self, seed=None, options=None):
147 | self.modifier.reset(seed=seed, options=options)
148 | obs, info = super().reset(seed=seed, options=options)
149 | obs = self.modifier.modify_obs(obs)
150 | return obs, info
151 |
152 | def step(self, action):
153 | obs, rew, term, trunc, info = super().step(self.modifier.modify_action(action))
154 | obs = self.modifier.modify_obs(obs)
155 | return obs, rew, term, trunc, info
156 |
157 |
158 | shared_wrapper = WrapperChooser(
159 | aec_wrapper=shared_wrapper_aec,
160 | gym_wrapper=shared_wrapper_gym,
161 | parallel_wrapper=shared_wrapper_parr,
162 | )
163 |
--------------------------------------------------------------------------------
/supersuit/aec_vector/vector_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pettingzoo.utils.agent_selector import agent_selector
3 |
4 | from .base_aec_vec_env import VectorAECEnv
5 |
6 |
7 | class SyncAECVectorEnv(VectorAECEnv):
8 | def __init__(self, env_constructors):
9 | assert len(env_constructors) >= 1
10 | assert callable(
11 | env_constructors[0]
12 | ), "env_constructor must be a callable object (i.e function) that create an environment"
13 |
14 | self.envs = [env_constructor() for env_constructor in env_constructors]
15 | self.num_envs = len(env_constructors)
16 | self.env = self.envs[0]
17 | self.max_num_agents = self.env.max_num_agents
18 | self.possible_agents = self.env.possible_agents
19 | self._agent_selector = agent_selector(self.possible_agents)
20 |
21 | def action_space(self, agent):
22 | return self.env.action_space(agent)
23 |
24 | def observation_space(self, agent):
25 | return self.env.observation_space(agent)
26 |
27 | def _find_active_agent(self):
28 | cur_selection = self.agent_selection
29 | while not any(cur_selection == env.agent_selection for env in self.envs):
30 | cur_selection = self._agent_selector.next()
31 | return cur_selection
32 |
33 | def _collect_dicts(self):
34 | self.rewards = {
35 | agent: np.array(
36 | [
37 | env.rewards[agent] if agent in env.rewards else 0
38 | for env in self.envs
39 | ],
40 | dtype=np.float32,
41 | )
42 | for agent in self.possible_agents
43 | }
44 | self._cumulative_rewards = {
45 | agent: np.array(
46 | [
47 | env._cumulative_rewards[agent]
48 | if agent in env._cumulative_rewards
49 | else 0
50 | for env in self.envs
51 | ],
52 | dtype=np.float32,
53 | )
54 | for agent in self.possible_agents
55 | }
56 | self.terminations = {
57 | agent: np.array(
58 | [
59 | env.terminations[agent] if agent in env.terminations else True
60 | for env in self.envs
61 | ],
62 | dtype=np.uint8,
63 | )
64 | for agent in self.possible_agents
65 | }
66 | self.truncations = {
67 | agent: np.array(
68 | [
69 | env.truncations[agent] if agent in env.truncations else True
70 | for env in self.envs
71 | ],
72 | dtype=np.uint8,
73 | )
74 | for agent in self.possible_agents
75 | }
76 | self.infos = {
77 | agent: [env.infos[agent] if agent in env.infos else {} for env in self.envs]
78 | for agent in self.possible_agents
79 | }
80 |
81 | def reset(self, seed=None, options=None):
82 | """
83 | returns: list of observations
84 | """
85 | if seed is not None:
86 | for i, env in enumerate(self.envs):
87 | env.reset(seed=seed + i, options=options)
88 | else:
89 | for i, env in enumerate(self.envs):
90 | env.reset(seed=None, options=options)
91 |
92 | self.agent_selection = self._agent_selector.reset()
93 | self.agent_selection = self._find_active_agent()
94 |
95 | self._collect_dicts()
96 | self.envs_terminations = np.zeros(self.num_envs)
97 | self.envs_truncations = np.zeros(self.num_envs)
98 |
99 | def observe(self, agent):
100 | observations = []
101 | for env in self.envs:
102 | obs = (
103 | env.observe(agent)
104 | if (agent in env.terminations) or (agent in env.truncations)
105 | else np.zeros_like(env.observation_space(agent).low)
106 | )
107 | observations.append(obs)
108 | return np.stack(observations)
109 |
110 | def last(self, observe=True):
111 | passes = np.array(
112 | [env.agent_selection != self.agent_selection for env in self.envs],
113 | dtype=np.uint8,
114 | )
115 | last_agent = self.agent_selection
116 | obs = self.observe(last_agent) if observe else None
117 | return (
118 | obs,
119 | self._cumulative_rewards[last_agent],
120 | self.terminations[last_agent],
121 | self.truncations[last_agent],
122 | self.envs_terminations,
123 | self.envs_truncations,
124 | passes,
125 | self.infos[last_agent],
126 | )
127 |
128 | def step(self, actions, observe=True):
129 | assert len(actions) == len(
130 | self.envs
131 | ), f"{len(actions)} actions given, but there are {len(self.envs)} environments!"
132 | old_agent = self.agent_selection
133 |
134 | envs_dones = []
135 | for i, (act, env) in enumerate(zip(actions, self.envs)):
136 | # Prior to the truncation API update, the env was reset if env.agents was an empty list
137 | # After the truncation API update, the env needs to be reset if every agent is terminated OR truncated
138 | terminations = np.fromiter(env.terminations.values(), dtype=bool)
139 | truncations = np.fromiter(env.truncations.values(), dtype=bool)
140 | env_done = (terminations | truncations).all()
141 | envs_dones.append(env_done)
142 |
143 | if env_done:
144 | env.reset()
145 | elif env.agent_selection == old_agent:
146 | if isinstance(type(act), np.ndarray):
147 | act = np.array(act)
148 | act = (
149 | act
150 | if not (
151 | self.terminations[old_agent][i]
152 | or self.truncations[old_agent][i]
153 | )
154 | else None
155 | ) # if the agent is dead, set action to None
156 | env.step(act)
157 |
158 | self.agent_selection = self._agent_selector.next()
159 | self.agent_selection = self._find_active_agent()
160 |
161 | self.envs_dones = np.array(envs_dones)
162 | self._collect_dicts()
163 |
--------------------------------------------------------------------------------
/supersuit/generic_wrappers/frame_skip.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 | import numpy as np
3 | from pettingzoo.utils.wrappers import BaseParallelWrapper, BaseWrapper
4 |
5 | from supersuit.utils.frame_skip import check_transform_frameskip
6 | from supersuit.utils.make_defaultdict import make_defaultdict
7 | from supersuit.utils.wrapper_chooser import WrapperChooser
8 |
9 |
10 | class frame_skip_gym(gymnasium.Wrapper):
11 | def __init__(self, env, num_frames):
12 | super().__init__(env)
13 | self.num_frames = check_transform_frameskip(num_frames)
14 |
15 | def step(self, action):
16 | low, high = self.num_frames
17 | num_skips = int(self.env.unwrapped.np_random.integers(low, high + 1))
18 | total_reward = 0.0
19 |
20 | for x in range(num_skips):
21 | obs, rew, term, trunc, info = super().step(action)
22 | total_reward += rew
23 | if term or trunc:
24 | break
25 |
26 | return obs, total_reward, term, trunc, info
27 |
28 |
29 | class StepAltWrapper(BaseWrapper):
30 | def _modify_action(self, agent, action):
31 | return action
32 |
33 | def _modify_observation(self, agent, observation):
34 | return observation
35 |
36 |
37 | class frame_skip_aec(StepAltWrapper):
38 | def __init__(self, env, num_frames):
39 | super().__init__(env)
40 | assert isinstance(
41 | num_frames, int
42 | ), "multi-agent frame skip only takes in an integer"
43 | assert num_frames > 0
44 | check_transform_frameskip(num_frames)
45 | self.num_frames = num_frames
46 |
47 | def reset(self, seed=None, options=None):
48 | super().reset(seed=seed, options=options)
49 | self.agents = self.env.agents[:]
50 | self.terminations = make_defaultdict({agent: False for agent in self.agents})
51 | self.truncations = make_defaultdict({agent: False for agent in self.agents})
52 | self.rewards = make_defaultdict({agent: 0.0 for agent in self.agents})
53 | self._cumulative_rewards = make_defaultdict(
54 | {agent: 0.0 for agent in self.agents}
55 | )
56 | self.infos = make_defaultdict({agent: {} for agent in self.agents})
57 | self.skip_num = make_defaultdict({agent: 0 for agent in self.agents})
58 | self.old_actions = make_defaultdict({agent: None for agent in self.agents})
59 | self._final_observations = make_defaultdict(
60 | {agent: None for agent in self.agents}
61 | )
62 |
63 | def observe(self, agent):
64 | fin_observe = self._final_observations[agent]
65 | return fin_observe if fin_observe is not None else super().observe(agent)
66 |
67 | def step(self, action):
68 | self._has_updated = True
69 |
70 | # if agent is dead, perform a None step
71 | if (
72 | self.terminations[self.agent_selection]
73 | or self.truncations[self.agent_selection]
74 | ):
75 | if self.env.agents and self.agent_selection == self.env.agent_selection:
76 | self.env.step(None)
77 | self._was_dead_step(action)
78 | return
79 |
80 | cur_agent = self.agent_selection
81 | self._cumulative_rewards[cur_agent] = 0
82 | self.rewards = make_defaultdict({a: 0.0 for a in self.agents})
83 | self.skip_num[
84 | cur_agent
85 | ] = (
86 | self.num_frames
87 | ) # set the skip num to the param inputted in the frame_skip wrapper
88 | self.old_actions[cur_agent] = action
89 |
90 | while (
91 | self.old_actions[self.env.agent_selection] is not None
92 | ): # this is like `for x in range(num_skips):` (L18)
93 | step_agent = self.env.agent_selection
94 |
95 | # if agent is dead, perform a None step
96 | if (step_agent in self.env.terminations) or (
97 | step_agent in self.env.truncations
98 | ):
99 | # reward = self.env.rewards[step_agent]
100 | # done = self.env.dones[step_agent]
101 | # info = self.env.infos[step_agent]
102 | observe, reward, termination, truncation, info = self.env.last(
103 | observe=False
104 | )
105 | action = self.old_actions[step_agent]
106 | self.env.step(action)
107 | for agent in self.env.agents:
108 | self.rewards[agent] += self.env.rewards[agent]
109 | self.infos[self.env.agent_selection] = info
110 |
111 | while self.env.agents and (
112 | self.env.terminations[self.env.agent_selection]
113 | or self.env.truncations[self.env.agent_selection]
114 | ):
115 | dead_agent = self.env.agent_selection
116 | self.terminations[dead_agent] = self.env.terminations[dead_agent]
117 | self.truncations[dead_agent] = self.env.truncations[dead_agent]
118 | self._final_observations[dead_agent] = self.env.observe(dead_agent)
119 | self.env.step(None)
120 | step_agent = self.env.agent_selection
121 |
122 | self.skip_num[step_agent] -= 1
123 | if self.skip_num[step_agent] == 0:
124 | self.old_actions[
125 | step_agent
126 | ] = None # if it is time to skip, set action to None, effectively breaking the while loop
127 |
128 | my_agent_set = set(self.agents)
129 | for agent in self.env.agents:
130 | self.terminations[agent] = self.env.terminations[agent]
131 | self.truncations[agent] = self.env.truncations[agent]
132 | self.infos[agent] = self.env.infos[agent]
133 | if agent not in my_agent_set:
134 | self.agents.append(agent)
135 | self.agent_selection = self.env.agent_selection
136 | self._accumulate_rewards()
137 | self._deads_step_first()
138 |
139 |
140 | class frame_skip_par(BaseParallelWrapper):
141 | def __init__(self, env, num_frames, default_action=None):
142 | super().__init__(env)
143 | self.num_frames = check_transform_frameskip(num_frames)
144 | self.default_action = default_action
145 |
146 | def step(self, action):
147 | action = {**action}
148 | low, high = self.num_frames
149 | num_skips = int(self.env.unwrapped.np_random.integers(low, high + 1))
150 | orig_agents = set(action.keys())
151 |
152 | total_reward = make_defaultdict({agent: 0.0 for agent in self.agents})
153 | total_terminations = {}
154 | total_truncations = {}
155 | total_infos = {}
156 | total_obs = {}
157 |
158 | for x in range(num_skips):
159 | obs, rews, term, trunc, info = super().step(action)
160 |
161 | for agent, rew in rews.items():
162 | total_reward[agent] += rew
163 | total_terminations[agent] = term[agent]
164 | total_truncations[agent] = trunc[agent]
165 | total_infos[agent] = info[agent]
166 | total_obs[agent] = obs[agent]
167 |
168 | for agent in self.env.agents:
169 | if agent not in action:
170 | assert (
171 | self.default_action is not None
172 | ), "parallel environments that use frame_skip_v0 must provide a `default_action` argument for steps between an agent being generated and an agent taking its first step"
173 | action[agent] = self.default_action
174 |
175 | if (
176 | np.fromiter(term.values(), dtype=bool)
177 | | np.fromiter(trunc.values(), dtype=bool)
178 | ).all():
179 | break
180 |
181 | # delete any values created by agents which were
182 | # generated and deleted before they took any actions
183 | final_agents = set(self.agents)
184 | for agent in list(total_reward):
185 | if agent not in final_agents and agent not in orig_agents:
186 | del total_reward[agent]
187 | del total_terminations[agent]
188 | del total_truncations[agent]
189 | del total_infos[agent]
190 | del total_obs[agent]
191 |
192 | return (
193 | total_obs,
194 | total_reward,
195 | total_terminations,
196 | total_truncations,
197 | total_infos,
198 | )
199 |
200 |
201 | frame_skip_v0 = WrapperChooser(
202 | aec_wrapper=frame_skip_aec,
203 | gym_wrapper=frame_skip_gym,
204 | parallel_wrapper=frame_skip_par,
205 | )
206 |
--------------------------------------------------------------------------------
/supersuit/vector/multiproc_vec.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import multiprocessing as mp
3 | import time
4 | import traceback
5 |
6 | import gymnasium.vector
7 | import numpy as np
8 | from gymnasium.vector.utils import (
9 | concatenate,
10 | create_empty_array,
11 | create_shared_memory,
12 | iterate,
13 | read_from_shared_memory,
14 | write_to_shared_memory,
15 | )
16 |
17 | from .utils.shared_array import SharedArray
18 |
19 |
20 | def compress_info(infos):
21 | non_empty_infs = [(i, info) for i, info in enumerate(infos) if info]
22 | return non_empty_infs
23 |
24 |
25 | def decompress_info(num_envs, idx_starts, comp_infos):
26 | all_info = [{}] * num_envs
27 | for idx_start, comp_infos in zip(idx_starts, comp_infos):
28 | for i, info in comp_infos:
29 | all_info[idx_start + i] = info
30 | return all_info
31 |
32 |
33 | def write_observations(vec_env, env_start_idx, shared_obs, obs):
34 | obs = list(iterate(vec_env.observation_space, obs))
35 | for i in range(vec_env.num_envs):
36 | write_to_shared_memory(
37 | vec_env.observation_space,
38 | env_start_idx + i,
39 | obs[i],
40 | shared_obs,
41 | )
42 |
43 |
44 | def numpy_deepcopy(buf):
45 | if isinstance(buf, dict):
46 | return {name: numpy_deepcopy(v) for name, v in buf.items()}
47 | elif isinstance(buf, tuple):
48 | return tuple(numpy_deepcopy(v) for v in buf)
49 | elif isinstance(buf, np.ndarray):
50 | return buf.copy()
51 | else:
52 | raise ValueError("numpy_deepcopy ")
53 |
54 |
55 | def async_loop(
56 | vec_env_constr, inpt_p, pipe, shared_obs, shared_rews, shared_terms, shared_truncs
57 | ):
58 | inpt_p.close()
59 | try:
60 | vec_env = vec_env_constr()
61 |
62 | pipe.send(vec_env.num_envs)
63 | env_start_idx = pipe.recv()
64 | env_end_idx = env_start_idx + vec_env.num_envs
65 | while True:
66 | instr = pipe.recv()
67 | comp_infos = []
68 |
69 | if instr == "close":
70 | vec_env.close()
71 |
72 | elif isinstance(instr, tuple):
73 | name, data = instr
74 |
75 | if name == "reset":
76 | observations, infos = vec_env.reset(seed=data[0], options=data[1])
77 | comp_infos = compress_info(infos)
78 |
79 | write_observations(vec_env, env_start_idx, shared_obs, observations)
80 | shared_terms.np_arr[env_start_idx:env_end_idx] = False
81 | shared_truncs.np_arr[env_start_idx:env_end_idx] = False
82 | shared_rews.np_arr[env_start_idx:env_end_idx] = 0.0
83 |
84 | elif name == "step":
85 | actions = data
86 | actions = concatenate(
87 | vec_env.action_space,
88 | actions,
89 | create_empty_array(vec_env.action_space, n=len(actions)),
90 | )
91 | observations, rewards, terms, truncs, infos = vec_env.step(actions)
92 | write_observations(vec_env, env_start_idx, shared_obs, observations)
93 | shared_terms.np_arr[env_start_idx:env_end_idx] = terms
94 | shared_truncs.np_arr[env_start_idx:env_end_idx] = truncs
95 | shared_rews.np_arr[env_start_idx:env_end_idx] = rewards
96 | comp_infos = compress_info(infos)
97 |
98 | elif name == "env_is_wrapped":
99 | comp_infos = vec_env.env_is_wrapped(data)
100 |
101 | else:
102 | raise AssertionError("bad tuple instruction name: " + name)
103 | elif instr == "render":
104 | render_result = vec_env.render()
105 | if vec_env.render_mode == "rgb_array":
106 | comp_infos = render_result
107 | elif instr == "terminate":
108 | return
109 | else:
110 | raise AssertionError("bad instruction: " + instr)
111 | pipe.send(comp_infos)
112 | except BaseException as e:
113 | tb = traceback.format_exc()
114 | pipe.send((e, tb))
115 |
116 |
117 | class ProcConcatVec(gymnasium.vector.VectorEnv):
118 | def __init__(
119 | self, vec_env_constrs, observation_space, action_space, tot_num_envs, metadata
120 | ):
121 | self.observation_space = observation_space
122 | self.action_space = action_space
123 | self.num_envs = num_envs = tot_num_envs
124 | self.metadata = metadata
125 |
126 | self.shared_obs = create_shared_memory(self.observation_space, n=self.num_envs)
127 | self.shared_act = create_shared_memory(self.action_space, n=self.num_envs)
128 | self.shared_rews = SharedArray((num_envs,), dtype=np.float32)
129 | self.shared_terms = SharedArray((num_envs,), dtype=np.uint8)
130 | self.shared_truncs = SharedArray((num_envs,), dtype=np.uint8)
131 |
132 | self.observations_buffers = read_from_shared_memory(
133 | self.observation_space, self.shared_obs, n=self.num_envs
134 | )
135 |
136 | self.graceful_shutdown_timeout = 10
137 |
138 | pipes = []
139 | procs = []
140 | for constr in vec_env_constrs:
141 | inpt, outpt = mp.Pipe()
142 | constr = gymnasium.vector.async_vector_env.CloudpickleWrapper(constr)
143 | proc = mp.Process(
144 | target=async_loop,
145 | args=(
146 | constr,
147 | inpt,
148 | outpt,
149 | self.shared_obs,
150 | self.shared_rews,
151 | self.shared_terms,
152 | self.shared_truncs,
153 | ),
154 | )
155 | proc.start()
156 | outpt.close()
157 | pipes.append(inpt)
158 | procs.append(proc)
159 |
160 | self.pipes = pipes
161 | self.procs = procs
162 |
163 | num_envs = 0
164 | env_nums = self._receive_info()
165 | idx_starts = []
166 | for pipe, cnum_env in zip(self.pipes, env_nums):
167 | cur_env_idx = num_envs
168 | num_envs += cnum_env
169 | pipe.send(cur_env_idx)
170 | idx_starts.append(cur_env_idx)
171 | idx_starts.append(num_envs)
172 |
173 | assert num_envs == tot_num_envs
174 | self.idx_starts = idx_starts
175 |
176 | def reset(self, seed=None, options=None):
177 | for i, pipe in enumerate(self.pipes):
178 | if seed is not None:
179 | pipe.send(("reset", (seed + i, options)))
180 | else:
181 | pipe.send(("reset", (seed, options)))
182 |
183 | info = self._receive_info()
184 |
185 | return numpy_deepcopy(self.observations_buffers), copy.deepcopy(info)
186 |
187 | def step_async(self, actions):
188 | actions = list(iterate(self.action_space, actions))
189 | for i, pipe in enumerate(self.pipes):
190 | start, end = self.idx_starts[i : i + 2]
191 | pipe.send(("step", actions[start:end]))
192 |
193 | def _receive_info(self):
194 | all_data = []
195 | for cin in self.pipes:
196 | data = cin.recv()
197 | if isinstance(data, tuple):
198 | e, tb = data
199 | print(tb)
200 | raise e
201 | all_data.append(data)
202 | return all_data
203 |
204 | def step_wait(self):
205 | compressed_infos = self._receive_info()
206 | infos = decompress_info(self.num_envs, self.idx_starts, compressed_infos)
207 | rewards = self.shared_rews.np_arr
208 | terms = self.shared_terms.np_arr
209 | truncs = self.shared_truncs.np_arr
210 | return (
211 | numpy_deepcopy(self.observations_buffers),
212 | rewards.copy(),
213 | terms.astype(bool).copy(),
214 | truncs.astype(bool).copy(),
215 | copy.deepcopy(infos),
216 | )
217 |
218 | def step(self, actions):
219 | self.step_async(actions)
220 | return self.step_wait()
221 |
222 | def __del__(self):
223 | self.close()
224 |
225 | def render(self):
226 | self.pipes[0].send("render")
227 | render_result = self.pipes[0].recv()
228 |
229 | if isinstance(render_result, tuple):
230 | e, tb = render_result
231 | print(tb)
232 | raise e
233 |
234 | return render_result
235 |
236 | def close(self):
237 | try:
238 | for pipe, proc in zip(self.pipes, self.procs):
239 | if proc.is_alive():
240 | pipe.send(("close", None))
241 | except OSError:
242 | pass
243 | else:
244 | deadline = (
245 | None
246 | if self.graceful_shutdown_timeout is None
247 | else time.monotonic() + self.graceful_shutdown_timeout
248 | )
249 | for proc in self.procs:
250 | timeout = None if deadline is None else deadline - time.monotonic()
251 | if timeout is not None and timeout <= 0:
252 | break
253 | proc.join(timeout)
254 | for pipe, proc in zip(self.pipes, self.procs):
255 | if proc.is_alive():
256 | proc.kill()
257 | pipe.close()
258 |
259 | def env_is_wrapped(self, wrapper_class, indices=None):
260 | for i, pipe in enumerate(self.pipes):
261 | pipe.send(("env_is_wrapped", wrapper_class))
262 |
263 | results = self._receive_info()
264 | return sum(results, [])
265 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | This repository is licensed as follows:
2 | All assets in this repository are the copyright of the Farama Foundation, except
3 | where prohibited. Contributors to the repository transfer copyright of their work
4 | to the Farama Foundation.
5 |
6 | Some code in this repository has been taken from other open source projects
7 | and was originally released under the MIT or Apache 2.0 licenses, with
8 | copyright held by another party. We've attributed these authors and they
9 | retain their copyright to the extent required by law. Everything else
10 | is owned by the Farama Foundation. The Secret Code font was also released under
11 | the MIT license by Matthew Welch (http://www.squaregear.net/fonts/).
12 | The MIT and Apache 2.0 licenses are included below.
13 |
14 | The Farama Foundation releases the elements of this repository they copyright to
15 | under the MIT license.
16 |
17 | --------------------------------------------------------------------------------
18 |
19 | MIT License
20 |
21 | Permission is hereby granted, free of charge, to any person obtaining a copy
22 | of this software and associated documentation files (the "Software"), to deal
23 | in the Software without restriction, including without limitation the rights
24 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25 | copies of the Software, and to permit persons to whom the Software is
26 | furnished to do so, subject to the following conditions:
27 |
28 | The above copyright notice and this permission notice shall be included in all
29 | copies or substantial portions of the Software.
30 |
31 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37 | SOFTWARE.
38 |
39 | --------------------------------------------------------------------------------
40 |
41 | Apache License
42 | Version 2.0, January 2004
43 | http://www.apache.org/licenses/
44 |
45 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
46 |
47 | 1. Definitions.
48 |
49 | "License" shall mean the terms and conditions for use, reproduction,
50 | and distribution as defined by Sections 1 through 9 of this document.
51 |
52 | "Licensor" shall mean the copyright owner or entity authorized by
53 | the copyright owner that is granting the License.
54 |
55 | "Legal Entity" shall mean the union of the acting entity and all
56 | other entities that control, are controlled by, or are under common
57 | control with that entity. For the purposes of this definition,
58 | "control" means (i) the power, direct or indirect, to cause the
59 | direction or management of such entity, whether by contract or
60 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
61 | outstanding shares, or (iii) beneficial ownership of such entity.
62 |
63 | "You" (or "Your") shall mean an individual or Legal Entity
64 | exercising permissions granted by this License.
65 |
66 | "Source" form shall mean the preferred form for making modifications,
67 | including but not limited to software source code, documentation
68 | source, and configuration files.
69 |
70 | "Object" form shall mean any form resulting from mechanical
71 | transformation or translation of a Source form, including but
72 | not limited to compiled object code, generated documentation,
73 | and conversions to other media types.
74 |
75 | "Work" shall mean the work of authorship, whether in Source or
76 | Object form, made available under the License, as indicated by a
77 | copyright notice that is included in or attached to the work
78 | (an example is provided in the Appendix below).
79 |
80 | "Derivative Works" shall mean any work, whether in Source or Object
81 | form, that is based on (or derived from) the Work and for which the
82 | editorial revisions, annotations, elaborations, or other modifications
83 | represent, as a whole, an original work of authorship. For the purposes
84 | of this License, Derivative Works shall not include works that remain
85 | separable from, or merely link (or bind by name) to the interfaces of,
86 | the Work and Derivative Works thereof.
87 |
88 | "Contribution" shall mean any work of authorship, including
89 | the original version of the Work and any modifications or additions
90 | to that Work or Derivative Works thereof, that is intentionally
91 | submitted to Licensor for inclusion in the Work by the copyright owner
92 | or by an individual or Legal Entity authorized to submit on behalf of
93 | the copyright owner. For the purposes of this definition, "submitted"
94 | means any form of electronic, verbal, or written communication sent
95 | to the Licensor or its representatives, including but not limited to
96 | communication on electronic mailing lists, source code control systems,
97 | and issue tracking systems that are managed by, or on behalf of, the
98 | Licensor for the purpose of discussing and improving the Work, but
99 | excluding communication that is conspicuously marked or otherwise
100 | designated in writing by the copyright owner as "Not a Contribution."
101 |
102 | "Contributor" shall mean Licensor and any individual or Legal Entity
103 | on behalf of whom a Contribution has been received by Licensor and
104 | subsequently incorporated within the Work.
105 |
106 | 2. Grant of Copyright License. Subject to the terms and conditions of
107 | this License, each Contributor hereby grants to You a perpetual,
108 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
109 | copyright license to reproduce, prepare Derivative Works of,
110 | publicly display, publicly perform, sublicense, and distribute the
111 | Work and such Derivative Works in Source or Object form.
112 |
113 | 3. Grant of Patent License. Subject to the terms and conditions of
114 | this License, each Contributor hereby grants to You a perpetual,
115 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
116 | (except as stated in this section) patent license to make, have made,
117 | use, offer to sell, sell, import, and otherwise transfer the Work,
118 | where such license applies only to those patent claims licensable
119 | by such Contributor that are necessarily infringed by their
120 | Contribution(s) alone or by combination of their Contribution(s)
121 | with the Work to which such Contribution(s) was submitted. If You
122 | institute patent litigation against any entity (including a
123 | cross-claim or counterclaim in a lawsuit) alleging that the Work
124 | or a Contribution incorporated within the Work constitutes direct
125 | or contributory patent infringement, then any patent licenses
126 | granted to You under this License for that Work shall terminate
127 | as of the date such litigation is filed.
128 |
129 | 4. Redistribution. You may reproduce and distribute copies of the
130 | Work or Derivative Works thereof in any medium, with or without
131 | modifications, and in Source or Object form, provided that You
132 | meet the following conditions:
133 |
134 | (a) You must give any other recipients of the Work or
135 | Derivative Works a copy of this License; and
136 |
137 | (b) You must cause any modified files to carry prominent notices
138 | stating that You changed the files; and
139 |
140 | (c) You must retain, in the Source form of any Derivative Works
141 | that You distribute, all copyright, patent, trademark, and
142 | attribution notices from the Source form of the Work,
143 | excluding those notices that do not pertain to any part of
144 | the Derivative Works; and
145 |
146 | (d) If the Work includes a "NOTICE" text file as part of its
147 | distribution, then any Derivative Works that You distribute must
148 | include a readable copy of the attribution notices contained
149 | within such NOTICE file, excluding those notices that do not
150 | pertain to any part of the Derivative Works, in at least one
151 | of the following places: within a NOTICE text file distributed
152 | as part of the Derivative Works; within the Source form or
153 | documentation, if provided along with the Derivative Works; or,
154 | within a display generated by the Derivative Works, if and
155 | wherever such third-party notices normally appear. The contents
156 | of the NOTICE file are for informational purposes only and
157 | do not modify the License. You may add Your own attribution
158 | notices within Derivative Works that You distribute, alongside
159 | or as an addendum to the NOTICE text from the Work, provided
160 | that such additional attribution notices cannot be construed
161 | as modifying the License.
162 |
163 | You may add Your own copyright statement to Your modifications and
164 | may provide additional or different license terms and conditions
165 | for use, reproduction, or distribution of Your modifications, or
166 | for any such Derivative Works as a whole, provided Your use,
167 | reproduction, and distribution of the Work otherwise complies with
168 | the conditions stated in this License.
169 |
170 | 5. Submission of Contributions. Unless You explicitly state otherwise,
171 | any Contribution intentionally submitted for inclusion in the Work
172 | by You to the Licensor shall be under the terms and conditions of
173 | this License, without any additional terms or conditions.
174 | Notwithstanding the above, nothing herein shall supersede or modify
175 | the terms of any separate license agreement you may have executed
176 | with Licensor regarding such Contributions.
177 |
178 | 6. Trademarks. This License does not grant permission to use the trade
179 | names, trademarks, service marks, or product names of the Licensor,
180 | except as required for reasonable and customary use in describing the
181 | origin of the Work and reproducing the content of the NOTICE file.
182 |
183 | 7. Disclaimer of Warranty. Unless required by applicable law or
184 | agreed to in writing, Licensor provides the Work (and each
185 | Contributor provides its Contributions) on an "AS IS" BASIS,
186 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
187 | implied, including, without limitation, any warranties or conditions
188 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
189 | PARTICULAR PURPOSE. You are solely responsible for determining the
190 | appropriateness of using or redistributing the Work and assume any
191 | risks associated with Your exercise of permissions under this License.
192 |
193 | 8. Limitation of Liability. In no event and under no legal theory,
194 | whether in tort (including negligence), contract, or otherwise,
195 | unless required by applicable law (such as deliberate and grossly
196 | negligent acts) or agreed to in writing, shall any Contributor be
197 | liable to You for damages, including any direct, indirect, special,
198 | incidental, or consequential damages of any character arising as a
199 | result of this License or out of the use or inability to use the
200 | Work (including but not limited to damages for loss of goodwill,
201 | work stoppage, computer failure or malfunction, or any and all
202 | other commercial damages or losses), even if such Contributor
203 | has been advised of the possibility of such damages.
204 |
205 | 9. Accepting Warranty or Additional Liability. While redistributing
206 | the Work or Derivative Works thereof, You may choose to offer,
207 | and charge a fee for, acceptance of support, warranty, indemnity,
208 | or other liability obligations and/or rights consistent with this
209 | License. However, in accepting such obligations, You may act only
210 | on Your own behalf and on Your sole responsibility, not on behalf
211 | of any other Contributor, and only if You agree to indemnify,
212 | defend, and hold each Contributor harmless for any liability
213 | incurred by, or claims asserted against, such Contributor by reason
214 | of your accepting any such warranty or additional liability.
215 |
216 | END OF TERMS AND CONDITIONS
217 |
--------------------------------------------------------------------------------
/test/pettingzoo_api_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from pettingzoo.butterfly import (
4 | cooperative_pong_v5,
5 | knights_archers_zombies_v10,
6 | pistonball_v6,
7 | )
8 | from pettingzoo.classic import connect_four_v3
9 | from pettingzoo.mpe import simple_push_v3, simple_spread_v3, simple_world_comm_v3
10 | from pettingzoo.sisl import pursuit_v4
11 | from pettingzoo.test import api_test, parallel_api_test, seed_test
12 | from pettingzoo.utils.all_modules import (
13 | atari_environments,
14 | butterfly_environments,
15 | classic_environments,
16 | mpe_environments,
17 | sisl_environments,
18 | )
19 |
20 | import supersuit
21 | from supersuit import (
22 | dtype_v0,
23 | frame_skip_v0,
24 | frame_stack_v2,
25 | pad_action_space_v0,
26 | sticky_actions_v0,
27 | )
28 | from supersuit.utils.convert_box import convert_box
29 |
30 |
31 | atari = list(atari_environments.values())
32 | butterfly = list(butterfly_environments.values())
33 | classic = list(classic_environments.values())
34 | mpe = list(mpe_environments.values())
35 | sisl = list(sisl_environments.values())
36 | all = atari + butterfly + classic + mpe + sisl
37 |
38 | BUTTERFLY_MPE_CLASSIC = [
39 | knights_archers_zombies_v10,
40 | simple_push_v3,
41 | connect_four_v3,
42 | simple_spread_v3,
43 | ]
44 | BUTTERFLY_MPE = [knights_archers_zombies_v10, simple_push_v3, simple_spread_v3]
45 |
46 |
47 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
48 | def test_frame_stack(env_fn):
49 | _env = env_fn.env()
50 | wrapped_env = frame_stack_v2(_env)
51 | api_test(wrapped_env)
52 |
53 |
54 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
55 | def test_frame_stack_parallel(env_fn):
56 | _env = env_fn.parallel_env()
57 | wrapped_env = frame_stack_v2(_env)
58 | parallel_api_test(wrapped_env)
59 |
60 |
61 | @pytest.mark.parametrize("env_fn", [simple_push_v3])
62 | def test_frame_skip(env_fn):
63 | env = env_fn.raw_env(max_cycles=100)
64 | env = frame_skip_v0(env, 3)
65 | env.reset()
66 | x = 0
67 | for _ in env.agent_iter(25):
68 | assert env.unwrapped.steps == (x // 2) * 3
69 | action = env.action_space(env.agent_selection).sample()
70 | env.step(action)
71 | x += 1
72 |
73 |
74 | @pytest.mark.parametrize("env_fn", [simple_push_v3])
75 | def test_frame_skip_parallel(env_fn):
76 | env = env_fn.parallel_env(max_cycles=100)
77 | env = frame_skip_v0(env, 3)
78 | env.reset()
79 | x = 0
80 | while env.agents:
81 | assert env.unwrapped.steps == (x // 2) * 3
82 | actions = {agent: env.action_space(agent).sample() for agent in env.agents}
83 | env.step(actions)
84 | x += env.num_agents
85 |
86 |
87 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
88 | def test_pad_action_space(env_fn):
89 | _env = env_fn.env()
90 | wrapped_env = pad_action_space_v0(_env)
91 | api_test(wrapped_env)
92 | seed_test(lambda: sticky_actions_v0(simple_world_comm_v3.env(), 0.5), 100)
93 |
94 |
95 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
96 | def test_pad_action_space_parallel(env_fn):
97 | _env = env_fn.parallel_env()
98 | wrapped_env = pad_action_space_v0(_env)
99 | parallel_api_test(wrapped_env)
100 |
101 |
102 | @pytest.mark.parametrize(
103 | "env_fn", atari + [pistonball_v6, cooperative_pong_v5, pursuit_v4]
104 | )
105 | def test_color_reduction(env_fn):
106 | env = supersuit.color_reduction_v0(env_fn.env(), "R")
107 | api_test(env)
108 |
109 |
110 | @pytest.mark.parametrize(
111 | "env_fn", atari + [pistonball_v6, cooperative_pong_v5, pursuit_v4]
112 | )
113 | def test_color_reduction_parallel(env_fn):
114 | env = supersuit.color_reduction_v0(env_fn.parallel_env(), "R")
115 | parallel_api_test(env)
116 |
117 |
118 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10])
119 | @pytest.mark.parametrize(
120 | "wrapper_kwargs",
121 | [dict(x_size=5, y_size=10), dict(x_size=5, y_size=10, linear_interp=True)],
122 | )
123 | def test_resize_dtype(env_fn, wrapper_kwargs):
124 | env = supersuit.resize_v1(
125 | dtype_v0(env_fn.env(vector_state=False), np.uint8), **wrapper_kwargs
126 | )
127 | api_test(env)
128 |
129 |
130 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10])
131 | @pytest.mark.parametrize(
132 | "wrapper_kwargs",
133 | [dict(x_size=5, y_size=10), dict(x_size=5, y_size=10, linear_interp=True)],
134 | )
135 | def test_resize_dtype_parallel(env_fn, wrapper_kwargs):
136 | env = supersuit.resize_v1(
137 | dtype_v0(env_fn.parallel_env(vector_state=False), np.uint8), **wrapper_kwargs
138 | )
139 | parallel_api_test(env)
140 |
141 |
142 | @pytest.mark.parametrize(
143 | "env_fn",
144 | atari
145 | + butterfly
146 | + [v for k, v in sisl_environments.items() if k != "sisl/multiwalker_v9"],
147 | )
148 | def test_dtype(env_fn):
149 | env = supersuit.dtype_v0(env_fn.env(), np.int32)
150 | api_test(env)
151 |
152 |
153 | @pytest.mark.parametrize("env_fn", atari + butterfly + sisl)
154 | def test_dtype_parallel(env_fn):
155 | env = supersuit.dtype_v0(env_fn.parallel_env(), np.int32)
156 | parallel_api_test(env)
157 |
158 |
159 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
160 | def test_flatten(env_fn):
161 | env = supersuit.flatten_v0(knights_archers_zombies_v10.env())
162 | api_test(env)
163 |
164 |
165 | # Classic environments don't have parallel envs so this doesn't apply
166 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
167 | def test_flatten_parallel(env_fn):
168 | env = supersuit.flatten_v0(env_fn.parallel_env())
169 | parallel_api_test(env)
170 |
171 |
172 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10])
173 | def test_reshape(env_fn):
174 | env = supersuit.reshape_v0(env_fn.env(vector_state=False), (512 * 512, 3))
175 | api_test(env)
176 |
177 |
178 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10])
179 | def test_reshape_parallel(env_fn):
180 | env = supersuit.reshape_v0(env_fn.parallel_env(vector_state=False), (512 * 512, 3))
181 | parallel_api_test(env)
182 |
183 |
184 | # MPE environment has infinite bounds for observation space (only environments with finite bounds can be passed to normalize_obs)
185 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10])
186 | def test_normalize_obs(env_fn):
187 | env = supersuit.normalize_obs_v0(
188 | dtype_v0(env_fn.env(), np.float32), env_min=-1, env_max=5.0
189 | )
190 | api_test(env)
191 |
192 |
193 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10])
194 | def test_normalize_obs_parallel(env_fn):
195 | env = supersuit.normalize_obs_v0(
196 | dtype_v0(env_fn.parallel_env(), np.float32), env_min=-1, env_max=5.0
197 | )
198 | parallel_api_test(env)
199 |
200 |
201 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
202 | def test_pad_observations(env_fn):
203 | env = supersuit.pad_observations_v0(env_fn.env())
204 | api_test(env)
205 |
206 |
207 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
208 | def test_pad_observations_parallel(env_fn):
209 | env = supersuit.pad_observations_v0(env_fn.parallel_env())
210 | parallel_api_test(env)
211 |
212 |
213 | @pytest.mark.skip(
214 | reason="Black death wrapper is only designed for parallel envs, AEC envs should simply skip the agent by setting env.agent_selection manually"
215 | )
216 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
217 | def test_black_death(env_fn):
218 | env = supersuit.black_death_v3(env_fn.env())
219 | api_test(env)
220 |
221 |
222 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
223 | def test_black_death_parallel(env_fn):
224 | env = supersuit.black_death_v3(env_fn.parallel_env())
225 | parallel_api_test(env)
226 |
227 |
228 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10])
229 | @pytest.mark.parametrize("env_kwargs", [dict(type_only=True), dict(type_only=False)])
230 | def test_agent_indicator(env_fn, env_kwargs):
231 | env = supersuit.agent_indicator_v0(env_fn.env(), **env_kwargs)
232 | api_test(env)
233 |
234 |
235 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10])
236 | @pytest.mark.parametrize("env_kwargs", [dict(type_only=True), dict(type_only=False)])
237 | def test_agent_indicator_parallel(env_fn, env_kwargs):
238 | env = supersuit.agent_indicator_v0(env_fn.parallel_env(), **env_kwargs)
239 | parallel_api_test(env)
240 |
241 |
242 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
243 | def test_reward_lambda(env_fn):
244 | env = supersuit.reward_lambda_v0(env_fn.env(), lambda x: x / 10)
245 | api_test(env)
246 |
247 |
248 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
249 | def test_reward_lambda_parallel(env_fn):
250 | env = supersuit.reward_lambda_v0(env_fn.parallel_env(), lambda x: x / 10)
251 | parallel_api_test(env)
252 |
253 |
254 | @pytest.mark.parametrize(
255 | "env_fn",
256 | [v for k, v in butterfly_environments.items() if k != "butterfly/pistonball_v6"]
257 | + mpe
258 | + sisl,
259 | )
260 | def test_observation_lambda(env_fn):
261 | env = supersuit.observation_lambda_v0(env_fn.env(), lambda obs, obs_space: obs - 1)
262 | api_test(env)
263 |
264 |
265 | # Example using observation lambda with an action masked environment (flattening the obs space while keeping action mask)
266 | @pytest.mark.parametrize("env_fn", [connect_four_v3])
267 | def test_observation_lambda_action_mask(env_fn):
268 | env = env_fn.env()
269 | env.reset()
270 | obs = env.observe(env.possible_agents[0])
271 |
272 | # Example: reshape the observation to flatten the first two dimensions: (6, 7, 2) -> (42, 2)
273 | newshape = obs["observation"].reshape((-1, 2)).shape
274 |
275 | def change_obs_space_fn(obs_space):
276 | obs_space["observation"] = convert_box(
277 | lambda obs: obs.reshape(newshape), old_box=obs_space["observation"]
278 | )
279 | return obs_space
280 |
281 | def change_observation_fn(observation, old_obs_space):
282 | # Reshape observation
283 | observation["observation"] = observation["observation"].reshape(newshape)
284 | # Invert the action mask (make illegal actions legal, and vice versa)
285 | observation["action_mask"] = 1 - observation["action_mask"]
286 | return observation
287 |
288 | env = supersuit.observation_lambda_v0(
289 | env_fn.env(),
290 | change_obs_space_fn=change_obs_space_fn,
291 | change_observation_fn=change_observation_fn,
292 | )
293 |
294 | env.reset()
295 | obs = env.observe(env.possible_agents[0])
296 | assert obs["observation"].shape == (
297 | 42,
298 | 2,
299 | ), "New observation should be shape (42, 2)"
300 | assert np.array_equal(
301 | obs["action_mask"], np.zeros(7)
302 | ), "New action mask should be all zeros"
303 | assert env.observation_space(env.possible_agents[0])["observation"].shape == (
304 | 42,
305 | 2,
306 | ), "Observation space should be (42, 2)"
307 |
308 | api_test(env)
309 |
310 |
311 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
312 | def test_observation_lambda_parallel(env_fn):
313 | pass
314 |
315 |
316 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
317 | def test_clip_reward(env_fn):
318 | env = supersuit.clip_reward_v0(env_fn.env())
319 | api_test(env)
320 |
321 |
322 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
323 | def test_clip_reward_parallel(env_fn):
324 | env = supersuit.clip_reward_v0(env_fn.parallel_env())
325 | parallel_api_test(env)
326 |
327 |
328 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
329 | def test_nan_noop(env_fn):
330 | env = supersuit.nan_noop_v0(env_fn.env(), 0)
331 | api_test(env)
332 |
333 |
334 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
335 | def test_nan_noop_parallel(env_fn):
336 | env = supersuit.nan_noop_v0(env_fn.parallel_env(), 0)
337 | parallel_api_test(env)
338 |
339 |
340 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
341 | def test_nan_zeros(env_fn):
342 | env = supersuit.nan_zeros_v0(env_fn.env())
343 | api_test(env)
344 |
345 |
346 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
347 | def test_nan_zeros_parallel(env_fn):
348 | env = supersuit.nan_zeros_v0(env_fn.parallel_env())
349 | parallel_api_test(env)
350 |
351 |
352 | # Note: hanabi v5 fails here
353 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
354 | def test_nan_random(env_fn):
355 | env = supersuit.nan_random_v0(env_fn.env())
356 | api_test(env)
357 |
358 |
359 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
360 | def test_nan_random_parallel(env_fn):
361 | env = supersuit.nan_random_v0(env_fn.parallel_env())
362 | parallel_api_test(env)
363 |
364 |
365 | # Note: hanabi v5 fails here
366 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
367 | def test_sticky_actions(env_fn):
368 | env = supersuit.sticky_actions_v0(env_fn.env(), 0.75)
369 | api_test(env)
370 |
371 |
372 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
373 | def test_sticky_actions_parallel(env_fn):
374 | env = supersuit.sticky_actions_v0(env_fn.parallel_env(), 0.75)
375 | parallel_api_test(env)
376 |
377 |
378 | # Note: hanabi_v5 and texas_holdem_v4 fail here
379 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
380 | def test_delay_observations(env_fn):
381 | env = supersuit.delay_observations_v0(env_fn.env(), 3)
382 | api_test(env)
383 |
384 |
385 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
386 | def test_delay_observations_parallel(env_fn):
387 | env = supersuit.delay_observations_v0(env_fn.parallel_env(), 3)
388 | parallel_api_test(env)
389 |
390 |
391 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl)
392 | def test_max_observation(env_fn):
393 | env = supersuit.max_observation_v0(knights_archers_zombies_v10.env(), 3)
394 | api_test(env)
395 |
396 |
397 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl)
398 | def test_max_observation_parallel(env_fn):
399 | env = supersuit.max_observation_v0(knights_archers_zombies_v10.parallel_env(), 3)
400 | parallel_api_test(env)
401 |
--------------------------------------------------------------------------------