├── test ├── __init__.py ├── test_vector │ ├── __init__.py │ ├── test_pettingzoo_to_vec_gymnasium_wrappers.py │ ├── test_render.py │ ├── test_env_is_wrapped.py │ ├── test_aec_vector_identity_env.py │ ├── test_aec_vector_values.py │ ├── test_vector_dict.py │ ├── test_pettingzoo_to_vec.py │ └── test_gym_vector.py ├── test_utils │ ├── test_basic_transforms │ │ ├── __init__.py │ │ ├── test_flatten.py │ │ ├── test_dtype.py │ │ ├── test_color_reduction.py │ │ ├── test_reshape.py │ │ ├── test_normalized_obs.py │ │ └── test_resize.py │ ├── test_action_transforms │ │ └── test_homogenize.py │ ├── test_frame_stack.py │ └── test_agent_indicator.py ├── test_autodep.py ├── pettingzoo_env_test.py ├── dummy_gym_env.py ├── vec_env_test.py ├── dummy_aec_env.py ├── parallel_env_test.py ├── gym_unwrapped_test.py ├── gym_mock_test.py ├── aec_unwrapped_test.py ├── generated_agents_test.py └── pettingzoo_api_test.py ├── supersuit ├── utils │ ├── __init__.py │ ├── action_transforms │ │ ├── __init__.py │ │ └── homogenize_ops.py │ ├── convert_box.py │ ├── make_defaultdict.py │ ├── basic_transforms │ │ ├── dtype.py │ │ ├── __init__.py │ │ ├── flatten.py │ │ ├── reshape.py │ │ ├── resize.py │ │ ├── color_reduction.py │ │ └── normalize_obs.py │ ├── frame_skip.py │ ├── accumulator.py │ ├── obs_delay.py │ ├── base_aec_wrapper.py │ ├── wrapper_chooser.py │ ├── frame_stack.py │ └── agent_indicator.py ├── vector │ ├── utils │ │ ├── __init__.py │ │ ├── space_wrapper.py │ │ └── shared_array.py │ ├── __init__.py │ ├── sb_vector_wrapper.py │ ├── constructors.py │ ├── single_vec_env.py │ ├── sb3_vector_wrapper.py │ ├── vector_constructors.py │ ├── concat_vec_env.py │ ├── markov_vector_wrapper.py │ └── multiproc_vec.py ├── generic_wrappers │ ├── utils │ │ ├── __init__.py │ │ ├── base_modifier.py │ │ └── shared_wrapper_util.py │ ├── delay_observations.py │ ├── __init__.py │ ├── max_observation.py │ ├── sticky_actions.py │ ├── basic_wrappers.py │ ├── nan_wrappers.py │ ├── frame_stack.py │ └── frame_skip.py ├── aec_vector │ ├── __init__.py │ ├── create.py │ ├── base_aec_vec_env.py │ └── vector_env.py ├── lambda_wrappers │ ├── __init__.py │ ├── reward_lambda.py │ ├── action_lambda.py │ └── observation_lambda.py ├── multiagent_wrappers │ ├── __init__.py │ ├── agent_indication.py │ ├── padding_wrappers.py │ └── black_death.py └── __init__.py ├── .github ├── FUNDING.yml └── workflows │ ├── pre-commit.yml │ ├── linux-test.yml │ └── build-publish.yml ├── .gitignore ├── supersuit-text.png ├── setup.py ├── .pre-commit-config.yaml ├── README.md ├── pyproject.toml └── LICENSE /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /supersuit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/test_vector/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /supersuit/vector/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /supersuit/generic_wrappers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /supersuit/utils/action_transforms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: Farama-Foundation 2 | -------------------------------------------------------------------------------- /test/test_utils/test_basic_transforms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | pettingzoo 3 | dist/* 4 | build/* 5 | SuperSuit.egg-info/* 6 | venv/ 7 | -------------------------------------------------------------------------------- /supersuit-text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Farama-Foundation/SuperSuit/HEAD/supersuit-text.png -------------------------------------------------------------------------------- /supersuit/aec_vector/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_aec_vec_env import VectorAECEnv # NOQA 2 | from .create import vectorize_aec_env_v0 # NOQA 3 | -------------------------------------------------------------------------------- /supersuit/lambda_wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .action_lambda import action_lambda_v1 # NOQA 2 | from .observation_lambda import observation_lambda_v0 # NOQA 3 | from .reward_lambda import reward_lambda_v0 # NOQA 4 | -------------------------------------------------------------------------------- /supersuit/multiagent_wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent_indication import agent_indicator_v0 # NOQA 2 | from .black_death import black_death_v3 # NOQA 3 | from .padding_wrappers import pad_action_space_v0, pad_observations_v0 # NOQA 4 | -------------------------------------------------------------------------------- /supersuit/utils/convert_box.py: -------------------------------------------------------------------------------- 1 | from gymnasium.spaces import Box 2 | 3 | 4 | def convert_box(convert_obs_fn, old_box): 5 | new_low = convert_obs_fn(old_box.low) 6 | new_high = convert_obs_fn(old_box.high) 7 | return Box(low=new_low, high=new_high, dtype=new_low.dtype) 8 | -------------------------------------------------------------------------------- /supersuit/vector/__init__.py: -------------------------------------------------------------------------------- 1 | from .concat_vec_env import ConcatVecEnv # NOQA 2 | from .constructors import MakeCPUAsyncConstructor # NOQA 3 | from .markov_vector_wrapper import MarkovVectorEnv # NOQA 4 | from .multiproc_vec import ProcConcatVec # NOQA 5 | from .single_vec_env import SingleVecEnv # NOQA 6 | -------------------------------------------------------------------------------- /test/test_autodep.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import supersuit 4 | 5 | 6 | def test_bad_import(): 7 | with pytest.raises(supersuit.DeprecatedWrapper): 8 | from supersuit import action_lambda_v0 # noqa: F401 9 | with pytest.raises(supersuit.DeprecatedWrapper): 10 | supersuit.action_lambda_v0 11 | -------------------------------------------------------------------------------- /supersuit/utils/make_defaultdict.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import defaultdict 3 | 4 | 5 | def make_defaultdict(d): 6 | try: 7 | dd = defaultdict(type(next(iter(d.values())))) 8 | for k, v in d.items(): 9 | dd[k] = v 10 | return dd 11 | except StopIteration: 12 | warnings.warn("No agents left in the environment!") 13 | return {} 14 | -------------------------------------------------------------------------------- /test/pettingzoo_env_test.py: -------------------------------------------------------------------------------- 1 | from pettingzoo.mpe import simple_spread_v3 2 | from pettingzoo.test import parallel_api_test 3 | 4 | from supersuit.multiagent_wrappers import pad_action_space_v0 5 | 6 | 7 | def test_pad_actuon_space(): 8 | env = simple_spread_v3.parallel_env(max_cycles=25, continuous_actions=True) 9 | env = pad_action_space_v0(env) 10 | 11 | parallel_api_test(env, num_cycles=100) 12 | -------------------------------------------------------------------------------- /supersuit/utils/basic_transforms/dtype.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from . import convert_box 4 | 5 | 6 | def check_param(obs_space, new_dtype): 7 | np.dtype(new_dtype) # type argument must be convertible to a numpy dtype 8 | 9 | 10 | def change_obs_space(obs_space, param): 11 | return convert_box(lambda obs: change_observation(obs, obs_space, param), obs_space) 12 | 13 | 14 | def change_observation(obs, obs_space, new_dtype): 15 | obs = obs.astype(new_dtype) 16 | return obs 17 | -------------------------------------------------------------------------------- /supersuit/utils/basic_transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from gymnasium.spaces import Box 2 | 3 | 4 | def convert_box(convert_obs_fn, old_box): 5 | new_low = convert_obs_fn(old_box.low) 6 | new_high = convert_obs_fn(old_box.high) 7 | return Box(low=new_low, high=new_high, dtype=new_low.dtype) 8 | 9 | 10 | from . import color_reduction # NOQA 11 | from . import dtype # NOQA 12 | from . import flatten # NOQA 13 | from . import normalize_obs # NOQA 14 | from . import reshape # NOQA 15 | from . import resize # NOQA 16 | -------------------------------------------------------------------------------- /test/dummy_gym_env.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | 3 | 4 | class DummyEnv(gymnasium.Env): 5 | def __init__(self, observation, observation_space, action_space): 6 | super().__init__() 7 | self._observation = observation 8 | self.observation_space = observation_space 9 | self.action_space = action_space 10 | 11 | def step(self, action): 12 | return self._observation, 1, False, False, {} 13 | 14 | def reset(self, seed=None, options=None): 15 | return self._observation, {} 16 | -------------------------------------------------------------------------------- /supersuit/vector/utils/space_wrapper.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import numpy as np 3 | 4 | 5 | class SpaceWrapper: 6 | def __init__(self, space): 7 | if isinstance(space, gymnasium.spaces.Discrete): 8 | self.shape = () 9 | self.dtype = np.dtype(np.int64) 10 | elif isinstance(space, gymnasium.spaces.Box): 11 | self.shape = space.shape 12 | self.dtype = np.dtype(space.dtype) 13 | else: 14 | assert False, "ProcVectorEnv only support Box and Discrete types" 15 | -------------------------------------------------------------------------------- /supersuit/utils/basic_transforms/flatten.py: -------------------------------------------------------------------------------- 1 | from . import convert_box 2 | 3 | 4 | def check_param(obs_space, should_flatten): 5 | assert isinstance( 6 | should_flatten, bool 7 | ), f"should_flatten must be bool. It is {should_flatten}" 8 | 9 | 10 | def change_obs_space(obs_space, param): 11 | return convert_box(lambda obs: change_observation(obs, obs_space, param), obs_space) 12 | 13 | 14 | def change_observation(obs, obs_space, should_flatten): 15 | if should_flatten: 16 | obs = obs.flatten() 17 | return obs 18 | -------------------------------------------------------------------------------- /supersuit/utils/frame_skip.py: -------------------------------------------------------------------------------- 1 | def check_transform_frameskip(frame_skip): 2 | if ( 3 | isinstance(frame_skip, tuple) 4 | and len(frame_skip) == 2 5 | and isinstance(frame_skip[0], int) 6 | and isinstance(frame_skip[1], int) 7 | and 1 <= frame_skip[0] <= frame_skip[1] 8 | ): 9 | return frame_skip 10 | elif isinstance(frame_skip, int): 11 | return (frame_skip, frame_skip) 12 | else: 13 | assert ( 14 | False 15 | ), "frame_skip must be an int or a tuple of two ints, where the first values is at least one and not greater than the second" 16 | -------------------------------------------------------------------------------- /supersuit/generic_wrappers/delay_observations.py: -------------------------------------------------------------------------------- 1 | from supersuit.utils.obs_delay import Delayer 2 | 3 | from .utils.base_modifier import BaseModifier 4 | from .utils.shared_wrapper_util import shared_wrapper 5 | 6 | 7 | def delay_observations_v0(env, delay): 8 | class DelayObsModifier(BaseModifier): 9 | def reset(self, seed=None, options=None): 10 | self.delayer = Delayer(self.observation_space, delay) 11 | 12 | def modify_obs(self, obs): 13 | obs = self.delayer.add(obs) 14 | return BaseModifier.modify_obs(self, obs) 15 | 16 | return shared_wrapper(env, DelayObsModifier) 17 | -------------------------------------------------------------------------------- /test/vec_env_test.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import numpy as np 3 | 4 | from supersuit import gym_vec_env_v0 5 | 6 | 7 | def test_vec_env_args(): 8 | env = gymnasium.make("Acrobot-v1") 9 | num_envs = 8 10 | vec_env = gym_vec_env_v0(env, num_envs) 11 | vec_env.reset() 12 | obs, rew, terminations, truncations, infos = vec_env.step( 13 | [0] + [1] * (vec_env.num_envs - 1) 14 | ) 15 | assert not np.any(np.equal(obs[0], obs[1])) 16 | 17 | 18 | def test_all_vec_env_fns(): 19 | num_envs = 8 20 | env = gymnasium.make("Acrobot-v1") 21 | gym_vec_env_v0(env, num_envs, False) 22 | gym_vec_env_v0(env, num_envs, True) 23 | -------------------------------------------------------------------------------- /supersuit/generic_wrappers/utils/base_modifier.py: -------------------------------------------------------------------------------- 1 | class BaseModifier: 2 | def __init__(self): 3 | pass 4 | 5 | def reset(self, seed=None, options=None): 6 | pass 7 | 8 | def modify_obs(self, obs): 9 | self.cur_obs = obs 10 | return obs 11 | 12 | def get_last_obs(self): 13 | return self.cur_obs 14 | 15 | def modify_obs_space(self, obs_space): 16 | self.observation_space = obs_space 17 | return obs_space 18 | 19 | def modify_action(self, act): 20 | return act 21 | 22 | def modify_action_space(self, act_space): 23 | self.action_space = act_space 24 | return act_space 25 | -------------------------------------------------------------------------------- /supersuit/generic_wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic_wrappers import clip_reward_v0 # NOQA 2 | from .basic_wrappers import ( 3 | clip_actions_v0, 4 | color_reduction_v0, 5 | dtype_v0, 6 | flatten_v0, 7 | normalize_obs_v0, 8 | reshape_v0, 9 | resize_v1, 10 | scale_actions_v0, 11 | ) 12 | from .delay_observations import delay_observations_v0 # NOQA 13 | from .frame_skip import frame_skip_v0 # NOQA 14 | from .frame_stack import frame_stack_v1, frame_stack_v2 # NOQA 15 | from .max_observation import max_observation_v0 # NOQA 16 | from .nan_wrappers import nan_noop_v0, nan_random_v0, nan_zeros_v0 # NOQA 17 | from .sticky_actions import sticky_actions_v0 # NOQA 18 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | # https://pre-commit.com 2 | # This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file. 3 | name: pre-commit 4 | on: 5 | pull_request: 6 | push: 7 | branches: [master] 8 | 9 | permissions: 10 | contents: read # to fetch code (actions/checkout) 11 | 12 | jobs: 13 | pre-commit: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v3 17 | - uses: actions/setup-python@v4 18 | - run: python -m pip install pre-commit 19 | - run: python -m pre_commit --version 20 | - run: python -m pre_commit install 21 | - run: python -m pre_commit run --all-files --show-diff-on-failure 22 | -------------------------------------------------------------------------------- /supersuit/generic_wrappers/max_observation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from supersuit.utils.accumulator import Accumulator 4 | 5 | from .utils.base_modifier import BaseModifier 6 | from .utils.shared_wrapper_util import shared_wrapper 7 | 8 | 9 | def max_observation_v0(env, memory): 10 | int(memory) # delay must be an int 11 | 12 | class MaxObsModifier(BaseModifier): 13 | def reset(self, seed=None, options=None): 14 | self.accumulator = Accumulator(self.observation_space, memory, np.maximum) 15 | 16 | def modify_obs(self, obs): 17 | self.accumulator.add(obs) 18 | return super().modify_obs(self.accumulator.get()) 19 | 20 | return shared_wrapper(env, MaxObsModifier) 21 | -------------------------------------------------------------------------------- /test/test_utils/test_basic_transforms/test_flatten.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from gymnasium.spaces import Box 4 | 5 | from supersuit.utils.basic_transforms.flatten import change_observation, check_param 6 | 7 | 8 | test_obs_space = Box( 9 | low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 3), dtype=np.float32 10 | ) 11 | test_obs = np.zeros([4, 4, 3], dtype=np.float64) + np.arange(3) 12 | 13 | 14 | def test_param_check(): 15 | check_param(test_obs_space, True) 16 | with pytest.raises(AssertionError): 17 | check_param(test_obs_space, 6) 18 | 19 | 20 | def test_change_observation(): 21 | new_obs = change_observation(test_obs, test_obs_space, True) 22 | assert new_obs.shape == (4 * 4 * 3,) 23 | -------------------------------------------------------------------------------- /supersuit/utils/accumulator.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from functools import reduce 3 | 4 | import numpy as np 5 | 6 | 7 | class Accumulator: 8 | def __init__(self, obs_space, memory, reduction): 9 | self.memory = memory 10 | self._obs_buffer = deque() 11 | self.reduction = reduction 12 | self.maxed_val = None 13 | 14 | def add(self, in_obs): 15 | self._obs_buffer.append(np.copy(in_obs)) 16 | if len(self._obs_buffer) > self.memory: 17 | self._obs_buffer.popleft() 18 | self.maxed_val = None 19 | 20 | def get(self): 21 | if self.maxed_val is None: 22 | self.maxed_val = reduce(self.reduction, (self._obs_buffer)) 23 | return self.maxed_val 24 | -------------------------------------------------------------------------------- /supersuit/vector/utils/shared_array.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | 3 | import numpy as np 4 | 5 | 6 | class SharedArray: 7 | def __init__(self, shape, dtype): 8 | self.shared_arr = mp.Array( 9 | np.ctypeslib.as_ctypes_type(dtype), int(np.prod(shape)), lock=False 10 | ) 11 | self.dtype = dtype 12 | self.shape = shape 13 | self._set_np_arr() 14 | 15 | def _set_np_arr(self): 16 | self.np_arr = np.frombuffer(self.shared_arr, dtype=self.dtype).reshape( 17 | self.shape 18 | ) 19 | 20 | def __getstate__(self): 21 | return (self.shared_arr, self.dtype, self.shape) 22 | 23 | def __setstate__(self, state): 24 | self.shared_arr, self.dtype, self.shape = state 25 | self._set_np_arr() 26 | -------------------------------------------------------------------------------- /test/test_utils/test_basic_transforms/test_dtype.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from gymnasium.spaces import Box 4 | 5 | from supersuit.utils.basic_transforms.dtype import change_observation, check_param 6 | 7 | 8 | test_obs_space = Box( 9 | low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 3), dtype=np.float32 10 | ) 11 | test_obs = np.zeros([4, 4, 3], dtype=np.float64) + np.arange(3) 12 | 13 | 14 | def test_param_check(): 15 | check_param(test_obs_space, np.uint8) 16 | check_param(test_obs_space, np.dtype("uint8")) 17 | with pytest.raises(TypeError): 18 | check_param(test_obs_space, 6) 19 | 20 | 21 | def test_change_observation(): 22 | new_obs = change_observation(test_obs, test_obs_space, np.float32) 23 | assert new_obs.dtype == np.float32 24 | -------------------------------------------------------------------------------- /supersuit/utils/basic_transforms/reshape.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from . import convert_box 4 | 5 | 6 | def check_param(obs_space, shape): 7 | assert isinstance(shape, tuple), f"shape must be tuple. It is {shape}" 8 | assert all( 9 | isinstance(el, int) for el in shape 10 | ), f"shape must be tuple of ints, is: {shape}" 11 | assert np.prod(shape) == np.prod( 12 | obs_space.shape 13 | ), "new shape {} must have as many elements as original shape {}".format( 14 | shape, obs_space.shape 15 | ) 16 | 17 | 18 | def change_obs_space(obs_space, param): 19 | return convert_box(lambda obs: change_observation(obs, obs_space, param), obs_space) 20 | 21 | 22 | def change_observation(obs, obs_space, shape): 23 | obs = obs.reshape(shape) 24 | return obs 25 | -------------------------------------------------------------------------------- /supersuit/aec_vector/create.py: -------------------------------------------------------------------------------- 1 | import cloudpickle 2 | from pettingzoo import AECEnv 3 | 4 | from .async_vector_env import AsyncAECVectorEnv 5 | from .vector_env import SyncAECVectorEnv 6 | 7 | 8 | def vectorize_aec_env_v0(aec_env, num_envs, num_cpus=0): 9 | assert isinstance( 10 | aec_env, AECEnv 11 | ), "pettingzoo_env_to_vec_env takes in a pettingzoo AECEnv." 12 | assert hasattr( 13 | aec_env, "possible_agents" 14 | ), "environment passed to vectorize_aec_env must have possible_agents attribute." 15 | 16 | def env_fn(): 17 | return cloudpickle.loads(cloudpickle.dumps(aec_env)) 18 | 19 | env_list = [env_fn] * num_envs 20 | 21 | if num_cpus == 0 or num_cpus == 1: 22 | return SyncAECVectorEnv(env_list) 23 | else: 24 | return AsyncAECVectorEnv(env_list, num_cpus) 25 | -------------------------------------------------------------------------------- /test/test_utils/test_basic_transforms/test_color_reduction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from gymnasium.spaces import Box 4 | 5 | from supersuit.utils.basic_transforms.color_reduction import ( 6 | change_observation, 7 | check_param, 8 | ) 9 | 10 | 11 | test_obs_space = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 3)) 12 | bad_test_obs_space = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 4)) 13 | test_obs = np.zeros([4, 4, 3]) + np.arange(3) 14 | 15 | 16 | def test_param_check(): 17 | with pytest.raises(AssertionError): 18 | check_param(test_obs_space, "bob") 19 | with pytest.raises(AssertionError): 20 | check_param(bad_test_obs_space, "R") 21 | check_param(test_obs_space, "G") 22 | 23 | 24 | def test_change_observation(): 25 | new_obs = change_observation(test_obs, test_obs_space, "B") 26 | assert np.all(np.equal(new_obs, 2 * np.ones([4, 4]))) 27 | -------------------------------------------------------------------------------- /supersuit/generic_wrappers/sticky_actions.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | 3 | from .utils.base_modifier import BaseModifier 4 | from .utils.shared_wrapper_util import shared_wrapper 5 | 6 | 7 | def sticky_actions_v0(env, repeat_action_probability): 8 | assert 0 <= repeat_action_probability < 1 9 | 10 | class StickyActionsModifier(BaseModifier): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def reset(self, seed=None, options=None): 15 | self.np_random, _ = gymnasium.utils.seeding.np_random(seed) 16 | self.old_action = None 17 | 18 | def modify_action(self, action): 19 | if ( 20 | self.old_action is not None 21 | and self.np_random.uniform() < repeat_action_probability 22 | ): 23 | action = self.old_action 24 | self.old_action = action 25 | return action 26 | 27 | return shared_wrapper(env, StickyActionsModifier) 28 | -------------------------------------------------------------------------------- /test/test_utils/test_basic_transforms/test_reshape.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from gymnasium.spaces import Box 4 | 5 | from supersuit.utils.basic_transforms.reshape import change_observation, check_param 6 | 7 | 8 | test_obs_space = Box( 9 | low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 3), dtype=np.float32 10 | ) 11 | test_obs = np.zeros([4, 4, 3], dtype=np.float64) + np.arange(3) 12 | 13 | 14 | def test_param_check(): 15 | check_param(test_obs_space, (8, 6)) 16 | with pytest.raises(AssertionError): 17 | check_param(test_obs_space, (8, 7)) 18 | with pytest.raises(AssertionError): 19 | check_param(test_obs_space, "bob") 20 | with pytest.raises(AssertionError): 21 | check_param(test_obs_space, ("bob", 5)) 22 | 23 | 24 | def test_change_observation(): 25 | new_obs = change_observation(test_obs, test_obs_space, (8, 6)) 26 | assert new_obs.shape == (8, 6) 27 | new_obs = change_observation(test_obs, test_obs_space, (4 * 4 * 3,)) 28 | assert new_obs.shape == (4 * 4 * 3,) 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from setuptools import setup 4 | 5 | 6 | CWD = pathlib.Path(__file__).absolute().parent 7 | 8 | 9 | def get_version(): 10 | """Gets the supersuit version.""" 11 | path = CWD / "supersuit" / "__init__.py" 12 | content = path.read_text() 13 | 14 | for line in content.splitlines(): 15 | if line.startswith("__version__"): 16 | return line.strip().split()[-1].strip().strip('"') 17 | raise RuntimeError("bad version data in __init__.py") 18 | 19 | 20 | def get_description(): 21 | """Gets the description from the readme.""" 22 | with open("README.md") as fh: 23 | long_description = "" 24 | header_count = 0 25 | for line in fh: 26 | if line.startswith("##"): 27 | header_count += 1 28 | if header_count < 2: 29 | long_description += line 30 | else: 31 | break 32 | return long_description 33 | 34 | 35 | setup(name="supersuit", version=get_version(), long_description=get_description()) 36 | -------------------------------------------------------------------------------- /supersuit/utils/obs_delay.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | 5 | 6 | class Delayer: 7 | def __init__(self, obs_space, delay): 8 | self.delay = delay 9 | self.obs_queue = deque() 10 | self.obs_space = obs_space 11 | 12 | def add(self, in_obs): 13 | self.obs_queue.append(in_obs) 14 | if len(self.obs_queue) > self.delay: 15 | return self.obs_queue.popleft() 16 | else: 17 | if isinstance(in_obs, np.ndarray): 18 | return np.zeros_like(in_obs) 19 | elif ( 20 | isinstance(in_obs, dict) 21 | and "observation" in in_obs.keys() 22 | and "action_mask" in in_obs.keys() 23 | ): 24 | return { 25 | "observation": np.zeros_like(in_obs["observation"]), 26 | "action_mask": np.ones_like(in_obs["action_mask"]), 27 | } 28 | else: 29 | raise TypeError( 30 | "Observation must be of type np.ndarray or dictionary with keys 'observation' and 'action_mask'" 31 | ) 32 | -------------------------------------------------------------------------------- /supersuit/utils/basic_transforms/resize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from . import convert_box 4 | 5 | 6 | def check_param(obs_space, resize): 7 | xsize, ysize, linear_interp = resize 8 | assert all( 9 | isinstance(ds, int) and ds > 0 for ds in [xsize, ysize] 10 | ), "resize x and y sizes must be integers greater than zero." 11 | assert isinstance( 12 | linear_interp, bool 13 | ), "resize linear_interp parameter must be bool." 14 | assert len(obs_space.shape) == 3 or len(obs_space.shape) == 2 15 | 16 | 17 | def change_obs_space(obs_space, param): 18 | return convert_box(lambda obs: change_observation(obs, obs_space, param), obs_space) 19 | 20 | 21 | def change_observation(obs, obs_space, resize): 22 | import tinyscaler 23 | 24 | xsize, ysize, linear_interp = resize 25 | if len(obs.shape) == 2: 26 | obs = obs.reshape(obs.shape + (1,)) 27 | interp_method = "bilinear" if linear_interp else "nearest" 28 | obs = tinyscaler.scale( 29 | src=np.ascontiguousarray(obs), size=(xsize, ysize), mode=interp_method 30 | ) 31 | if len(obs_space.shape) == 2: 32 | obs = obs.reshape(obs.shape[:2]) 33 | return obs 34 | -------------------------------------------------------------------------------- /supersuit/multiagent_wrappers/agent_indication.py: -------------------------------------------------------------------------------- 1 | from pettingzoo.utils.env import AECEnv, ParallelEnv 2 | 3 | from supersuit import observation_lambda_v0 4 | from supersuit.utils import agent_indicator as agent_ider 5 | 6 | 7 | def agent_indicator_v0(env, type_only=False): 8 | assert isinstance(env, AECEnv) or isinstance( 9 | env, ParallelEnv 10 | ), "agent_indicator_v0 only accepts an AECEnv or ParallelEnv" 11 | assert hasattr( 12 | env, "possible_agents" 13 | ), "environment passed to agent indicator wrapper must have the possible_agents attribute." 14 | 15 | indicator_map = agent_ider.get_indicator_map(env.possible_agents, type_only) 16 | num_indicators = len(set(indicator_map.values())) 17 | 18 | obs_spaces = [env.observation_space(agent) for agent in env.possible_agents] 19 | agent_ider.check_params(obs_spaces) 20 | 21 | return observation_lambda_v0( 22 | env, 23 | lambda obs, obs_space, agent: agent_ider.change_observation( 24 | obs, 25 | obs_space, 26 | (indicator_map[agent], num_indicators), 27 | ), 28 | lambda obs_space: agent_ider.change_obs_space(obs_space, num_indicators), 29 | ) 30 | -------------------------------------------------------------------------------- /supersuit/vector/sb_vector_wrapper.py: -------------------------------------------------------------------------------- 1 | from stable_baselines.common.vec_env.base_vec_env import VecEnv 2 | 3 | 4 | class SBVecEnvWrapper(VecEnv): 5 | def __init__(self, venv): 6 | self.venv = venv 7 | self.num_envs = venv.num_envs 8 | self.observation_space = venv.observation_space 9 | self.action_space = venv.action_space 10 | 11 | def reset(self, seed=None, options=None): 12 | if seed is not None: 13 | self.seed(seed=seed) 14 | 15 | return self.venv.reset() 16 | 17 | def step_async(self, actions): 18 | self.venv.step_async(actions) 19 | 20 | def step_wait(self): 21 | return self.venv.step_wait() 22 | 23 | def step(self, actions): 24 | return self.venv.step(actions) 25 | 26 | def close(self): 27 | del self.venv 28 | 29 | def seed(self, seed=None): 30 | self.venv.seed(seed) 31 | 32 | def get_attr(self, attr_name, indices=None): 33 | raise NotImplementedError() 34 | 35 | def set_attr(self, attr_name, value, indices=None): 36 | raise NotImplementedError() 37 | 38 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 39 | raise NotImplementedError() 40 | -------------------------------------------------------------------------------- /supersuit/utils/basic_transforms/color_reduction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from . import convert_box 4 | 5 | 6 | COLOR_RED_LIST = ["full", "R", "G", "B"] 7 | GRAYSCALE_WEIGHTS = np.array([0.299, 0.587, 0.114], dtype=np.float32) 8 | 9 | 10 | def check_param(space, color_reduction): 11 | assert isinstance( 12 | color_reduction, str 13 | ), f"color_reduction must be str. It is {color_reduction}" 14 | assert color_reduction in COLOR_RED_LIST, "color_reduction must be in {}".format( 15 | COLOR_RED_LIST 16 | ) 17 | assert ( 18 | len(space.low.shape) == 3 and space.low.shape[2] == 3 19 | ), "To apply color_reduction, shape must be a 3d image with last dimension of size 3. Shape is {}".format( 20 | space.low.shape 21 | ) 22 | 23 | 24 | def change_obs_space(obs_space, param): 25 | return convert_box(lambda obs: change_observation(obs, obs_space, param), obs_space) 26 | 27 | 28 | def change_observation(obs, obs_space, color_reduction): 29 | if color_reduction == "R": 30 | obs = obs[:, :, 0] 31 | if color_reduction == "G": 32 | obs = obs[:, :, 1] 33 | if color_reduction == "B": 34 | obs = obs[:, :, 2] 35 | if color_reduction == "full": 36 | obs = (obs.astype(np.float32) @ GRAYSCALE_WEIGHTS).astype(np.uint8) 37 | return obs 38 | -------------------------------------------------------------------------------- /test/test_utils/test_action_transforms/test_homogenize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium.spaces import Box, Discrete 3 | 4 | from supersuit.utils.action_transforms.homogenize_ops import ( 5 | check_homogenize_spaces, 6 | dehomogenize_actions, 7 | homogenize_observations, 8 | homogenize_spaces, 9 | ) 10 | 11 | 12 | box_spaces = [ 13 | Box(low=np.float32(0), high=np.float32(1), shape=(5, 4)), 14 | Box(low=np.float32(0), high=np.float32(1), shape=(10, 2)), 15 | ] 16 | discrete_spaces = [Discrete(5), Discrete(7)] 17 | 18 | 19 | def test_param_check(): 20 | check_homogenize_spaces(box_spaces) 21 | check_homogenize_spaces(discrete_spaces) 22 | 23 | 24 | def test_homogenize_spaces(): 25 | hom_space_box = homogenize_spaces(box_spaces) 26 | hom_space_discrete = homogenize_spaces(discrete_spaces) 27 | assert hom_space_box.shape == (10, 4) 28 | assert hom_space_discrete.n == 7 29 | 30 | 31 | def test_dehomogenize_actions(): 32 | action = np.ones([10, 4]) 33 | assert dehomogenize_actions(box_spaces[0], action).shape == (5, 4) 34 | assert dehomogenize_actions(discrete_spaces[0], 5) == 0 35 | assert dehomogenize_actions(discrete_spaces[0], 4) == 4 36 | 37 | 38 | def test_homogenize_observations(): 39 | obs = np.zeros([5, 4]) 40 | hom_space_box = homogenize_spaces(box_spaces) 41 | assert homogenize_observations(hom_space_box, obs).shape == (10, 4) 42 | -------------------------------------------------------------------------------- /test/test_vector/test_pettingzoo_to_vec_gymnasium_wrappers.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import numpy as np 3 | import pytest 4 | from gymnasium.wrappers import NormalizeObservation 5 | from pettingzoo.butterfly import pistonball_v6 6 | 7 | import supersuit as ss 8 | 9 | 10 | @pytest.mark.parametrize("env_fn", [pistonball_v6]) 11 | def test_vec_env_normalize_obs(env_fn): 12 | env = env_fn.parallel_env() 13 | env = ss.pettingzoo_env_to_vec_env_v1(env) 14 | env = ss.concat_vec_envs_v1(env, 10, base_class="gymnasium") 15 | obs, info = env.reset() 16 | 17 | # Create a "dummy" class that adds gym.Env as a base class of the env 18 | # to satisfy the assertion. 19 | class PatchedEnv(env.__class__, gym.Env): 20 | pass 21 | 22 | env.__class__ = PatchedEnv 23 | 24 | env = NormalizeObservation(env) 25 | normalized_obs, normalized_info = env.reset() 26 | 27 | obs_range = np.amax(obs) - np.amin(obs) 28 | normalized_obs_range = np.amax(normalized_obs) - np.amin(normalized_obs) 29 | assert obs_range > 1, "Regular observation space should be greater than 1." 30 | assert ( 31 | normalized_obs_range < 1.0e-4 32 | ), "Normalized observation space should be smaller than 1.0e-4." 33 | assert ( 34 | obs_range > normalized_obs_range 35 | ), "Normalized observation space has more range than regular observation space." 36 | -------------------------------------------------------------------------------- /test/test_vector/test_render.py: -------------------------------------------------------------------------------- 1 | from pettingzoo.butterfly import pistonball_v6 2 | from stable_baselines3.common.vec_env import VecVideoRecorder 3 | 4 | import supersuit as ss 5 | 6 | 7 | def schedule(episode_idx): 8 | print(episode_idx) 9 | return episode_idx <= 1 10 | 11 | 12 | def make_sb3_record_env(): 13 | env = pistonball_v6.parallel_env(render_mode="rgb_array") 14 | print(env.render_mode) 15 | env = ss.pettingzoo_env_to_vec_env_v1(env) 16 | envs = ss.concat_vec_envs_v1(env, 1, num_cpus=0, base_class="stable_baselines3") 17 | envs = VecVideoRecorder(envs, "/tmp", schedule) 18 | return envs 19 | 20 | 21 | def test_record_video_sb3(): 22 | envs = make_sb3_record_env() 23 | envs.reset() 24 | for _ in range(100): 25 | envs.step([envs.action_space.sample() for _ in range(envs.num_envs)]) 26 | envs.close() 27 | 28 | 29 | def make_env(): 30 | env = pistonball_v6.parallel_env(render_mode="rgb_array") 31 | env = ss.pettingzoo_env_to_vec_env_v1(env) 32 | return env 33 | 34 | 35 | def test_vector_render_multiproc(): 36 | env = make_env() 37 | num_envs = 1 38 | venv = ss.concat_vec_envs_v1( 39 | env, num_envs, num_cpus=num_envs, base_class="stable_baselines3" 40 | ) 41 | venv.reset() 42 | arr = venv.render() 43 | venv.reset() 44 | assert len(arr.shape) == 3 and arr.shape[2] == 3 45 | venv.reset() 46 | try: 47 | venv.close() 48 | except RuntimeError: 49 | pass 50 | -------------------------------------------------------------------------------- /test/test_vector/test_env_is_wrapped.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import pytest 3 | from pettingzoo.mpe import simple_spread_v3 4 | 5 | from supersuit import concat_vec_envs_v1, pettingzoo_env_to_vec_env_v1 6 | from supersuit.generic_wrappers.frame_skip import frame_skip_gym 7 | 8 | 9 | def test_env_is_wrapped_true(): 10 | env = gymnasium.make("MountainCarContinuous-v0") 11 | env = frame_skip_gym(env, 4) 12 | num_envs = 3 13 | venv1 = concat_vec_envs_v1(env, num_envs) 14 | assert venv1.env_is_wrapped(frame_skip_gym) == [True] * 3 15 | 16 | 17 | def test_env_is_wrapped_false(): 18 | env = gymnasium.make("MountainCarContinuous-v0") 19 | num_envs = 3 20 | venv1 = concat_vec_envs_v1(env, num_envs) 21 | assert venv1.env_is_wrapped(frame_skip_gym) == [False] * 3 22 | 23 | 24 | @pytest.mark.skip( 25 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188" 26 | ) 27 | def test_env_is_wrapped_cpu(): 28 | env = gymnasium.make("MountainCarContinuous-v0") 29 | env = frame_skip_gym(env, 4) 30 | num_envs = 3 31 | venv1 = concat_vec_envs_v1(env, num_envs, num_cpus=2) 32 | assert venv1.env_is_wrapped(frame_skip_gym) == [True] * 3 33 | 34 | 35 | def test_env_is_wrapped_pettingzoo(): 36 | env = simple_spread_v3.parallel_env() 37 | venv1 = pettingzoo_env_to_vec_env_v1(env) 38 | num_envs = 3 39 | venv1 = concat_vec_envs_v1(venv1, num_envs) 40 | assert venv1.env_is_wrapped(frame_skip_gym) == [False] * 9 41 | -------------------------------------------------------------------------------- /supersuit/vector/constructors.py: -------------------------------------------------------------------------------- 1 | from .concat_vec_env import ConcatVecEnv 2 | from .multiproc_vec import ProcConcatVec 3 | 4 | 5 | class call_wrap: 6 | def __init__(self, fn, data): 7 | self.fn = fn 8 | self.data = data 9 | 10 | def __call__(self, *args): 11 | return self.fn(self.data) 12 | 13 | 14 | def MakeCPUAsyncConstructor(max_num_cpus): 15 | if max_num_cpus == 0 or max_num_cpus == 1: 16 | return ConcatVecEnv 17 | else: 18 | 19 | def constructor(env_fn_list, obs_space, act_space): 20 | example_env = env_fn_list[0]() 21 | envs_per_env = getattr(example_env, "num_envs", 1) 22 | 23 | num_fns = len(env_fn_list) 24 | envs_per_cpu = (num_fns + max_num_cpus - 1) // max_num_cpus 25 | 26 | env_cpu_div = [] 27 | num_envs_alloced = 0 28 | while num_envs_alloced < num_fns: 29 | start_idx = num_envs_alloced 30 | end_idx = min(num_fns, start_idx + envs_per_cpu) 31 | env_cpu_div.append(env_fn_list[start_idx:end_idx]) 32 | num_envs_alloced = end_idx 33 | 34 | cat_env_fns = [call_wrap(ConcatVecEnv, env_fns) for env_fns in env_cpu_div] 35 | return ProcConcatVec( 36 | cat_env_fns, 37 | obs_space, 38 | act_space, 39 | num_fns * envs_per_env, 40 | example_env.metadata, 41 | ) 42 | 43 | return constructor 44 | -------------------------------------------------------------------------------- /test/test_utils/test_basic_transforms/test_normalized_obs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from gymnasium.spaces import Box 4 | 5 | from supersuit.utils.basic_transforms.normalize_obs import ( 6 | change_obs_space, 7 | change_observation, 8 | check_param, 9 | ) 10 | 11 | 12 | high_val = np.array([1, 2, 4]) 13 | test_val = np.array([1, 1, 1]) 14 | test_obs_space = Box( 15 | low=np.zeros(3, dtype=np.float32), high=high_val.astype(np.float32) 16 | ) 17 | bad_test_obs_space = Box( 18 | low=np.zeros(3, dtype=np.int32), high=high_val.astype(np.int32), dtype=np.int32 19 | ) 20 | bad_test_obs_space2 = Box( 21 | low=np.zeros(3, dtype=np.float32), high=np.inf * high_val.astype(np.float32) 22 | ) 23 | 24 | 25 | def test_param_check(): 26 | check_param(test_obs_space, (2, 3)) 27 | with pytest.raises(AssertionError): 28 | check_param(test_obs_space, (2, 2)) 29 | with pytest.raises(AssertionError): 30 | check_param(test_obs_space, ("bob", 2)) 31 | with pytest.raises(AssertionError): 32 | check_param(bad_test_obs_space, (2, 3)) 33 | with pytest.raises(AssertionError): 34 | check_param(bad_test_obs_space2, (2, 3)) 35 | 36 | 37 | def test_change_obs_space(): 38 | assert np.all( 39 | np.equal(change_obs_space(test_obs_space, (1, 2)).high, np.array([2, 2, 2])) 40 | ) 41 | 42 | 43 | def test_change_observation(): 44 | assert np.all( 45 | np.equal( 46 | change_observation(test_val, test_obs_space, (1, 2)), 47 | np.array([2, 1.5, 1.25]), 48 | ) 49 | ) 50 | -------------------------------------------------------------------------------- /supersuit/utils/base_aec_wrapper.py: -------------------------------------------------------------------------------- 1 | from pettingzoo.utils.wrappers import OrderEnforcingWrapper as PZBaseWrapper 2 | 3 | 4 | class BaseWrapper(PZBaseWrapper): 5 | def __init__(self, env): 6 | """ 7 | Creates a wrapper around `env`. Extend this class to create changes to the space. 8 | """ 9 | super().__init__(env) 10 | 11 | self._check_wrapper_params() 12 | 13 | self._modify_spaces() 14 | 15 | def _check_wrapper_params(self): 16 | pass 17 | 18 | def _modify_spaces(self): 19 | pass 20 | 21 | def _modify_action(self, agent, action): 22 | raise NotImplementedError() 23 | 24 | def _modify_observation(self, agent, observation): 25 | raise NotImplementedError() 26 | 27 | def _update_step(self, agent): 28 | pass 29 | 30 | def reset(self, seed=None, options=None): 31 | super().reset(seed=seed, options=options) 32 | self._update_step(self.agent_selection) 33 | 34 | def observe(self, agent): 35 | obs = super().observe( 36 | agent 37 | ) # problem is in this line, the obs is sometimes a different size from the obs space 38 | observation = self._modify_observation(agent, obs) 39 | return observation 40 | 41 | def step(self, action): 42 | agent = self.env.agent_selection 43 | if not (self.terminations[agent] or self.truncations[agent]): 44 | action = self._modify_action(agent, action) 45 | 46 | super().step(action) 47 | 48 | self._update_step(self.agent_selection) 49 | -------------------------------------------------------------------------------- /.github/workflows/linux-test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-22.04 19 | strategy: 20 | matrix: 21 | python-version: ['3.9', '3.10', '3.11', '3.12'] 22 | 23 | steps: 24 | - uses: actions/checkout@v2 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v2 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - name: Install ubuntu dependencies 30 | run: | 31 | sudo apt-get install python3-opengl xvfb 32 | - name: Install python dependencies 33 | run: | 34 | pip install -e .[testing] 35 | # Install pettingzoo directly from master until it mints a new version to resolve 36 | # pygame version restriction. 37 | pip install "pettingzoo[all,atari] @ git+https://github.com/Farama-Foundation/PettingZoo.git@master" 38 | AutoROM -v 39 | - name: Test with pytest 40 | run: | 41 | xvfb-run -s "-screen 0 1400x900x24" pytest ./test 42 | - name: Test installation 43 | run: | 44 | python -m pip install --upgrade build 45 | python -m build --sdist 46 | pip install dist/*.tar.gz 47 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: check-symlinks 8 | - id: destroyed-symlinks 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-yaml 12 | - id: check-toml 13 | - id: check-ast 14 | - id: check-added-large-files 15 | - id: check-merge-conflict 16 | - id: check-executables-have-shebangs 17 | - id: check-shebang-scripts-are-executable 18 | - id: detect-private-key 19 | - id: debug-statements 20 | - repo: https://github.com/codespell-project/codespell 21 | rev: v2.2.2 22 | hooks: 23 | - id: codespell 24 | args: 25 | - --ignore-words-list=magent 26 | - repo: https://github.com/PyCQA/flake8 27 | rev: 6.0.0 28 | hooks: 29 | - id: flake8 30 | args: 31 | - '--per-file-ignores=*/__init__.py:F401' 32 | - --ignore=E203,W503,E741 33 | - --max-complexity=30 34 | - --max-line-length=456 35 | - --show-source 36 | - --statistics 37 | - repo: https://github.com/asottile/pyupgrade 38 | rev: v3.3.1 39 | hooks: 40 | - id: pyupgrade 41 | args: ["--py37-plus"] 42 | - repo: https://github.com/PyCQA/isort 43 | rev: 5.12.0 44 | hooks: 45 | - id: isort 46 | args: ["--profile", "black"] 47 | - repo: https://github.com/python/black 48 | rev: 23.1.0 49 | hooks: 50 | - id: black 51 | -------------------------------------------------------------------------------- /supersuit/utils/basic_transforms/normalize_obs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium.spaces import Box 3 | 4 | 5 | def check_param(obs_space, min_max_pair): 6 | assert np.dtype(obs_space.dtype) == np.dtype("float32") or np.dtype( 7 | obs_space.dtype 8 | ) == np.dtype("float64") 9 | assert ( 10 | isinstance(min_max_pair, tuple) and len(min_max_pair) == 2 11 | ), f"range_scale must be tuple of size 2. It is {min_max_pair}" 12 | try: 13 | min_res = float(min_max_pair[0]) 14 | max_res = float(min_max_pair[1]) 15 | except ValueError: 16 | assert False, "normalize_obs inputs must be numbers. They are {}".format( 17 | min_max_pair 18 | ) 19 | assert ( 20 | max_res > min_res 21 | ), "maximum must be greater than minimum value in normalize_obs" 22 | assert np.all(np.isfinite(obs_space.low)) and np.all( 23 | np.isfinite(obs_space.high) 24 | ), "Box observation_space of environment has infinite bounds! Only environments with finite bounds can be passed to normalize_obs_v0" 25 | 26 | 27 | def change_obs_space(obs_space, min_max_pair): 28 | min = np.float64(min_max_pair[0]).astype(obs_space.dtype) 29 | max = np.float64(min_max_pair[1]).astype(obs_space.dtype) 30 | return Box(low=min, high=max, shape=obs_space.shape, dtype=obs_space.dtype) 31 | 32 | 33 | def change_observation(obs, obs_space, min_max_pair): 34 | min_res, max_res = (float(x) for x in min_max_pair) 35 | old_size = obs_space.high - obs_space.low 36 | new_size = max_res - min_res 37 | result = (obs - obs_space.low) / old_size * new_size + min_res 38 | return result 39 | -------------------------------------------------------------------------------- /supersuit/multiagent_wrappers/padding_wrappers.py: -------------------------------------------------------------------------------- 1 | from pettingzoo.utils.env import AECEnv, ParallelEnv 2 | 3 | from supersuit import action_lambda_v1, observation_lambda_v0 4 | from supersuit.utils.action_transforms import homogenize_ops 5 | 6 | 7 | def pad_action_space_v0(env): 8 | assert isinstance(env, AECEnv) or isinstance( 9 | env, ParallelEnv 10 | ), "pad_action_space_v0 only accepts an AECEnv or ParallelEnv" 11 | assert hasattr( 12 | env, "possible_agents" 13 | ), "environment passed to pad_observations must have a possible_agents list." 14 | spaces = [env.action_space(agent) for agent in env.possible_agents] 15 | homogenize_ops.check_homogenize_spaces(spaces) 16 | padded_space = homogenize_ops.homogenize_spaces(spaces) 17 | return action_lambda_v1( 18 | env, 19 | lambda action, act_space: homogenize_ops.dehomogenize_actions( 20 | act_space, action 21 | ), 22 | lambda act_space: padded_space, 23 | ) 24 | 25 | 26 | def pad_observations_v0(env): 27 | assert isinstance(env, AECEnv) or isinstance( 28 | env, ParallelEnv 29 | ), "pad_observations_v0 only accepts an AECEnv or ParallelEnv" 30 | assert hasattr( 31 | env, "possible_agents" 32 | ), "environment passed to pad_observations must have a possible_agents list." 33 | spaces = [env.observation_space(agent) for agent in env.possible_agents] 34 | homogenize_ops.check_homogenize_spaces(spaces) 35 | padded_space = homogenize_ops.homogenize_spaces(spaces) 36 | return observation_lambda_v0( 37 | env, 38 | lambda obs, obs_space: homogenize_ops.homogenize_observations( 39 | padded_space, obs 40 | ), 41 | lambda obs_space: padded_space, 42 | ) 43 | -------------------------------------------------------------------------------- /supersuit/vector/single_vec_env.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import numpy as np 3 | 4 | 5 | class SingleVecEnv: 6 | def __init__(self, gym_env_fns, *args): 7 | assert len(gym_env_fns) == 1 8 | self.gym_env = gym_env_fns[0]() 9 | self.render_mode = self.gym_env.render_mode 10 | self.observation_space = self.gym_env.observation_space 11 | self.action_space = self.gym_env.action_space 12 | self.num_envs = 1 13 | self.metadata = self.gym_env.metadata 14 | 15 | def reset(self, seed=None, options=None): 16 | # TODO: should this include info 17 | return np.expand_dims(self.gym_env.reset(seed=seed, options=options), 0) 18 | 19 | def step_async(self, actions): 20 | self._saved_actions = actions 21 | 22 | def step_wait(self): 23 | return self.step(self._saved_actions) 24 | 25 | def render(self): 26 | return self.gym_env.render() 27 | 28 | def close(self): 29 | self.gym_env.close() 30 | 31 | def step(self, actions): 32 | observations, reward, term, trunc, info = self.gym_env.step(actions[0]) 33 | if term or trunc: 34 | observations = self.gym_env.reset() 35 | observations = np.expand_dims(observations, 0) 36 | rewards = np.array([reward], dtype=np.float32) 37 | terms = np.array([term], dtype=np.uint8) 38 | truncs = np.array([trunc], dtype=np.uint8) 39 | infos = [info] 40 | return observations, rewards, terms, truncs, infos 41 | 42 | def env_is_wrapped(self, wrapper_class): 43 | env_tmp = self.gym_env 44 | while isinstance(env_tmp, gymnasium.Wrapper): 45 | if isinstance(env_tmp, wrapper_class): 46 | return [True] 47 | env_tmp = env_tmp.env 48 | return [False] 49 | -------------------------------------------------------------------------------- /supersuit/utils/wrapper_chooser.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | from pettingzoo.utils.conversions import aec_to_parallel, parallel_to_aec 3 | from pettingzoo.utils.env import AECEnv, ParallelEnv 4 | 5 | 6 | class WrapperChooser: 7 | def __init__(self, aec_wrapper=None, gym_wrapper=None, parallel_wrapper=None): 8 | assert ( 9 | aec_wrapper is not None or parallel_wrapper is not None 10 | ), "either the aec wrapper or the parallel wrapper must be defined for all supersuit environments" 11 | self.aec_wrapper = aec_wrapper 12 | self.gym_wrapper = gym_wrapper 13 | self.parallel_wrapper = parallel_wrapper 14 | 15 | def __call__(self, env, *args, **kwargs): 16 | if isinstance(env, gymnasium.Env): 17 | if self.gym_wrapper is None: 18 | raise ValueError( 19 | f"{self.wrapper_name} does not apply to gymnasium environments, pettingzoo environments only" 20 | ) 21 | return self.gym_wrapper(env, *args, **kwargs) 22 | elif isinstance(env, AECEnv): 23 | if self.aec_wrapper is not None: 24 | return self.aec_wrapper(env, *args, **kwargs) 25 | else: 26 | return parallel_to_aec( 27 | self.parallel_wrapper(aec_to_parallel(env), *args, **kwargs) 28 | ) 29 | elif isinstance(env, ParallelEnv): 30 | if self.parallel_wrapper is not None: 31 | return self.parallel_wrapper(env, *args, **kwargs) 32 | else: 33 | return aec_to_parallel( 34 | self.aec_wrapper(parallel_to_aec(env), *args, **kwargs) 35 | ) 36 | else: 37 | raise ValueError( 38 | "environment passed to supersuit wrapper must either be a gymnasium environment or a pettingzoo environment" 39 | ) 40 | -------------------------------------------------------------------------------- /test/test_utils/test_basic_transforms/test_resize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from gymnasium.spaces import Box 4 | 5 | from supersuit.utils.basic_transforms.resize import change_observation, check_param 6 | 7 | 8 | test_shape = (6, 4, 3) 9 | high_val = np.ones(test_shape) + np.arange(4).reshape(1, 4, 1) 10 | test_obs_space = Box(low=high_val - 1, high=high_val, dtype=np.uint8) 11 | test_val = high_val - 0.5 12 | 13 | 14 | def test_param_check(): 15 | check_param(test_obs_space, (2, 2, False)) 16 | with pytest.raises(AssertionError): 17 | check_param(test_obs_space, (-2, 2, 2)) 18 | with pytest.raises(AssertionError): 19 | check_param(test_obs_space, (2.0, 2, 2)) 20 | 21 | 22 | def test_change_observation(): 23 | cur_val = test_val 24 | cur_val = cur_val.astype(np.uint8) 25 | new_obs = change_observation(cur_val, test_obs_space, (3, 2, False)) 26 | new_obs = change_observation(cur_val, test_obs_space, (3, 2, True)) 27 | test_obs = np.array( 28 | [ 29 | [ 30 | [0, 0, 0], 31 | [1, 1, 1], 32 | [2, 2, 2], 33 | ], 34 | [ 35 | [0, 0, 0], 36 | [1, 1, 1], 37 | [2, 2, 2], 38 | ], 39 | ] 40 | ).astype(np.uint8) 41 | 42 | assert new_obs.dtype == np.uint8 43 | assert np.all(np.equal(new_obs, test_obs)) 44 | 45 | test_shape = (6, 4) 46 | high_val = np.ones(test_shape).astype(np.float64) 47 | obs_spae = Box(low=high_val - 1, high=high_val) 48 | new_obs = change_observation(high_val - 0.5, obs_spae, (3, 2, False)) 49 | assert new_obs.shape == (2, 3) 50 | assert new_obs.dtype == np.float64 51 | 52 | test_shape = (6, 5, 4) 53 | high_val = np.ones(test_shape).astype(np.uint8) 54 | obs_spae = Box(low=high_val - 1, high=high_val, dtype=np.uint8) 55 | new_obs = change_observation(high_val, obs_spae, (5, 2, False)) 56 | assert new_obs.shape == (2, 5, 4) 57 | assert new_obs.dtype == np.uint8 58 | -------------------------------------------------------------------------------- /.github/workflows/build-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build and (if release) publish Python distributions to PyPI 2 | # For more information see: 3 | # - https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 4 | # - https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/ 5 | # 6 | 7 | name: build-publish 8 | 9 | on: 10 | workflow_dispatch: 11 | push: 12 | branches: [master] 13 | pull_request: 14 | paths: 15 | - .github/workflows/build-publish.yml 16 | release: 17 | types: [published] 18 | 19 | jobs: 20 | build: 21 | name: Build sdist and wheel 22 | runs-on: ubuntu-latest 23 | permissions: 24 | contents: read 25 | steps: 26 | - uses: actions/checkout@v4 27 | - name: Set up Python 28 | uses: actions/setup-python@v4 29 | with: 30 | python-version: 3.x 31 | - name: Install dependencies 32 | run: python -m pip install --upgrade setuptools wheel build 33 | - name: Build sdist and wheel 34 | run: python -m build --sdist --wheel 35 | - name: Store sdist and wheel 36 | uses: actions/upload-artifact@v4 37 | with: 38 | name: artifact 39 | path: dist 40 | publish: 41 | name: Publish to PyPI 42 | needs: [build] 43 | runs-on: ubuntu-latest 44 | environment: pypi 45 | if: github.event_name == 'release' && github.event.action == 'published' 46 | steps: 47 | - name: Download dist 48 | uses: actions/download-artifact@v4 49 | with: 50 | name: artifact 51 | path: dist 52 | - name: Publish 53 | uses: pypa/gh-action-pypi-publish@release/v1 54 | with: 55 | password: ${{ secrets.PYPI_API_TOKEN }} 56 | -------------------------------------------------------------------------------- /supersuit/aec_vector/base_aec_vec_env.py: -------------------------------------------------------------------------------- 1 | class VectorAECEnv: 2 | def reset(self, seed=None, options=None): 3 | """ 4 | resets all environments 5 | """ 6 | 7 | def observe(self, agent): 8 | """ 9 | returns observation for agent from all environments (if agent is alive, else all zeros) 10 | """ 11 | 12 | def last(self, observe=True): 13 | """ 14 | returns list of observations, rewards, dones, env_dones, passes, infos 15 | 16 | each of the following is a list over environments that holds the value for the current agent (env.agent_selection) 17 | 18 | dones: are True when the current agent is done 19 | env_dones: is True when all agents are done, and the environment will reset 20 | passes: is true when the agent is not stepping this turn (because it is dead or not currently stepping for some other reason) 21 | infos: list of infos for the agent 22 | """ 23 | 24 | def step(self, actions, observe=True): 25 | """ 26 | steps the current agent with the following actions. 27 | Unlike a regular AECEnv, the actions cannot be None 28 | """ 29 | 30 | def agent_iter(self, max_iter): 31 | """ 32 | Unlike aec agent_iter, this does not stop on environment done. Instead, 33 | vector environment resets specific envs when done. 34 | 35 | Instead, just continues until max_iter is reached. 36 | """ 37 | return AECIterable(self, max_iter) 38 | 39 | 40 | class AECIterable: 41 | def __init__(self, env, max_iter): 42 | self.env = env 43 | self.max_iter = max_iter 44 | 45 | def __iter__(self): 46 | return AECIterator(self.env, self.max_iter) 47 | 48 | 49 | class AECIterator: 50 | def __init__(self, env, max_iter): 51 | self.env = env 52 | self.iters_til_term = max_iter 53 | self.env._is_iterating = True 54 | 55 | def __next__(self): 56 | if self.iters_til_term <= 0: 57 | raise StopIteration 58 | self.iters_til_term -= 1 59 | return self.env.agent_selection 60 | -------------------------------------------------------------------------------- /test/dummy_aec_env.py: -------------------------------------------------------------------------------- 1 | from pettingzoo import AECEnv 2 | from pettingzoo.utils.agent_selector import agent_selector 3 | 4 | 5 | class DummyEnv(AECEnv): 6 | metadata = {"render_modes": ["human"], "is_parallelizable": True} 7 | 8 | def __init__(self, observations, observation_spaces, action_spaces): 9 | super().__init__() 10 | self._observations = observations 11 | self._observation_spaces = observation_spaces 12 | self.render_mode = None 13 | self.agents = sorted([x for x in observation_spaces.keys()]) 14 | self.possible_agents = self.agents[:] 15 | self._agent_selector = agent_selector(self.agents) 16 | self.agent_selection = self._agent_selector.reset() 17 | self._action_spaces = action_spaces 18 | 19 | self.steps = 0 20 | 21 | def observation_space(self, agent): 22 | return self._observation_spaces[agent] 23 | 24 | def action_space(self, agent): 25 | return self._action_spaces[agent] 26 | 27 | def observe(self, agent): 28 | return self._observations[agent] 29 | 30 | def step(self, action, observe=True): 31 | if ( 32 | self.terminations[self.agent_selection] 33 | or self.truncations[self.agent_selection] 34 | ): 35 | return self._was_dead_step(action) 36 | self._cumulative_rewards[self.agent_selection] = 0 37 | self.agent_selection = self._agent_selector.next() 38 | self.steps += 1 39 | if self.steps >= 5 * len(self.agents): 40 | self.truncations = {a: True for a in self.agents} 41 | 42 | self._accumulate_rewards() 43 | self._deads_step_first() 44 | 45 | def reset(self, seed=None, options=None): 46 | self.agents = self.possible_agents[:] 47 | self._agent_selector = agent_selector(self.agents) 48 | self.agent_selection = self._agent_selector.reset() 49 | self.rewards = {a: 1 for a in self.agents} 50 | self._cumulative_rewards = {a: 0 for a in self.agents} 51 | self.terminations = {a: False for a in self.agents} 52 | self.truncations = {a: False for a in self.agents} 53 | self.infos = {a: {} for a in self.agents} 54 | self.steps = 0 55 | 56 | def close(self): 57 | pass 58 | -------------------------------------------------------------------------------- /supersuit/multiagent_wrappers/black_death.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import numpy as np 3 | from pettingzoo.utils.wrappers import BaseParallelWrapper 4 | 5 | from supersuit.utils.wrapper_chooser import WrapperChooser 6 | 7 | 8 | class black_death_par(BaseParallelWrapper): 9 | def __init__(self, env): 10 | super().__init__(env) 11 | 12 | def _check_valid_for_black_death(self): 13 | for agent in self.agents: 14 | space = self.observation_space(agent) 15 | assert isinstance( 16 | space, gymnasium.spaces.Box 17 | ), f"observation sapces for black death must be Box spaces, is {space}" 18 | 19 | def reset(self, seed=None, options=None): 20 | obss, infos = self.env.reset(seed=seed, options=options) 21 | 22 | self.agents = self.env.agents[:] 23 | self._check_valid_for_black_death() 24 | black_obs = { 25 | agent: np.zeros_like(self.observation_space(agent).low) 26 | for agent in self.agents 27 | if agent not in obss 28 | } 29 | return {**obss, **black_obs}, infos 30 | 31 | def step(self, actions): 32 | active_actions = {agent: actions[agent] for agent in self.env.agents} 33 | obss, rews, terms, truncs, infos = self.env.step(active_actions) 34 | black_obs = { 35 | agent: np.zeros_like(self.observation_space(agent).low) 36 | for agent in self.agents 37 | if agent not in obss 38 | } 39 | black_rews = {agent: 0.0 for agent in self.agents if agent not in obss} 40 | black_infos = {agent: {} for agent in self.agents if agent not in obss} 41 | terminations = np.fromiter(terms.values(), dtype=bool) 42 | truncations = np.fromiter(truncs.values(), dtype=bool) 43 | env_is_done = (terminations & truncations).all() 44 | total_obs = {**black_obs, **obss} 45 | total_rews = {**black_rews, **rews} 46 | total_infos = {**black_infos, **infos} 47 | total_dones = {agent: env_is_done for agent in self.agents} 48 | if env_is_done: 49 | self.agents.clear() 50 | return total_obs, total_rews, total_dones, total_dones, total_infos 51 | 52 | 53 | black_death_v3 = WrapperChooser(parallel_wrapper=black_death_par) 54 | -------------------------------------------------------------------------------- /test/parallel_env_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium.spaces import Box, Discrete 3 | from pettingzoo.utils import ParallelEnv 4 | 5 | import supersuit 6 | 7 | 8 | class DummyParEnv(ParallelEnv): 9 | metadata = {"render_modes": ["human"]} 10 | 11 | def __init__(self, observations, observation_spaces, action_spaces): 12 | super().__init__() 13 | self._observations = observations 14 | self._observation_spaces = observation_spaces 15 | 16 | self.agents = [x for x in observation_spaces.keys()] 17 | self.possible_agents = self.agents 18 | self.agent_selection = self.agents[0] 19 | self._action_spaces = action_spaces 20 | 21 | self.rewards = {a: 1 for a in self.agents} 22 | self.terminations = {a: False for a in self.agents} 23 | self.truncations = {a: False for a in self.agents} 24 | self.infos = {a: {} for a in self.agents} 25 | 26 | def observation_space(self, agent): 27 | return self._observation_spaces[agent] 28 | 29 | def action_space(self, agent): 30 | return self._action_spaces[agent] 31 | 32 | def step(self, actions): 33 | for agent, action in actions.items(): 34 | assert action in self.action_space(agent) 35 | return ( 36 | self._observations, 37 | self.rewards, 38 | self.terminations, 39 | self.truncations, 40 | self.infos, 41 | ) 42 | 43 | def reset(self, seed=None, options=None): 44 | return self._observations, self.infos 45 | 46 | def close(self): 47 | pass 48 | 49 | 50 | base_obs = { 51 | f"a{idx}": np.zeros([8, 8, 3], dtype=np.float32) + np.arange(3) + idx 52 | for idx in range(2) 53 | } 54 | base_obs_space = { 55 | f"a{idx}": Box(low=np.float32(0.0), high=np.float32(10.0), shape=[8, 8, 3]) 56 | for idx in range(2) 57 | } 58 | base_act_spaces = {f"a{idx}": Discrete(5) for idx in range(2)} 59 | 60 | 61 | def test_basic(): 62 | env = DummyParEnv(base_obs, base_obs_space, base_act_spaces) 63 | env = supersuit.delay_observations_v0(env, 4) 64 | env = supersuit.dtype_v0(env, np.uint8) 65 | env.reset() 66 | for i in range(10): 67 | action = {agent: env.action_space(agent).sample() for agent in env.agents} 68 | env.step(action) 69 | -------------------------------------------------------------------------------- /supersuit/vector/sb3_vector_wrapper.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, List, Optional 3 | 4 | import numpy as np 5 | from stable_baselines3.common.vec_env import VecEnvWrapper 6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnvIndices 7 | 8 | 9 | class SB3VecEnvWrapper(VecEnvWrapper): 10 | def __init__(self, venv): 11 | self.venv = venv 12 | self.num_envs = venv.num_envs 13 | self.observation_space = venv.observation_space 14 | self.action_space = venv.action_space 15 | self.render_mode = venv.render_mode 16 | self.reset_infos = [] 17 | 18 | def reset(self, seed=None, options=None): 19 | if seed is not None: 20 | self.seed(seed=seed) 21 | # Note: SB3's vector envs return only observations on reset, and store infos in `self.reset_infos` 22 | observations, self.reset_infos = self.venv.reset() 23 | return observations 24 | 25 | def step_wait(self): 26 | observations, rewards, terminations, truncations, infos = self.venv.step_wait() 27 | # Note: SB3 expects dones to be an np.array 28 | dones = np.array( 29 | [terminations[i] or truncations[i] for i in range(len(terminations))] 30 | ) 31 | return observations, rewards, dones, infos 32 | 33 | def env_is_wrapped(self, wrapper_class, indices=None): 34 | # ignores indices 35 | return self.venv.env_is_wrapped(wrapper_class) 36 | 37 | def getattr_recursive(self, name): 38 | raise AttributeError(name) 39 | 40 | def getattr_depth_check(self, name, already_found): 41 | return None 42 | 43 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: 44 | attr = self.venv.get_attr(attr_name) 45 | # Note: SB3 expects render_mode to be returned as an array, with values for each env 46 | if attr_name == "render_mode": 47 | return [attr for _ in range(self.num_envs)] 48 | else: 49 | return attr 50 | 51 | def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: 52 | warnings.warn( 53 | "PettingZoo environments do not take the `render(mode)` argument, to change rendering mode, re-initialize the environment using the `render_mode` argument." 54 | ) 55 | return self.venv.render() 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | **Aug 11, 2025: This project is semi-depricated, and is unmaintained except for being kept operational with new versions of PettingZoo until the relevant functionality can be merged into pettingzoo.wrappers.** 6 | 7 | SuperSuit introduces a collection of small functions which can wrap reinforcement learning environments to do preprocessing ('microwrappers'). 8 | We support Gymnasium for single agent environments and PettingZoo for multi-agent environments (both AECEnv and ParallelEnv environments). 9 | 10 | 11 | Using it with Gymnasium to convert space invaders to have a grey scale observation space and stack the last 4 frames looks like: 12 | 13 | ``` 14 | import gymnasium 15 | from supersuit import color_reduction_v0, frame_stack_v1 16 | 17 | env = gymnasium.make('SpaceInvaders-v0') 18 | 19 | env = frame_stack_v1(color_reduction_v0(env, 'full'), 4) 20 | ``` 21 | 22 | Similarly, using SuperSuit with PettingZoo environments looks like 23 | 24 | ``` 25 | from pettingzoo.butterfly import pistonball_v0 26 | env = pistonball_v0.env() 27 | 28 | env = frame_stack_v1(color_reduction_v0(env, 'full'), 4) 29 | ``` 30 | 31 | 32 | **Please note**: Once the planned wrapper rewrite of Gymnasium is complete and the vector API is stabilized, this project will be deprecated and rewritten as part of a new wrappers package in PettingZoo and the vectorized API will be redone, taking inspiration from the functionality currently in Gymnasium. 33 | 34 | ## Installing SuperSuit 35 | To install SuperSuit from pypi: 36 | 37 | ``` 38 | python3 -m venv env 39 | source env/bin/activate 40 | pip install --upgrade pip 41 | pip install supersuit 42 | ``` 43 | 44 | Alternatively, to install SuperSuit from source, clone this repo, `cd` to it, and then: 45 | 46 | ``` 47 | python3 -m venv env 48 | source env/bin/activate 49 | pip install --upgrade pip 50 | pip install -e . 51 | ``` 52 | 53 | ## Citation 54 | 55 | If you use this in your research, please cite: 56 | 57 | ``` 58 | @article{SuperSuit, 59 | Title = {SuperSuit: Simple Microwrappers for Reinforcement Learning Environments}, 60 | Author = {Terry, J. K and Black, Benjamin and Hari, Ananth}, 61 | journal={arXiv preprint arXiv:2008.08932}, 62 | year={2020} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /supersuit/lambda_wrappers/reward_lambda.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | from pettingzoo.utils import BaseWrapper as PettingzooWrap 3 | 4 | from supersuit.utils.make_defaultdict import make_defaultdict 5 | from supersuit.utils.wrapper_chooser import WrapperChooser 6 | 7 | 8 | class aec_reward_lambda(PettingzooWrap): 9 | def __init__(self, env, change_reward_fn): 10 | assert callable( 11 | change_reward_fn 12 | ), f"change_reward_fn needs to be a function. It is {change_reward_fn}" 13 | self._change_reward_fn = change_reward_fn 14 | 15 | super().__init__(env) 16 | 17 | def _check_wrapper_params(self): 18 | pass 19 | 20 | def _modify_spaces(self): 21 | pass 22 | 23 | def reset(self, seed=None, options=None): 24 | super().reset(seed=seed, options=options) 25 | self.rewards = { 26 | agent: self._change_reward_fn(reward) 27 | for agent, reward in self.env.rewards.items() # you don't want to unwrap here, because another reward wrapper might have been applied 28 | } 29 | self.__cumulative_rewards = make_defaultdict({a: 0 for a in self.agents}) 30 | self._accumulate_rewards() 31 | 32 | def step(self, action): 33 | agent = self.env.agent_selection 34 | super().step(action) 35 | self.rewards = { 36 | agent: self._change_reward_fn(reward) 37 | for agent, reward in self.env.rewards.items() # you don't want to unwrap here, because another reward wrapper might have been applied 38 | } 39 | self.__cumulative_rewards[agent] = 0 40 | self._cumulative_rewards = self.__cumulative_rewards 41 | self._accumulate_rewards() 42 | 43 | 44 | class gym_reward_lambda(gymnasium.Wrapper): 45 | def __init__(self, env, change_reward_fn): 46 | assert callable( 47 | change_reward_fn 48 | ), f"change_reward_fn needs to be a function. It is {change_reward_fn}" 49 | self._change_reward_fn = change_reward_fn 50 | 51 | super().__init__(env) 52 | 53 | def step(self, action): 54 | obs, rew, termination, truncation, info = super().step(action) 55 | return obs, self._change_reward_fn(rew), termination, truncation, info 56 | 57 | 58 | reward_lambda_v0 = WrapperChooser( 59 | aec_wrapper=aec_reward_lambda, gym_wrapper=gym_reward_lambda 60 | ) 61 | -------------------------------------------------------------------------------- /supersuit/__init__.py: -------------------------------------------------------------------------------- 1 | from supersuit.generic_wrappers import clip_actions_v0 # NOQA 2 | from supersuit.generic_wrappers import ( 3 | clip_reward_v0, 4 | color_reduction_v0, 5 | delay_observations_v0, 6 | dtype_v0, 7 | flatten_v0, 8 | frame_skip_v0, 9 | frame_stack_v1, 10 | max_observation_v0, 11 | normalize_obs_v0, 12 | reshape_v0, 13 | resize_v1, 14 | sticky_actions_v0, 15 | ) 16 | 17 | from .aec_vector import vectorize_aec_env_v0 18 | from .generic_wrappers import * # NOQA 19 | from .lambda_wrappers import observation_lambda_v0 # NOQA 20 | from .lambda_wrappers import action_lambda_v1, reward_lambda_v0 21 | from .multiagent_wrappers import black_death_v3 # NOQA 22 | from .multiagent_wrappers import ( 23 | agent_indicator_v0, 24 | pad_action_space_v0, 25 | pad_observations_v0, 26 | ) 27 | from .vector.vector_constructors import concat_vec_envs_v1 # NOQA 28 | from .vector.vector_constructors import ( 29 | gym_vec_env_v0, 30 | pettingzoo_env_to_vec_env_v1, 31 | stable_baselines3_vec_env_v0, 32 | stable_baselines_vec_env_v0, 33 | ) 34 | 35 | 36 | class DeprecatedWrapper(ImportError): 37 | pass 38 | 39 | 40 | def __getattr__(wrapper_name): 41 | """ 42 | Gives error that looks like this when trying to import old version of wrapper: 43 | File "./supersuit/__init__.py", line 38, in __getattr__ 44 | raise DeprecatedWrapper(f"{base}{version_num} is now deprecated, use {base}{act_version_num} instead") 45 | supersuit.DeprecatedWrapper: concat_vec_envs_v0 is now deprecated, use concat_vec_envs_v1 instead 46 | """ 47 | start_v = wrapper_name.rfind("_v") + 2 48 | version = wrapper_name[start_v:] 49 | base = wrapper_name[:start_v] 50 | try: 51 | version_num = int(version) 52 | is_valid_version = True 53 | except ValueError: 54 | is_valid_version = False 55 | 56 | globs = globals() 57 | if is_valid_version: 58 | for act_version_num in range(1000): 59 | if f"{base}{act_version_num}" in globs: 60 | if version_num < act_version_num: 61 | raise DeprecatedWrapper( 62 | f"{base}{version_num} is now deprecated, use {base}{act_version_num} instead" 63 | ) 64 | 65 | raise ImportError(f"cannot import name '{wrapper_name}' from 'supersuit'") 66 | 67 | 68 | __version__ = "3.10.0" 69 | -------------------------------------------------------------------------------- /supersuit/generic_wrappers/basic_wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium.spaces import Box 3 | 4 | from supersuit.lambda_wrappers import ( 5 | action_lambda_v1, 6 | observation_lambda_v0, 7 | reward_lambda_v0, 8 | ) 9 | from supersuit.utils import basic_transforms 10 | 11 | 12 | def basic_obs_wrapper(env, module, param): 13 | def change_space(space): 14 | module.check_param(space, param) 15 | space = module.change_obs_space(space, param) 16 | return space 17 | 18 | def change_obs(obs, obs_space): 19 | return module.change_observation(obs, obs_space, param) 20 | 21 | return observation_lambda_v0(env, change_obs, change_space) 22 | 23 | 24 | def color_reduction_v0(env, mode="full"): 25 | return basic_obs_wrapper(env, basic_transforms.color_reduction, mode) 26 | 27 | 28 | def resize_v1(env, x_size, y_size, linear_interp=True): 29 | scale_tuple = (x_size, y_size, linear_interp) 30 | return basic_obs_wrapper(env, basic_transforms.resize, scale_tuple) 31 | 32 | 33 | def dtype_v0(env, dtype): 34 | return basic_obs_wrapper(env, basic_transforms.dtype, dtype) 35 | 36 | 37 | def flatten_v0(env): 38 | return basic_obs_wrapper(env, basic_transforms.flatten, True) 39 | 40 | 41 | def reshape_v0(env, shape): 42 | return basic_obs_wrapper(env, basic_transforms.reshape, shape) 43 | 44 | 45 | def normalize_obs_v0(env, env_min=0.0, env_max=1.0): 46 | return basic_obs_wrapper(env, basic_transforms.normalize_obs, (env_min, env_max)) 47 | 48 | 49 | def clip_actions_v0(env): 50 | return action_lambda_v1( 51 | env, 52 | lambda action, act_space: np.clip(action, act_space.low, act_space.high) 53 | if action is not None 54 | else None, 55 | lambda act_space: act_space, 56 | ) 57 | 58 | 59 | def scale_actions_v0(env, scale): 60 | def change_act_space(act_space): 61 | assert isinstance( 62 | act_space, Box 63 | ), "scale_actions_v0 only works with a Box action space" 64 | return Box(low=act_space.low * scale, high=act_space.high * scale) 65 | 66 | return action_lambda_v1( 67 | env, 68 | lambda action, act_space: np.asarray(action) * scale 69 | if action is not None 70 | else None, 71 | lambda act_space: change_act_space(act_space), 72 | ) 73 | 74 | 75 | def clip_reward_v0(env, lower_bound=-1, upper_bound=1): 76 | return reward_lambda_v0(env, lambda rew: max(min(rew, upper_bound), lower_bound)) 77 | -------------------------------------------------------------------------------- /test/test_vector/test_aec_vector_identity_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pettingzoo.butterfly import knights_archers_zombies_v10 3 | 4 | from supersuit import vectorize_aec_env_v0 5 | 6 | 7 | def recursive_equal(info1, info2): 8 | if info1 == info2: 9 | return True 10 | return False 11 | 12 | 13 | def test_identical(): 14 | def env_fn(): 15 | return knights_archers_zombies_v10.env() # ,20) 16 | 17 | n_envs = 2 18 | # single threaded 19 | env1 = vectorize_aec_env_v0(knights_archers_zombies_v10.env(), n_envs) 20 | env2 = vectorize_aec_env_v0(knights_archers_zombies_v10.env(), n_envs, num_cpus=1) 21 | env1.reset(seed=42) 22 | env2.reset(seed=42) 23 | 24 | def policy(obs, agent): 25 | return [env1.action_space(agent).sample() for i in range(env1.num_envs)] 26 | 27 | envs_done = 0 28 | for agent in env1.agent_iter(200000): 29 | assert env1.agent_selection == env2.agent_selection 30 | agent = env1.agent_selection 31 | ( 32 | obs1, 33 | rew1, 34 | agent_term1, 35 | agent_trunc1, 36 | env_term1, 37 | env_trunc1, 38 | agent_passes1, 39 | infos1, 40 | ) = env1.last() 41 | ( 42 | obs2, 43 | rew2, 44 | agent_term2, 45 | agent_trunc2, 46 | env_term2, 47 | env_trunc2, 48 | agent_passes2, 49 | infos2, 50 | ) = env2.last() 51 | assert np.all(np.equal(obs1, obs2)) 52 | assert np.all(np.equal(agent_term1, agent_term2)) 53 | assert np.all(np.equal(agent_trunc1, agent_trunc2)) 54 | assert np.all(np.equal(agent_passes1, agent_passes2)) 55 | assert np.all(np.equal(env_term1, env_term2)) 56 | assert np.all(np.equal(env_trunc1, env_trunc2)) 57 | assert np.all(np.equal(obs1, obs2)) 58 | assert all( 59 | np.all(np.equal(r1, r2)) 60 | for r1, r2 in zip(env1.rewards.values(), env2.rewards.values()) 61 | ) 62 | assert recursive_equal(infos1, infos2) 63 | actions = policy(obs1, agent) 64 | env1.step(actions) 65 | env2.step(actions) 66 | # env.envs[0].render() 67 | for j in range(2): 68 | # if agent_passes[j]: 69 | # print("pass") 70 | if rew1[j] != 0: 71 | print(j, agent, rew1, agent_term1[j], agent_trunc1[j]) 72 | if env_term1[j] or env_trunc1[j]: 73 | print(j, "done") 74 | envs_done += 1 75 | if envs_done == n_envs + 1: 76 | print("test passed") 77 | return 78 | -------------------------------------------------------------------------------- /test/gym_unwrapped_test.py: -------------------------------------------------------------------------------- 1 | from test.dummy_gym_env import DummyEnv 2 | 3 | import numpy as np 4 | from gymnasium import spaces 5 | 6 | from supersuit import ( 7 | clip_actions_v0, 8 | clip_reward_v0, 9 | color_reduction_v0, 10 | delay_observations_v0, 11 | dtype_v0, 12 | flatten_v0, 13 | frame_skip_v0, 14 | frame_stack_v1, 15 | max_observation_v0, 16 | nan_random_v0, 17 | nan_zeros_v0, 18 | normalize_obs_v0, 19 | scale_actions_v0, 20 | sticky_actions_v0, 21 | ) 22 | 23 | 24 | def unwrapped_check(env): 25 | # image observations 26 | if isinstance(env.observation_space, spaces.Box): 27 | if ( 28 | (env.observation_space.low.shape == 3) 29 | and (env.observation_space.low == 0).all() 30 | and (len(env.observation_space.shape[2]) == 3) 31 | and (env.observation_space.high == 255).all() 32 | ): 33 | env = max_observation_v0(env, 2) 34 | env = color_reduction_v0(env, mode="full") 35 | env = normalize_obs_v0(env) 36 | 37 | # box action spaces 38 | if isinstance(env.action_space, spaces.Box): 39 | env = clip_actions_v0(env) 40 | env = scale_actions_v0(env, 0.5) 41 | 42 | # stackable observations 43 | if isinstance(env.observation_space, spaces.Box) or isinstance( 44 | env.observation_space, spaces.Discrete 45 | ): 46 | env = frame_stack_v1(env, 2) 47 | 48 | # not discrete and not multibinary observations 49 | if not isinstance(env.observation_space, spaces.Discrete) and not isinstance( 50 | env.observation_space, spaces.MultiBinary 51 | ): 52 | env = dtype_v0(env, np.float16) 53 | env = flatten_v0(env) 54 | env = frame_skip_v0(env, 2) 55 | 56 | # everything else 57 | env = clip_reward_v0(env, lower_bound=-1, upper_bound=1) 58 | env = delay_observations_v0(env, 2) 59 | env = sticky_actions_v0(env, 0.5) 60 | env = nan_random_v0(env) 61 | env = nan_zeros_v0(env) 62 | 63 | assert env.unwrapped.__class__ == DummyEnv, f"Failed to unwrap {env}" 64 | 65 | 66 | def test_unwrapped(): 67 | observation_spaces = [] 68 | observation_spaces.append( 69 | spaces.Box(low=-1.0, high=1.0, shape=[2], dtype=np.float32) 70 | ) 71 | observation_spaces.append( 72 | spaces.Box(low=0, high=255, shape=[64, 64, 3], dtype=np.int16) 73 | ) 74 | observation_spaces.append(spaces.Discrete(5)) 75 | observation_spaces.append(spaces.MultiBinary([3, 4])) 76 | 77 | action_spaces = [] 78 | action_spaces.append(spaces.Box(-3.0, 3.0, [3], np.float32)) 79 | action_spaces.append(spaces.Discrete(5)) 80 | action_spaces.append(spaces.MultiDiscrete([3, 5])) 81 | 82 | for obs_space in observation_spaces: 83 | for act_space in action_spaces: 84 | env = DummyEnv(obs_space.sample(), obs_space, act_space) 85 | unwrapped_check(env) 86 | -------------------------------------------------------------------------------- /supersuit/lambda_wrappers/action_lambda.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import gymnasium 4 | from gymnasium.spaces import Space 5 | 6 | from supersuit.utils.base_aec_wrapper import BaseWrapper 7 | from supersuit.utils.wrapper_chooser import WrapperChooser 8 | 9 | 10 | class aec_action_lambda(BaseWrapper): 11 | def __init__(self, env, change_action_fn, change_space_fn): 12 | assert callable( 13 | change_action_fn 14 | ), f"change_action_fn needs to be a function. It is {change_action_fn}" 15 | assert callable( 16 | change_space_fn 17 | ), f"change_space_fn needs to be a function. It is {change_space_fn}" 18 | 19 | self.change_action_fn = change_action_fn 20 | self.change_space_fn = change_space_fn 21 | 22 | super().__init__(env) 23 | if hasattr(self, "possible_agents"): 24 | for agent in self.possible_agents: 25 | # call any validation logic in this function 26 | self.action_space(agent) 27 | 28 | def _modify_observation(self, agent, observation): 29 | return observation 30 | 31 | @functools.lru_cache(maxsize=None) 32 | def action_space(self, agent): 33 | old_act_space = self.env.action_space(agent) 34 | try: 35 | return self.change_space_fn(old_act_space, agent) 36 | except TypeError: 37 | return self.change_space_fn(old_act_space) 38 | 39 | def _modify_action(self, agent, action): 40 | old_act_space = self.env.action_space(agent) 41 | try: 42 | return self.change_action_fn(action, old_act_space, agent) 43 | except TypeError: 44 | return self.change_action_fn(action, old_act_space) 45 | 46 | 47 | class gym_action_lambda(gymnasium.Wrapper): 48 | def __init__(self, env, change_action_fn, change_space_fn): 49 | assert callable( 50 | change_action_fn 51 | ), f"change_action_fn needs to be a function. It is {change_action_fn}" 52 | assert callable( 53 | change_space_fn 54 | ), f"change_space_fn needs to be a function. It is {change_space_fn}" 55 | self.change_action_fn = change_action_fn 56 | self.change_space_fn = change_space_fn 57 | 58 | super().__init__(env) 59 | self._modify_spaces() 60 | 61 | def _modify_spaces(self): 62 | new_space = self.change_space_fn(self.action_space) 63 | assert isinstance( 64 | new_space, Space 65 | ), "output of change_space_fn argument to action_lambda_wrapper must be a gymnasium space" 66 | self.action_space = new_space 67 | 68 | def _modify_action(self, action): 69 | return self.change_action_fn(action, self.env.action_space) 70 | 71 | def step(self, action): 72 | return super().step(self._modify_action(action)) 73 | 74 | 75 | action_lambda_v1 = WrapperChooser( 76 | aec_wrapper=aec_action_lambda, gym_wrapper=gym_action_lambda 77 | ) 78 | -------------------------------------------------------------------------------- /test/test_utils/test_frame_stack.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium.spaces import Box, Discrete 3 | 4 | from supersuit.utils.frame_stack import stack_init, stack_obs, stack_obs_space 5 | 6 | 7 | stack_obs_space_3d = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 3)) 8 | stack_obs_space_2d = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(4, 3)) 9 | stack_obs_space_1d = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(3,)) 10 | 11 | stack_discrete = Discrete(3) 12 | 13 | STACK_SIZE = 11 14 | 15 | 16 | def test_obs_space(): 17 | assert stack_obs_space(stack_obs_space_1d, STACK_SIZE).shape == (3 * STACK_SIZE,) 18 | assert stack_obs_space(stack_obs_space_2d, STACK_SIZE).shape == (4, 3, STACK_SIZE) 19 | assert stack_obs_space(stack_obs_space_3d, STACK_SIZE).shape == ( 20 | 4, 21 | 4, 22 | 3 * STACK_SIZE, 23 | ) 24 | assert stack_obs_space(stack_discrete, STACK_SIZE).n == 3**STACK_SIZE 25 | 26 | 27 | def stack_obs_helper(frame_stack_list, obs_space, stack_size): 28 | stack = stack_init( 29 | obs_space, stack_size 30 | ) # stack_reset_obs(frame_stack_list[0], stack_size) 31 | for obs in frame_stack_list: 32 | stack = stack_obs(stack, obs, obs_space, stack_size) 33 | return stack 34 | 35 | 36 | def test_change_observation(): 37 | assert stack_obs_helper( 38 | [stack_obs_space_1d.low], stack_obs_space_1d, STACK_SIZE 39 | ).shape == (3 * STACK_SIZE,) 40 | assert stack_obs_helper( 41 | [stack_obs_space_1d.low, stack_obs_space_1d.high], 42 | stack_obs_space_1d, 43 | STACK_SIZE, 44 | ).shape == (3 * STACK_SIZE,) 45 | assert stack_obs_helper( 46 | [stack_obs_space_2d.low], stack_obs_space_2d, STACK_SIZE 47 | ).shape == (4, 3, STACK_SIZE) 48 | assert stack_obs_helper( 49 | [stack_obs_space_2d.low, stack_obs_space_2d.high], 50 | stack_obs_space_2d, 51 | STACK_SIZE, 52 | ).shape == (4, 3, STACK_SIZE) 53 | assert stack_obs_helper( 54 | [stack_obs_space_3d.low], stack_obs_space_3d, STACK_SIZE 55 | ).shape == (4, 4, 3 * STACK_SIZE) 56 | 57 | assert stack_obs_helper([1, 2], stack_discrete, STACK_SIZE) == 2 + 1 * 3 58 | 59 | stacked = stack_obs_helper( 60 | [stack_obs_space_2d.low, stack_obs_space_2d.high], stack_obs_space_2d, 3 61 | ) 62 | raw = np.stack( 63 | [ 64 | np.zeros_like(stack_obs_space_2d.high), 65 | stack_obs_space_2d.low, 66 | stack_obs_space_2d.high, 67 | ], 68 | axis=2, 69 | ) 70 | assert np.all(np.equal(stacked, raw)) 71 | 72 | stacked = stack_obs_helper( 73 | [stack_obs_space_3d.low, stack_obs_space_3d.high], stack_obs_space_3d, 3 74 | ) 75 | raw = np.concatenate( 76 | [ 77 | np.zeros_like(stack_obs_space_3d.high), 78 | stack_obs_space_3d.low, 79 | stack_obs_space_3d.high, 80 | ], 81 | axis=2, 82 | ) 83 | assert np.all(np.equal(stacked, raw)) 84 | -------------------------------------------------------------------------------- /test/test_vector/test_aec_vector_values.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from pettingzoo.butterfly import knights_archers_zombies_v10 5 | from pettingzoo.classic import rps_v2 6 | from pettingzoo.mpe import simple_world_comm_v3 7 | 8 | from supersuit import vectorize_aec_env_v0 9 | 10 | 11 | def test_all(): 12 | NUM_ENVS = 5 13 | 14 | def test_vec_env(vec_env): 15 | vec_env.reset() 16 | ( 17 | obs, 18 | rew, 19 | agent_term, 20 | agent_trunc, 21 | env_term, 22 | env_trunc, 23 | agent_passes, 24 | infos, 25 | ) = vec_env.last() 26 | print(np.asarray(obs).shape) 27 | assert len(obs) == NUM_ENVS 28 | act_space = vec_env.action_space(vec_env.agent_selection) 29 | assert np.all(np.equal(obs, vec_env.observe(vec_env.agent_selection))) 30 | assert len(vec_env.observe(vec_env.agent_selection)) == NUM_ENVS 31 | vec_env.step([act_space.sample() for _ in range(NUM_ENVS)]) 32 | ( 33 | obs, 34 | rew, 35 | agent_term, 36 | agent_trunc, 37 | env_term, 38 | env_trunc, 39 | agent_passes, 40 | infos, 41 | ) = vec_env.last(observe=False) 42 | assert obs is None 43 | 44 | def test_infos(vec_env): 45 | vec_env.reset() 46 | infos = vec_env.infos[vec_env.agent_selection] 47 | assert infos[1]["legal_moves"] 48 | 49 | def test_seed(vec_env): 50 | vec_env.reset(seed=4) 51 | 52 | def test_some_done(vec_env): 53 | vec_env.reset() 54 | act_space = vec_env.action_space(vec_env.agent_selection) 55 | assert not any(done for dones in vec_env.dones.values() for done in dones) 56 | vec_env.step([act_space.sample() for _ in range(NUM_ENVS)]) 57 | assert any(rew != 0 for rews in vec_env.rewards.values() for rew in rews) 58 | any_done_first = any(done for dones in vec_env.dones.values() for done in dones) 59 | vec_env.step([act_space.sample() for _ in range(NUM_ENVS)]) 60 | any_done_second = any( 61 | done for dones in vec_env.dones.values() for done in dones 62 | ) 63 | assert any_done_first and any_done_second 64 | 65 | def select_action(vec_env, passes, i): 66 | my_info = vec_env.infos[vec_env.agent_selection][i] 67 | if False and not passes[i] and "legal_moves" in my_info: 68 | return random.choice(my_info["legal_moves"]) 69 | else: 70 | act_space = vec_env.action_space(vec_env.agent_selection) 71 | return act_space.sample() 72 | 73 | for num_cpus in [0, 1]: 74 | test_vec_env(vectorize_aec_env_v0(rps_v2.env(), NUM_ENVS, num_cpus=num_cpus)) 75 | test_vec_env( 76 | vectorize_aec_env_v0( 77 | knights_archers_zombies_v10.env(), NUM_ENVS, num_cpus=num_cpus 78 | ) 79 | ) 80 | test_vec_env( 81 | vectorize_aec_env_v0( 82 | simple_world_comm_v3.env(), NUM_ENVS, num_cpus=num_cpus 83 | ) 84 | ) 85 | -------------------------------------------------------------------------------- /supersuit/generic_wrappers/nan_wrappers.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import gymnasium 4 | import numpy as np 5 | 6 | from supersuit.lambda_wrappers import action_lambda_v1 7 | 8 | from .utils.base_modifier import BaseModifier 9 | from .utils.shared_wrapper_util import shared_wrapper 10 | 11 | 12 | def nan_random_v0(env): 13 | class NanRandomModifier(BaseModifier): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def reset(self, seed=None, options=None): 18 | self.np_random, seed = gymnasium.utils.seeding.np_random(seed) 19 | 20 | return super().reset(seed, options=options) 21 | 22 | def modify_action(self, action): 23 | if action is not None and np.isnan(action).any(): 24 | obs = self.cur_obs 25 | if isinstance(obs, dict) and "action mask" in obs: 26 | warnings.warn( 27 | "[WARNING]: Step received an NaN action {}. Environment is {}. Taking a random action from 'action mask'.".format( 28 | action, self 29 | ) 30 | ) 31 | action = self.np_random.choice(np.flatnonzero(obs["action_mask"])) 32 | else: 33 | warnings.warn( 34 | "[WARNING]: Step received an NaN action {}. Environment is {}. Taking a random action.".format( 35 | action, self 36 | ) 37 | ) 38 | action = self.action_space.sample() 39 | 40 | return action 41 | 42 | return shared_wrapper(env, NanRandomModifier) 43 | 44 | 45 | def nan_noop_v0(env, no_op_action): 46 | def on_action(action, action_space): 47 | if action is None: 48 | warnings.warn( 49 | "[WARNING]: Step received an None action {}. Environment is {}. Taking no operation action.".format( 50 | action, env 51 | ) 52 | ) 53 | return None 54 | if np.isnan(action).any(): 55 | warnings.warn( 56 | "[WARNING]: Step received an NaN action {}. Environment is {}. Taking no operation action.".format( 57 | action, env 58 | ) 59 | ) 60 | return no_op_action 61 | return action 62 | 63 | return action_lambda_v1(env, on_action, lambda act_space: act_space) 64 | 65 | 66 | def nan_zeros_v0(env): 67 | def on_action(action, action_space): 68 | if action is None: 69 | warnings.warn( 70 | "[WARNING]: Step received an None action {}. Environment is {}. Taking the all zeroes action.".format( 71 | action, env 72 | ) 73 | ) 74 | return None 75 | if np.isnan(action).any(): 76 | warnings.warn( 77 | "[WARNING]: Step received an NaN action {}. Environment is {}. Taking the all zeroes action.".format( 78 | action, env 79 | ) 80 | ) 81 | return np.zeros_like(action) 82 | return action 83 | 84 | return action_lambda_v1(env, on_action, lambda act_space: act_space) 85 | -------------------------------------------------------------------------------- /test/test_utils/test_agent_indicator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from gymnasium.spaces import Box, Discrete 4 | 5 | from supersuit.utils.agent_indicator import ( 6 | change_obs_space, 7 | change_observation, 8 | get_indicator_map, 9 | ) 10 | 11 | 12 | obs_space_3d = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(4, 4, 3)) 13 | obs_space_2d = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(4, 3)) 14 | obs_space_1d = Box(low=np.float32(0.0), high=np.float32(1.0), shape=(3,)) 15 | 16 | discrete_space = Discrete(3) 17 | 18 | NUM_INDICATORS = 11 19 | 20 | 21 | def test_obs_space(): 22 | assert change_obs_space(obs_space_1d, NUM_INDICATORS).shape == (3 + NUM_INDICATORS,) 23 | assert change_obs_space(obs_space_2d, NUM_INDICATORS).shape == ( 24 | 4, 25 | 3, 26 | 1 + NUM_INDICATORS, 27 | ) 28 | assert change_obs_space(obs_space_3d, NUM_INDICATORS).shape == ( 29 | 4, 30 | 4, 31 | 3 + NUM_INDICATORS, 32 | ) 33 | assert change_obs_space(discrete_space, NUM_INDICATORS).n == 3 * NUM_INDICATORS 34 | 35 | 36 | def test_change_observation(): 37 | assert change_observation( 38 | np.ones((4, 4, 3)), obs_space_3d, (4, NUM_INDICATORS) 39 | ).shape == (4, 4, 3 + NUM_INDICATORS) 40 | assert change_observation( 41 | np.ones((4, 3)), obs_space_2d, (4, NUM_INDICATORS) 42 | ).shape == (4, 3, 1 + NUM_INDICATORS) 43 | assert change_observation(np.ones(41), obs_space_1d, (4, NUM_INDICATORS)).shape == ( 44 | 41 + NUM_INDICATORS, 45 | ) 46 | 47 | assert ( 48 | change_observation(np.ones((4, 4, 3)), obs_space_3d, (4, NUM_INDICATORS))[ 49 | 0, 0, 0 50 | ] 51 | == 1.0 52 | ) 53 | assert ( 54 | change_observation(np.ones((4, 4, 3)), obs_space_3d, (4, NUM_INDICATORS))[ 55 | 0, 0, 4 56 | ] 57 | == 0.0 58 | ) 59 | assert ( 60 | change_observation(np.ones((4, 4, 3)), obs_space_3d, (4, NUM_INDICATORS))[ 61 | 0, 1, 7 62 | ] 63 | == 1.0 64 | ) 65 | assert ( 66 | change_observation(np.ones((4, 4, 3)), obs_space_3d, (4, NUM_INDICATORS))[ 67 | 0, 0, 8 68 | ] 69 | == 0.0 70 | ) 71 | assert ( 72 | change_observation(np.ones((3,)), obs_space_1d, (4, NUM_INDICATORS))[2] == 1.0 73 | ) 74 | assert ( 75 | change_observation(np.ones((3,)), obs_space_1d, (4, NUM_INDICATORS))[6] == 0.0 76 | ) 77 | assert ( 78 | change_observation(np.ones((3,)), obs_space_1d, (4, NUM_INDICATORS))[7] == 1.0 79 | ) 80 | assert ( 81 | change_observation(np.ones((3,)), obs_space_1d, (4, NUM_INDICATORS))[8] == 0.0 82 | ) 83 | assert ( 84 | change_observation(2, discrete_space, (4, NUM_INDICATORS)) 85 | == 2 * NUM_INDICATORS + 4 86 | ) 87 | 88 | 89 | def test_get_indicator_map(): 90 | assert len(get_indicator_map(["bob", "joe", "fren"], False)) == 3 91 | 92 | with pytest.raises(AssertionError): 93 | get_indicator_map(["bob", "joe", "fren"], True) 94 | 95 | assert ( 96 | len( 97 | set(get_indicator_map(["bob_0", "joe_1", "fren_2", "joe_3"], True).values()) 98 | ) 99 | == 3 100 | ) 101 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Package ###################################################################### 2 | 3 | [build-system] 4 | requires = ["setuptools >= 61.0.0"] 5 | build-backend = "setuptools.build_meta" 6 | 7 | [project] 8 | name="SuperSuit" 9 | description="Wrappers for Gymnasium and PettingZoo" 10 | readme="README.md" 11 | requires-python = ">= 3.9" 12 | authors = [{ name = "Farama Foundation", email = "contact@farama.org" }] 13 | license = { text = "MIT License" } 14 | keywords=["Reinforcement Learning", "game", "RL", "AI", "gymnasium"] 15 | classifiers=[ 16 | "Programming Language :: Python :: 3.12", 17 | "Programming Language :: Python :: 3.11", 18 | "Programming Language :: Python :: 3.10", 19 | "Programming Language :: Python :: 3.9", 20 | "License :: OSI Approved :: MIT License", 21 | "Operating System :: OS Independent", 22 | ] 23 | dependencies = ["numpy>=1.19.0", "gymnasium>=1.0.0", "tinyscaler>=1.2.6"] 24 | dynamic = ["version"] 25 | 26 | [project.optional-dependencies] 27 | # Update dependencies in `all` if any are added or removed 28 | 29 | # Remove "pettingzoo[all,atari]" from dependencies for now 30 | # as there's a pygame version mismatch issue for pettingzoo version 31 | # 1.24.3. Once pettingzoo gets an updated version, add it back. 32 | testing = ["AutoROM", "pytest", "pytest-xdist", "stable-baselines3>=2.0.0", "moviepy >=1.0.0"] 33 | 34 | [project.urls] 35 | Homepage = "https://farama.org" 36 | Repository = "https://github.com/Farama-Foundation/SuperSuit" 37 | "Bug Report" = "https://github.com/Farama-Foundation/SuperSuit/issues" 38 | 39 | 40 | [tool.setuptools] 41 | include-package-data = true 42 | 43 | [tool.setuptools.packages.find] 44 | include = ["supersuit", "supersuit.*"] 45 | 46 | # Linters and Test tools ####################################################### 47 | 48 | [tool.black] 49 | safe = true 50 | 51 | [tool.isort] 52 | atomic = true 53 | profile = "black" 54 | src_paths = ["supersuit", "test"] 55 | extra_standard_library = ["typing_extensions"] 56 | indent = 4 57 | lines_after_imports = 2 58 | multi_line_output = 3 59 | 60 | [tool.pyright] 61 | include = ["supersuit/**", "tests/**"] 62 | exclude = ["**/node_modules", "**/__pycache__"] 63 | strict = [] 64 | 65 | typeCheckingMode = "basic" 66 | pythonVersion = "3.9" 67 | pythonPlatform = "All" 68 | typeshedPath = "typeshed" 69 | enableTypeIgnoreComments = true 70 | 71 | # This is required as the CI pre-commit does not download the module (i.e. numpy, pygame, box2d) 72 | # Therefore, we have to ignore missing imports 73 | reportMissingImports = "none" 74 | # Some modules are missing type stubs, which is an issue when running pyright locally 75 | reportMissingTypeStubs = false 76 | # For warning and error, will raise an error when 77 | reportInvalidTypeVarUse = "none" 78 | 79 | # reportUnknownMemberType = "warning" # -> raises 6035 warnings 80 | # reportUnknownParameterType = "warning" # -> raises 1327 warnings 81 | # reportUnknownVariableType = "warning" # -> raises 2585 warnings 82 | # reportUnknownArgumentType = "warning" # -> raises 2104 warnings 83 | reportGeneralTypeIssues = "none" # -> commented out raises 489 errors 84 | # reportUntypedFunctionDecorator = "none" # -> pytest.mark.parameterize issues 85 | 86 | reportPrivateUsage = "warning" 87 | reportUnboundVariable = "warning" 88 | 89 | [tool.pytest.ini_options] 90 | filterwarnings = ["ignore::DeprecationWarning:gymnasium.*:"] 91 | addopts = [ "-n=auto" ] 92 | -------------------------------------------------------------------------------- /supersuit/vector/vector_constructors.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import cloudpickle 4 | import gymnasium 5 | from pettingzoo.utils.env import ParallelEnv 6 | 7 | from . import MakeCPUAsyncConstructor, MarkovVectorEnv 8 | 9 | 10 | def vec_env_args(env, num_envs): 11 | def env_fn(): 12 | env_copy = cloudpickle.loads(cloudpickle.dumps(env)) 13 | return env_copy 14 | 15 | return [env_fn] * num_envs, env.observation_space, env.action_space 16 | 17 | 18 | def warn_not_gym_env(env, fn_name): 19 | if not isinstance(env, gymnasium.Env): 20 | warnings.warn( 21 | f"{fn_name} took in an environment which does not inherit from gymnasium.Env. Note that gym_vec_env only takes in gymnasium-style environments, not pettingzoo environments." 22 | ) 23 | 24 | 25 | def gym_vec_env_v0(env, num_envs, multiprocessing=False): 26 | warn_not_gym_env(env, "gym_vec_env") 27 | args = vec_env_args(env, num_envs)[:1] 28 | constructor = ( 29 | gymnasium.vector.AsyncVectorEnv 30 | if multiprocessing 31 | else gymnasium.vector.SyncVectorEnv 32 | ) 33 | return constructor(*args) 34 | 35 | 36 | def stable_baselines_vec_env_v0(env, num_envs, multiprocessing=False): 37 | import stable_baselines 38 | 39 | warn_not_gym_env(env, "stable_baselines_vec_env") 40 | args = vec_env_args(env, num_envs)[:1] 41 | constructor = ( 42 | stable_baselines.common.vec_env.SubprocVecEnv 43 | if multiprocessing 44 | else stable_baselines.common.vec_env.DummyVecEnv 45 | ) 46 | return constructor(*args) 47 | 48 | 49 | def stable_baselines3_vec_env_v0(env, num_envs, multiprocessing=False): 50 | import stable_baselines3 51 | 52 | warn_not_gym_env(env, "stable_baselines3_vec_env") 53 | args = vec_env_args(env, num_envs)[:1] 54 | constructor = ( 55 | stable_baselines3.common.vec_env.SubprocVecEnv 56 | if multiprocessing 57 | else stable_baselines3.common.vec_env.DummyVecEnv 58 | ) 59 | return constructor(*args) 60 | 61 | 62 | def concat_vec_envs_v1(vec_env, num_vec_envs, num_cpus=0, base_class="gymnasium"): 63 | num_cpus = min(num_cpus, num_vec_envs) 64 | vec_env = MakeCPUAsyncConstructor(num_cpus)(*vec_env_args(vec_env, num_vec_envs)) 65 | 66 | if base_class == "gymnasium": 67 | return vec_env 68 | elif base_class == "stable_baselines": 69 | from .sb_vector_wrapper import SBVecEnvWrapper 70 | 71 | return SBVecEnvWrapper(vec_env) 72 | elif base_class == "stable_baselines3": 73 | from .sb3_vector_wrapper import SB3VecEnvWrapper 74 | 75 | return SB3VecEnvWrapper(vec_env) 76 | else: 77 | raise ValueError( 78 | "supersuit_vec_env only supports 'gymnasium', 'stable_baselines', and 'stable_baselines3' for its base_class" 79 | ) 80 | 81 | 82 | def pettingzoo_env_to_vec_env_v1(parallel_env): 83 | assert isinstance( 84 | parallel_env, ParallelEnv 85 | ), "pettingzoo_env_to_vec_env takes in a pettingzoo ParallelEnv. Can create a parallel_env with pistonball.parallel_env() or convert it from an AEC env with `from pettingzoo.utils.conversions import aec_to_parallel; aec_to_parallel(env)``" 86 | assert hasattr( 87 | parallel_env, "possible_agents" 88 | ), "environment passed to pettingzoo_env_to_vec_env must have possible_agents attribute." 89 | return MarkovVectorEnv(parallel_env) 90 | -------------------------------------------------------------------------------- /test/test_vector/test_vector_dict.py: -------------------------------------------------------------------------------- 1 | from test.dummy_aec_env import DummyEnv 2 | 3 | import numpy as np 4 | import pytest 5 | from gymnasium.spaces import Box, Dict, Discrete, Tuple 6 | from gymnasium.vector.utils import concatenate, create_empty_array 7 | from pettingzoo.utils.conversions import aec_to_parallel 8 | 9 | from supersuit import concat_vec_envs_v1, pettingzoo_env_to_vec_env_v1 10 | 11 | 12 | n_agents = 5 13 | 14 | 15 | def make_env(): 16 | test_env = DummyEnv( 17 | { 18 | str(i): { 19 | "feature": i * np.ones((5,), dtype=np.float32), 20 | "id": ( 21 | i * np.ones((7,), dtype=np.float32), 22 | i * np.ones((8,), dtype=np.float32), 23 | ), 24 | } 25 | for i in range(n_agents) 26 | }, 27 | { 28 | str(i): Dict( 29 | { 30 | "feature": Box(low=0, high=10, shape=(5,)), 31 | "id": Tuple( 32 | [ 33 | Box(low=0, high=10, shape=(7,)), 34 | Box(low=0, high=10, shape=(8,)), 35 | ] 36 | ), 37 | } 38 | ) 39 | for i in range(n_agents) 40 | }, 41 | { 42 | str(i): Dict( 43 | { 44 | "obs": Box(low=0, high=10, shape=(5,)), 45 | "mask": Discrete(10), 46 | } 47 | ) 48 | for i in range(n_agents) 49 | }, 50 | ) 51 | test_env.metadata["is_parallelizable"] = True 52 | return aec_to_parallel(test_env) 53 | 54 | 55 | def dict_vec_env_test(env): 56 | # tests that environment really is a vectorized 57 | # version of the environment returned by make_env 58 | 59 | obss, infos = env.reset() 60 | for i in range(55): 61 | actions = [env.action_space.sample() for i in range(env.num_envs)] 62 | actions = concatenate( 63 | env.action_space, 64 | actions, 65 | create_empty_array(env.action_space, env.num_envs), 66 | ) 67 | obss, rews, terms, truncs, infos = env.step(actions) 68 | assert obss["feature"][1][0] == 1 69 | assert { 70 | "feature": obss["feature"][1][:], 71 | "id": [o[1] for o in obss["id"]], 72 | } in env.observation_space 73 | # no agent death, only env death 74 | if any(terms): 75 | assert all(terms) 76 | if any(truncs): 77 | assert all(truncs) 78 | 79 | 80 | def test_pettingzoo_vec_env(): 81 | env = make_env() 82 | env = pettingzoo_env_to_vec_env_v1(env) 83 | dict_vec_env_test(env) 84 | 85 | 86 | def test_single_threaded_concatenate(): 87 | env = make_env() 88 | env = pettingzoo_env_to_vec_env_v1(env) 89 | env = concat_vec_envs_v1(env, 2, num_cpus=1) 90 | dict_vec_env_test(env) 91 | 92 | 93 | @pytest.mark.skip( 94 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188" 95 | ) 96 | def test_multi_threaded_concatenate(): 97 | env = make_env() 98 | env = pettingzoo_env_to_vec_env_v1(env) 99 | env = concat_vec_envs_v1(env, 2, num_cpus=2) 100 | dict_vec_env_test(env) 101 | 102 | 103 | if __name__ == "__main__": 104 | test_multi_threaded_concatenate() 105 | exit(0) 106 | -------------------------------------------------------------------------------- /supersuit/generic_wrappers/frame_stack.py: -------------------------------------------------------------------------------- 1 | from gymnasium.spaces import Box, Discrete 2 | 3 | from supersuit.utils.frame_stack import stack_init, stack_obs, stack_obs_space 4 | 5 | from .utils.base_modifier import BaseModifier 6 | from .utils.shared_wrapper_util import shared_wrapper 7 | 8 | 9 | def frame_stack_v1(env, stack_size=4, stack_dim=-1): 10 | assert isinstance(stack_size, int), "stack size of frame_stack must be an int" 11 | 12 | class FrameStackModifier(BaseModifier): 13 | def modify_obs_space(self, obs_space): 14 | if isinstance(obs_space, Box): 15 | assert ( 16 | 1 <= len(obs_space.shape) <= 3 17 | ), "frame_stack only works for 1, 2 or 3 dimensional observations" 18 | elif isinstance(obs_space, Discrete): 19 | pass 20 | else: 21 | assert ( 22 | False 23 | ), "Stacking is currently only allowed for Box and Discrete observation spaces. The given observation space is {}".format( 24 | obs_space 25 | ) 26 | 27 | self.old_obs_space = obs_space 28 | self.observation_space = stack_obs_space(obs_space, stack_size, stack_dim) 29 | return self.observation_space 30 | 31 | def reset(self, seed=None, options=None): 32 | self.stack = stack_init(self.old_obs_space, stack_size, stack_dim) 33 | 34 | def modify_obs(self, obs): 35 | self.stack = stack_obs( 36 | self.stack, obs, self.old_obs_space, stack_size, stack_dim 37 | ) 38 | 39 | return self.stack 40 | 41 | def get_last_obs(self): 42 | return self.stack 43 | 44 | return shared_wrapper(env, FrameStackModifier) 45 | 46 | 47 | def frame_stack_v2(env, stack_size=4, stack_dim=-1): 48 | assert isinstance(stack_size, int), "stack size of frame_stack must be an int" 49 | assert f"stack_dim should be 0 or -1, not {stack_dim}" 50 | 51 | class FrameStackModifier(BaseModifier): 52 | def modify_obs_space(self, obs_space): 53 | if isinstance(obs_space, Box): 54 | assert ( 55 | 1 <= len(obs_space.shape) <= 3 56 | ), "frame_stack only works for 1, 2 or 3 dimensional observations" 57 | elif isinstance(obs_space, Discrete): 58 | pass 59 | else: 60 | assert ( 61 | False 62 | ), "Stacking is currently only allowed for Box and Discrete observation spaces. The given observation space is {}".format( 63 | obs_space 64 | ) 65 | 66 | self.old_obs_space = obs_space 67 | self.observation_space = stack_obs_space(obs_space, stack_size, stack_dim) 68 | return self.observation_space 69 | 70 | def reset(self, seed=None, options=None): 71 | self.stack = stack_init(self.old_obs_space, stack_size, stack_dim) 72 | self.reset_flag = True 73 | 74 | def modify_obs(self, obs): 75 | if self.reset_flag: 76 | for _ in range(stack_size): 77 | self.stack = stack_obs( 78 | self.stack, obs, self.old_obs_space, stack_size, stack_dim 79 | ) 80 | self.reset_flag = False 81 | else: 82 | self.stack = stack_obs( 83 | self.stack, obs, self.old_obs_space, stack_size, stack_dim 84 | ) 85 | 86 | return self.stack 87 | 88 | def get_last_obs(self): 89 | return self.stack 90 | 91 | return shared_wrapper(env, FrameStackModifier) 92 | -------------------------------------------------------------------------------- /supersuit/utils/action_transforms/homogenize_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium import spaces 3 | from gymnasium.spaces import Box, Discrete 4 | 5 | 6 | def check_homogenize_spaces(all_spaces): 7 | assert len(all_spaces) > 0 8 | space1 = all_spaces[0] 9 | assert all( 10 | isinstance(space, space1.__class__) for space in all_spaces 11 | ), "all spaces to homogenize must be of same general shape" 12 | 13 | if isinstance(space1, spaces.Box): 14 | for space in all_spaces: 15 | assert isinstance( 16 | space, spaces.Box 17 | ), "all spaces for homogenize must be either Box or Discrete, not a mix" 18 | assert len(space1.shape) == len( 19 | space.shape 20 | ), "all spaces to homogenize must be of same shape" 21 | assert ( 22 | space1.dtype == space.dtype 23 | ), "all spaces to homogenize must be of same dtype" 24 | elif isinstance(space1, spaces.Discrete): 25 | for space in all_spaces: 26 | assert isinstance( 27 | space, spaces.Discrete 28 | ), "all spaces for homogenize must be either Box or Discrete, not a mix" 29 | else: 30 | assert False, "homogenization only supports Discrete and Box spaces" 31 | 32 | 33 | def pad_to(arr, new_shape, pad_value): 34 | old_shape = arr.shape 35 | if old_shape == new_shape: 36 | return arr 37 | pad_size = [ns - os for ns, os in zip(new_shape, old_shape)] 38 | pad_tuples = [(0, ps) for ps in pad_size] 39 | return np.pad(arr, pad_tuples, constant_values=pad_value) 40 | 41 | 42 | def homogenize_spaces(all_spaces): 43 | space1 = all_spaces[0] 44 | if isinstance(space1, spaces.Box): 45 | all_dims = np.array([space.shape for space in all_spaces], dtype=np.int32) 46 | max_dims = np.max(all_dims, axis=0) 47 | new_shape = tuple(max_dims) 48 | all_lows = np.stack( 49 | [ 50 | pad_to(space.low, new_shape, np.minimum(0, np.min(space.low))) 51 | for space in all_spaces 52 | ] 53 | ) 54 | all_highs = np.stack( 55 | [ 56 | pad_to(space.high, new_shape, np.maximum(1e-5, np.max(space.high))) 57 | for space in all_spaces 58 | ] 59 | ) 60 | new_low = np.min(all_lows, axis=0) 61 | new_high = np.max(all_highs, axis=0) 62 | assert new_shape == new_low.shape 63 | return Box(low=new_low, high=new_high, dtype=space1.dtype) 64 | elif isinstance(space1, spaces.Discrete): 65 | max_n = max([space.n for space in all_spaces]) 66 | return Discrete(max_n) 67 | else: 68 | assert False 69 | 70 | 71 | def dehomogenize_actions(orig_action_space, action): 72 | if isinstance(orig_action_space, spaces.Box): 73 | # choose only the relevant action values 74 | cur_shape = action.shape 75 | new_shape = orig_action_space.shape 76 | if cur_shape == new_shape: 77 | return action 78 | else: 79 | assert len(cur_shape) == len(new_shape) 80 | slices = [slice(0, i) for i in new_shape] 81 | new_action = action[tuple(slices)] 82 | 83 | return new_action 84 | 85 | elif isinstance(orig_action_space, spaces.Discrete): 86 | # extra action values refer to action value 0 87 | n = orig_action_space.n 88 | if action is None: 89 | return None 90 | if action > n - 1: 91 | action = 0 92 | return action 93 | else: 94 | assert False 95 | 96 | 97 | def homogenize_observations(obs_space, obs): 98 | if isinstance(obs_space, spaces.Box): 99 | return pad_to(obs, obs_space.shape, 0) 100 | elif isinstance(obs_space, spaces.Discrete): 101 | return obs_space 102 | else: 103 | assert False 104 | -------------------------------------------------------------------------------- /supersuit/vector/concat_vec_env.py: -------------------------------------------------------------------------------- 1 | import gymnasium.vector 2 | import numpy as np 3 | from gymnasium.spaces import Discrete 4 | from gymnasium.vector.utils import concatenate, create_empty_array, iterate 5 | 6 | from .single_vec_env import SingleVecEnv 7 | 8 | 9 | def transpose(ll): 10 | return [[ll[i][j] for i in range(len(ll))] for j in range(len(ll[0]))] 11 | 12 | 13 | @iterate.register(Discrete) 14 | def iterate_discrete(space, items): 15 | try: 16 | return iter(items) 17 | except TypeError: 18 | raise TypeError(f"Unable to iterate over the following elements: {items}") 19 | 20 | 21 | class ConcatVecEnv(gymnasium.vector.VectorEnv): 22 | def __init__(self, vec_env_fns, obs_space=None, act_space=None): 23 | self.vec_envs = vec_envs = [vec_env_fn() for vec_env_fn in vec_env_fns] 24 | for i in range(len(vec_envs)): 25 | if not hasattr(vec_envs[i], "num_envs"): 26 | vec_envs[i] = SingleVecEnv([lambda: vec_envs[i]]) 27 | self.metadata = self.vec_envs[0].metadata 28 | self.render_mode = self.vec_envs[0].render_mode 29 | self.observation_space = vec_envs[0].observation_space 30 | self.action_space = vec_envs[0].action_space 31 | tot_num_envs = sum(env.num_envs for env in vec_envs) 32 | self.num_envs = tot_num_envs 33 | 34 | def reset(self, seed=None, options=None): 35 | _res_obs = [] 36 | _res_infos = [] 37 | 38 | if seed is not None: 39 | for i in range(len(self.vec_envs)): 40 | _obs, _info = self.vec_envs[i].reset(seed=seed + i, options=options) 41 | _res_obs.append(_obs) 42 | _res_infos.append(_info) 43 | else: 44 | for i in range(len(self.vec_envs)): 45 | _obs, _info = self.vec_envs[i].reset(options=options) 46 | _res_obs.append(_obs) 47 | _res_infos.append(_info) 48 | 49 | # flatten infos (also done in step function) 50 | flattened_infos = [info for sublist in _res_infos for info in sublist] 51 | 52 | return self.concat_obs(_res_obs), flattened_infos 53 | 54 | def concat_obs(self, observations): 55 | return concatenate( 56 | self.observation_space, 57 | [ 58 | item 59 | for obs in observations 60 | for item in iterate(self.observation_space, obs) 61 | ], 62 | create_empty_array(self.observation_space, n=self.num_envs), 63 | ) 64 | 65 | def concatenate_actions(self, actions, n_actions): 66 | return concatenate( 67 | self.action_space, 68 | actions, 69 | create_empty_array(self.action_space, n=n_actions), 70 | ) 71 | 72 | def step_async(self, actions): 73 | self._saved_actions = actions 74 | 75 | def step_wait(self): 76 | return self.step(self._saved_actions) 77 | 78 | def step(self, actions): 79 | data = [] 80 | idx = 0 81 | actions = list(iterate(self.action_space, actions)) 82 | for venv in self.vec_envs: 83 | data.append( 84 | venv.step( 85 | self.concatenate_actions( 86 | actions[idx : idx + venv.num_envs], venv.num_envs 87 | ) 88 | ) 89 | ) 90 | idx += venv.num_envs 91 | observations, rewards, terminations, truncations, infos = transpose(data) 92 | observations = self.concat_obs(observations) 93 | rewards = np.concatenate(rewards, axis=0) 94 | terminations = np.concatenate(terminations, axis=0) 95 | truncations = np.concatenate(truncations, axis=0) 96 | infos = [ 97 | info for sublist in infos for info in sublist 98 | ] # flatten infos from nested lists 99 | return observations, rewards, terminations, truncations, infos 100 | 101 | def render(self): 102 | return self.vec_envs[0].render() 103 | 104 | def close(self): 105 | for vec_env in self.vec_envs: 106 | vec_env.close() 107 | 108 | def env_is_wrapped(self, wrapper_class): 109 | return sum( 110 | [sub_venv.env_is_wrapped(wrapper_class) for sub_venv in self.vec_envs], [] 111 | ) 112 | -------------------------------------------------------------------------------- /supersuit/utils/frame_stack.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium.spaces import Box, Discrete 3 | 4 | 5 | def get_tile_shape(shape, stack_size, stack_dim=-1): 6 | obs_dim = len(shape) 7 | 8 | if stack_dim == -1: 9 | if obs_dim == 1: 10 | tile_shape = (stack_size,) 11 | new_shape = shape 12 | elif obs_dim == 3: 13 | tile_shape = (1, 1, stack_size) 14 | new_shape = shape 15 | # stack 2-D frames 16 | elif obs_dim == 2: 17 | tile_shape = (1, 1, stack_size) 18 | new_shape = shape + (1,) 19 | else: 20 | assert False, "Stacking is only available for 1,2 or 3 dimensional arrays" 21 | 22 | elif stack_dim == 0: 23 | if obs_dim == 1: 24 | tile_shape = (stack_size,) 25 | new_shape = shape 26 | elif obs_dim == 3: 27 | tile_shape = (stack_size, 1, 1) 28 | new_shape = shape 29 | # stack 2-D frames 30 | elif obs_dim == 2: 31 | tile_shape = (stack_size, 1, 1) 32 | new_shape = (1,) + shape 33 | else: 34 | assert False, "Stacking is only available for 1,2 or 3 dimensional arrays" 35 | 36 | return tile_shape, new_shape 37 | 38 | 39 | def stack_obs_space(obs_space, stack_size, stack_dim=-1): 40 | """ 41 | obs_space_dict: Dictionary of observations spaces of agents 42 | stack_size: Number of frames in the observation stack 43 | Returns: 44 | New obs_space_dict 45 | """ 46 | if isinstance(obs_space, Box): 47 | dtype = obs_space.dtype 48 | # stack 1-D frames and 3-D frames 49 | tile_shape, new_shape = get_tile_shape( 50 | obs_space.low.shape, stack_size, stack_dim 51 | ) 52 | 53 | low = np.tile(obs_space.low.reshape(new_shape), tile_shape) 54 | high = np.tile(obs_space.high.reshape(new_shape), tile_shape) 55 | new_obs_space = Box(low=low, high=high, dtype=dtype) 56 | return new_obs_space 57 | elif isinstance(obs_space, Discrete): 58 | return Discrete(obs_space.n**stack_size) 59 | else: 60 | assert ( 61 | False 62 | ), "Stacking is currently only allowed for Box and Discrete observation spaces. The given observation space is {}".format( 63 | obs_space 64 | ) 65 | 66 | 67 | def stack_init(obs_space, stack_size, stack_dim=-1): 68 | if isinstance(obs_space, Box): 69 | tile_shape, new_shape = get_tile_shape( 70 | obs_space.low.shape, stack_size, stack_dim 71 | ) 72 | return np.tile(np.zeros(new_shape, dtype=obs_space.dtype), tile_shape) 73 | else: 74 | return 0 75 | 76 | 77 | def stack_obs(frame_stack, obs, obs_space, stack_size, stack_dim=-1): 78 | """ 79 | Parameters 80 | ---------- 81 | frame_stack : if not None, it is the stack of frames 82 | obs : new observation 83 | Rearranges frame_stack. Appends the new observation at the end. 84 | Throws away the oldest observation. 85 | stack_size : needed for stacking reset observations 86 | """ 87 | if isinstance(obs_space, Box): 88 | obs_shape = obs.shape 89 | agent_fs = frame_stack 90 | 91 | if len(obs_shape) == 1: 92 | size = obs_shape[0] 93 | agent_fs[:-size] = agent_fs[size:] 94 | agent_fs[-size:] = obs 95 | 96 | elif len(obs_shape) == 2: 97 | if stack_dim == -1: 98 | agent_fs[:, :, :-1] = agent_fs[:, :, 1:] 99 | agent_fs[:, :, -1] = obs 100 | elif stack_dim == 0: 101 | agent_fs[:-1] = agent_fs[1:] 102 | agent_fs[:-1] = obs 103 | 104 | elif len(obs_shape) == 3: 105 | if stack_dim == -1: 106 | nchannels = obs_shape[-1] 107 | agent_fs[:, :, :-nchannels] = agent_fs[:, :, nchannels:] 108 | agent_fs[:, :, -nchannels:] = obs 109 | elif stack_dim == 0: 110 | nchannels = obs_shape[0] 111 | agent_fs[:-nchannels] = agent_fs[nchannels:] 112 | agent_fs[-nchannels:] = obs 113 | 114 | return agent_fs 115 | 116 | elif isinstance(obs_space, Discrete): 117 | return (frame_stack * obs_space.n + obs) % (obs_space.n**stack_size) 118 | -------------------------------------------------------------------------------- /supersuit/utils/agent_indicator.py: -------------------------------------------------------------------------------- 1 | import re 2 | import warnings 3 | 4 | import numpy as np 5 | from gymnasium.spaces import Box, Discrete 6 | 7 | 8 | def change_obs_space(space, num_indicators): 9 | if isinstance(space, Box): 10 | ndims = len(space.shape) 11 | if ndims == 1: 12 | pad_space = np.min(space.high) * np.ones( 13 | (num_indicators,), dtype=space.dtype 14 | ) 15 | new_low = np.concatenate([space.low, np.zeros_like(pad_space)], axis=0) 16 | new_high = np.concatenate([space.high, pad_space], axis=0) 17 | new_space = Box(low=new_low, high=new_high, dtype=space.dtype) 18 | return new_space 19 | elif ndims == 3 or ndims == 2: 20 | orig_low = space.low if ndims == 3 else np.expand_dims(space.low, 2) 21 | orig_high = space.high if ndims == 3 else np.expand_dims(space.high, 2) 22 | pad_space = np.min(space.high) * np.ones( 23 | orig_low.shape[:2] + (num_indicators,), dtype=space.dtype 24 | ) 25 | new_low = np.concatenate([orig_low, np.zeros_like(pad_space)], axis=2) 26 | new_high = np.concatenate([orig_high, pad_space], axis=2) 27 | new_space = Box(low=new_low, high=new_high, dtype=space.dtype) 28 | return new_space 29 | elif isinstance(space, Discrete): 30 | return Discrete(space.n * num_indicators) 31 | 32 | assert ( 33 | False 34 | ), "agent_indicator space must be 1d, 2d, or 3d Box or Discrete, was {}".format( 35 | space 36 | ) 37 | 38 | 39 | def get_indicator_map(agents, type_only): 40 | if type_only: 41 | assert all( 42 | re.match("[a-z]+_[0-9]+", agent) for agent in agents 43 | ), "when the `type_only` parameter is True to agent_indicator, the agent names must follow the `_` format" 44 | agent_id_map = {} 45 | type_idx_map = {} 46 | idx_num = 0 47 | for agent in agents: 48 | type = agent.split("_")[0] 49 | if type not in type_idx_map: 50 | type_idx_map[type] = idx_num 51 | idx_num += 1 52 | agent_id_map[agent] = type_idx_map[type] 53 | if idx_num == 1: 54 | warnings.warn( 55 | "agent_indicator wrapper is degenerate, only one agent type; doing nothing" 56 | ) 57 | return agent_id_map 58 | else: 59 | return {agent: i for i, agent in enumerate(agents)} 60 | 61 | 62 | def check_params(spaces): 63 | spaces = list(spaces) 64 | first_space = spaces[0] 65 | for space in spaces: 66 | assert repr(space) == repr( 67 | first_space 68 | ), "spaces need to be the same shape to add an indicator. Try using the `pad_observations` wrapper before agent_indicator." 69 | change_obs_space(space, 1) 70 | 71 | 72 | def change_observation(obs, space, indicator_data): 73 | indicator_num, num_indicators = indicator_data 74 | assert 0 <= indicator_num < num_indicators 75 | if isinstance(space, Box): 76 | ndims = len(space.shape) 77 | if ndims == 1: 78 | old_len = len(obs) 79 | new_obs = np.pad(obs, (0, num_indicators)) 80 | # if all spaces are finite, use the max, otherwise use 1.0 as agent indicator 81 | if not np.isinf(space.high).any(): 82 | new_obs[indicator_num + old_len] = np.max(space.high) 83 | else: 84 | new_obs[indicator_num + old_len] = 1.0 85 | 86 | return new_obs 87 | elif ndims == 3 or ndims == 2: 88 | obs = obs if ndims == 3 else np.expand_dims(obs, 2) 89 | old_shaped3 = obs.shape[2] 90 | new_obs = np.pad(obs, [(0, 0), (0, 0), (0, num_indicators)]) 91 | # if all spaces are finite, use the max, otherwise use 1.0 as agent indicator 92 | if not np.isinf(space.high).any(): 93 | new_obs[:, :, old_shaped3 + indicator_num] = np.max(space.high) 94 | else: 95 | new_obs[:, :, old_shaped3 + indicator_num] = 1.0 96 | return new_obs 97 | elif isinstance(space, Discrete): 98 | return obs * num_indicators + indicator_num 99 | 100 | assert ( 101 | False 102 | ), "agent_indicator space must be 1d, 2d, or 3d Box or Discrete, was {}".format( 103 | space 104 | ) 105 | -------------------------------------------------------------------------------- /test/test_vector/test_pettingzoo_to_vec.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import pytest 5 | from pettingzoo.butterfly import knights_archers_zombies_v10 6 | from pettingzoo.mpe import simple_spread_v3, simple_world_comm_v3 7 | 8 | from supersuit import black_death_v3, concat_vec_envs_v1, pettingzoo_env_to_vec_env_v1 9 | 10 | 11 | def test_good_env(): 12 | env = simple_spread_v3.parallel_env() 13 | max_num_agents = len(env.possible_agents) 14 | env = pettingzoo_env_to_vec_env_v1(env) 15 | assert env.num_envs == max_num_agents 16 | 17 | obss, infos = env.reset() 18 | for i in range(55): 19 | actions = [env.action_space.sample() for i in range(env.num_envs)] 20 | 21 | # Check we're not passing a thing that gets mutated 22 | keep_obs = copy.deepcopy(obss) 23 | new_obss, rews, terms, truncs, infos = env.step(actions) 24 | 25 | assert hash(str(keep_obs)) == hash(str(obss)) 26 | assert len(new_obss) == max_num_agents 27 | assert len(rews) == max_num_agents 28 | assert len(terms) == max_num_agents 29 | assert len(truncs) == max_num_agents 30 | assert len(infos) == max_num_agents 31 | # no agent death, only env death 32 | if any(terms): 33 | assert all(terms) 34 | if any(truncs): 35 | assert all(truncs) 36 | obss = new_obss 37 | 38 | 39 | def test_good_vecenv(): 40 | num_envs = 2 41 | env = simple_spread_v3.parallel_env() 42 | max_num_agents = len(env.possible_agents) * num_envs 43 | env = pettingzoo_env_to_vec_env_v1(env) 44 | env = concat_vec_envs_v1(env, num_envs) 45 | 46 | obss, infos = env.reset() 47 | for i in range(55): 48 | actions = [env.action_space.sample() for i in range(env.num_envs)] 49 | 50 | # Check we're not passing a thing that gets mutated 51 | keep_obs = copy.deepcopy(obss) 52 | new_obss, rews, terms, truncs, infos = env.step(actions) 53 | 54 | assert hash(str(keep_obs)) == hash(str(obss)) 55 | assert len(new_obss) == max_num_agents 56 | assert len(rews) == max_num_agents 57 | assert len(terms) == max_num_agents 58 | assert len(truncs) == max_num_agents 59 | assert len(infos) == max_num_agents 60 | # no agent death, only env death 61 | if any(terms): 62 | assert all(terms) 63 | if any(truncs): 64 | assert all(truncs) 65 | obss = new_obss 66 | 67 | 68 | def test_bad_action_spaces_env(): 69 | env = simple_world_comm_v3.parallel_env() 70 | with pytest.raises(AssertionError): 71 | env = pettingzoo_env_to_vec_env_v1(env) 72 | 73 | 74 | def test_env_black_death_assertion(): 75 | env = knights_archers_zombies_v10.parallel_env(spawn_rate=50, max_cycles=2000) 76 | env = pettingzoo_env_to_vec_env_v1(env) 77 | with pytest.raises(AssertionError): 78 | for i in range(100): 79 | env.reset() 80 | for i in range(2000): 81 | actions = [env.action_space.sample() for i in range(env.num_envs)] 82 | obss, rews, terms, truncs, infos = env.step(actions) 83 | 84 | 85 | def test_env_black_death_wrapper(): 86 | env = knights_archers_zombies_v10.parallel_env(spawn_rate=50, max_cycles=300) 87 | env = black_death_v3(env) 88 | env = pettingzoo_env_to_vec_env_v1(env) 89 | env.reset() 90 | for i in range(300): 91 | actions = [env.action_space.sample() for i in range(env.num_envs)] 92 | obss, rews, terms, truncs, infos = env.step(actions) 93 | 94 | 95 | def test_terminal_obs_are_returned(): 96 | """ 97 | If we reach (and pass) the end of the episode, the last observation is returned in the info dict. 98 | """ 99 | max_cycles = 300 100 | env = knights_archers_zombies_v10.parallel_env(spawn_rate=50, max_cycles=300) 101 | env = black_death_v3(env) 102 | env = pettingzoo_env_to_vec_env_v1(env) 103 | env.reset(seed=42) 104 | 105 | # run past max_cycles or until terminated - causing the env to reset and continue 106 | for _ in range(0, max_cycles + 10): 107 | actions = [env.action_space.sample() for i in range(env.num_envs)] 108 | _, _, terms, truncs, infos = env.step(actions) 109 | 110 | env_done = (np.array(terms) | np.array(truncs)).all() 111 | 112 | if env_done: 113 | # check we have infos for all agents 114 | assert len(infos) == len(env.par_env.possible_agents) 115 | # check infos contain terminal_observation 116 | for info in infos: 117 | assert "terminal_observation" in info 118 | -------------------------------------------------------------------------------- /test/gym_mock_test.py: -------------------------------------------------------------------------------- 1 | from test.dummy_gym_env import DummyEnv 2 | 3 | import numpy as np 4 | import pytest 5 | from gymnasium.spaces import Box, Discrete 6 | 7 | import supersuit 8 | from supersuit import action_lambda_v1, dtype_v0, observation_lambda_v0, reshape_v0 9 | 10 | 11 | base_obs = (np.zeros([8, 8, 3]) + np.arange(3)).astype(np.float32) 12 | base_obs_space = Box(low=np.float32(0.0), high=np.float32(10.0), shape=[8, 8, 3]) 13 | base_act_spaces = Discrete(5) 14 | 15 | 16 | def test_reshape(): 17 | base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces) 18 | env = reshape_v0(base_env, (64, 3)) 19 | obs, info = env.reset() 20 | assert obs.shape == (64, 3) 21 | first_obs, _, _, _, _ = env.step(5) 22 | assert np.all(np.equal(first_obs, base_obs.reshape([64, 3]))) 23 | 24 | 25 | def new_continuous_dummy(): 26 | base_act_spaces = Box(low=np.float32(0.0), high=np.float32(10.0), shape=[3]) 27 | return DummyEnv(base_obs, base_obs_space, base_act_spaces) 28 | 29 | 30 | def new_dummy(): 31 | return DummyEnv(base_obs, base_obs_space, base_act_spaces) 32 | 33 | 34 | wrappers = [ 35 | supersuit.color_reduction_v0(new_dummy(), "R"), 36 | supersuit.resize_v1(dtype_v0(new_dummy(), np.uint8), x_size=5, y_size=10), 37 | supersuit.resize_v1( 38 | dtype_v0(new_dummy(), np.uint8), x_size=5, y_size=10, linear_interp=True 39 | ), 40 | supersuit.dtype_v0(new_dummy(), np.int32), 41 | supersuit.flatten_v0(new_dummy()), 42 | supersuit.reshape_v0(new_dummy(), (64, 3)), 43 | supersuit.normalize_obs_v0(new_dummy(), env_min=-1, env_max=5.0), 44 | supersuit.frame_stack_v1(new_dummy(), 8), 45 | supersuit.reward_lambda_v0(new_dummy(), lambda x: x / 10), 46 | supersuit.clip_reward_v0(new_dummy()), 47 | supersuit.clip_actions_v0(new_continuous_dummy()), 48 | supersuit.frame_skip_v0(new_dummy(), 4), 49 | supersuit.frame_skip_v0(new_dummy(), (4, 6)), 50 | supersuit.sticky_actions_v0(new_dummy(), 0.75), 51 | supersuit.delay_observations_v0(new_dummy(), 1), 52 | supersuit.max_observation_v0(new_dummy(), 3), 53 | supersuit.nan_noop_v0(new_dummy(), 0), 54 | supersuit.nan_zeros_v0(new_dummy()), 55 | supersuit.nan_random_v0(new_dummy()), 56 | supersuit.scale_actions_v0(new_continuous_dummy(), 0.5), 57 | ] 58 | 59 | 60 | @pytest.mark.parametrize("env", wrappers) 61 | def test_basic_wrappers(env): 62 | obs, info = env.reset(seed=5) 63 | act_space = env.action_space 64 | obs_space = env.observation_space 65 | assert obs_space.contains(obs) 66 | assert obs.dtype == obs_space.dtype 67 | for i in range(10): 68 | env.step(act_space.sample()) 69 | 70 | 71 | def test_lambda(): 72 | def add1(obs, obs_space): 73 | return obs + 1 74 | 75 | base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces) 76 | env = observation_lambda_v0(base_env, add1) 77 | obs0, info0 = env.reset() 78 | assert int(obs0[0][0][0]) == 1 79 | env = observation_lambda_v0(env, add1) 80 | obs0, info0 = env.reset() 81 | assert int(obs0[0][0][0]) == 2 82 | 83 | def tile_obs(obs, obs_space): 84 | shape_size = len(obs.shape) 85 | tile_shape = [1] * shape_size 86 | tile_shape[0] *= 2 87 | return np.tile(obs, tile_shape) 88 | 89 | env = observation_lambda_v0(env, tile_obs) 90 | obs0, info0 = env.reset() 91 | assert env.observation_space.shape == (16, 8, 3) 92 | 93 | def change_shape_fn(obs_space): 94 | return Box(low=0, high=1, shape=(32, 8, 3)) 95 | 96 | env = observation_lambda_v0(env, tile_obs) 97 | obs0, info0 = env.reset() 98 | assert env.observation_space.shape == (32, 8, 3) 99 | assert obs0.shape == (32, 8, 3) 100 | 101 | 102 | def test_action_lambda(): 103 | def inc1(x, space): 104 | return x + 1 105 | 106 | def change_space_fn(space): 107 | return Discrete(space.n + 1) 108 | 109 | base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces) 110 | env = action_lambda_v1(base_env, inc1, change_space_fn) 111 | assert env.action_space.n == base_env.action_space.n + 1 112 | env.reset() 113 | env.step(5) 114 | 115 | def one_hot(x, n): 116 | v = np.zeros(n) 117 | v[x] = 1 118 | return v 119 | 120 | act_spaces = Box(low=0, high=1, shape=(15,)) 121 | base_env = DummyEnv(base_obs, base_obs_space, act_spaces) 122 | env = action_lambda_v1( 123 | base_env, 124 | lambda action, act_space: one_hot(action, act_space.shape[0]), 125 | lambda act_space: Discrete(act_space.shape[0]), 126 | ) 127 | 128 | env.reset() 129 | env.step(2) 130 | 131 | 132 | def test_rew_lambda(): 133 | env = supersuit.reward_lambda_v0(new_dummy(), lambda x: x / 10) 134 | env.reset() 135 | obs, rew, termination, truncation, info = env.step(0) 136 | assert rew == 1.0 / 10 137 | -------------------------------------------------------------------------------- /test/test_vector/test_gym_vector.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import gymnasium 4 | import numpy as np 5 | import pytest 6 | from pettingzoo.mpe import simple_spread_v3 7 | 8 | from supersuit import concat_vec_envs_v1, gym_vec_env_v0, pettingzoo_env_to_vec_env_v1 9 | 10 | 11 | def recursive_equal(info1, info2): 12 | try: 13 | if info1 == info2: 14 | return True 15 | except ValueError: 16 | if isinstance(info1, np.ndarray) and isinstance(info2, np.ndarray): 17 | return np.all(np.equal(info1, info2)) 18 | elif isinstance(info1, dict) and isinstance(info2, dict): 19 | return all( 20 | ( 21 | set(info1.keys()) == set(info2.keys()) 22 | and recursive_equal(info1[i], info2[i]) 23 | ) 24 | for i in info1.keys() 25 | ) 26 | elif isinstance(info1, list) and isinstance(info2, list): 27 | return all(recursive_equal(i1, i2) for i1, i2 in zip(info1, info2)) 28 | return False 29 | 30 | 31 | def check_vec_env_equivalency(venv1, venv2, check_info=True): 32 | # assert venv1.observation_space == venv2.observation_space 33 | # assert venv1.action_space == venv2.action_space 34 | 35 | obs1 = venv1.reset(seed=51) 36 | obs2 = venv2.reset(seed=51) 37 | 38 | for i in range(400): 39 | action = [venv1.action_space.sample() for env in range(venv1.num_envs)] 40 | assert np.all(np.equal(obs1, obs2)) 41 | 42 | obs1, rew1, term1, trunc1, info1 = venv1.step(action) 43 | obs2, rew2, term2, trunc2, info2 = venv2.step(action) 44 | 45 | # uses close rather than equal due to inconsistency in reporting rewards as float32 or float64 46 | assert np.allclose(rew1, rew2) 47 | assert np.all(np.equal(term1, term2)) 48 | assert np.all(np.equal(trunc1, trunc2)) 49 | assert recursive_equal(info1, info2) or not check_info 50 | 51 | 52 | @pytest.mark.skip( 53 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188" 54 | ) 55 | def test_gym_supersuit_equivalency(): 56 | env = gymnasium.make("MountainCarContinuous-v0") 57 | num_envs = 3 58 | venv1 = concat_vec_envs_v1(env, num_envs) 59 | venv2 = gym_vec_env_v0(env, num_envs) 60 | check_vec_env_equivalency(venv1, venv2) 61 | 62 | 63 | @pytest.mark.skip( 64 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188" 65 | ) 66 | def test_inital_state_dissimilarity(): 67 | env = gymnasium.make("CartPole-v1") 68 | venv = concat_vec_envs_v1(env, 2) 69 | observations = venv.reset() 70 | assert not np.equal(observations[0], observations[1]).all() 71 | 72 | 73 | # we really don't want to have a stable baselines dependency even in tests 74 | # def test_stable_baselines_supersuit_equivalency(): 75 | # env = gymnasium.make("MountainCarContinuous-v0") 76 | # num_envs = 3 77 | # venv1 = supersuit_vec_env(env, num_envs, base_class='stable_baselines3') 78 | # venv2 = stable_baselines3_vec_env(env, num_envs) 79 | # check_vec_env_equivalency(venv1, venv2, check_info=False) # stable baselines does not implement info correctly 80 | 81 | 82 | @pytest.mark.skip( 83 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188" 84 | ) 85 | def test_mutliproc_single_proc_equivalency(): 86 | env = gymnasium.make("CartPole-v1") 87 | num_envs = 3 88 | # uses single threaded vector environment 89 | venv1 = concat_vec_envs_v1(env, num_envs, num_cpus=0) 90 | # uses multiprocessing vector environment 91 | venv2 = concat_vec_envs_v1(env, num_envs, num_cpus=4) 92 | check_vec_env_equivalency(venv1, venv2) 93 | 94 | 95 | @pytest.mark.skip( 96 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188" 97 | ) 98 | def test_multiagent_mutliproc_single_proc_equivalency(): 99 | env = simple_spread_v3.parallel_env(max_cycles=10) 100 | env = pettingzoo_env_to_vec_env_v1(env) 101 | num_envs = 3 102 | # uses single threaded vector environment 103 | venv1 = concat_vec_envs_v1(env, num_envs, num_cpus=0) 104 | # uses multiprocessing vector environment 105 | venv2 = concat_vec_envs_v1(env, num_envs, num_cpus=4) 106 | check_vec_env_equivalency(venv1, venv2) 107 | 108 | 109 | @pytest.mark.skip( 110 | reason="Wrapper depreciated, see https://github.com/Farama-Foundation/SuperSuit/issues/188" 111 | ) 112 | def test_multiproc_buffer(): 113 | num_envs = 2 114 | env = gymnasium.make("CartPole-v1") 115 | env = concat_vec_envs_v1(env, num_envs, num_cpus=2) 116 | 117 | obss = env.reset() 118 | for i in range(55): 119 | actions = [env.action_space.sample() for i in range(env.num_envs)] 120 | 121 | # Check we're not passing a thing that gets mutated 122 | keep_obs = copy.deepcopy(obss) 123 | new_obss, rews, terms, truncs, infos = env.step(actions) 124 | 125 | assert hash(str(keep_obs)) == hash(str(obss)) 126 | 127 | obss = new_obss 128 | -------------------------------------------------------------------------------- /test/aec_unwrapped_test.py: -------------------------------------------------------------------------------- 1 | from test.dummy_aec_env import DummyEnv 2 | 3 | import numpy as np 4 | from gymnasium import spaces 5 | 6 | from supersuit import ( 7 | agent_indicator_v0, 8 | black_death_v3, 9 | clip_actions_v0, 10 | clip_reward_v0, 11 | color_reduction_v0, 12 | delay_observations_v0, 13 | dtype_v0, 14 | flatten_v0, 15 | frame_skip_v0, 16 | frame_stack_v1, 17 | max_observation_v0, 18 | nan_random_v0, 19 | nan_zeros_v0, 20 | normalize_obs_v0, 21 | pad_action_space_v0, 22 | pad_observations_v0, 23 | scale_actions_v0, 24 | sticky_actions_v0, 25 | ) 26 | 27 | 28 | def observation_homogenizable(env, agents): 29 | homogenizable = True 30 | for agent in agents: 31 | homogenizable = homogenizable and ( 32 | isinstance(env.observation_space(agent), spaces.Box) 33 | or isinstance(env.observation_space(agent), spaces.Discrete) 34 | ) 35 | return homogenizable 36 | 37 | 38 | def action_homogenizable(env, agents): 39 | homogenizable = True 40 | for agent in agents: 41 | homogenizable = homogenizable and ( 42 | isinstance(env.action_space(agent), spaces.Box) 43 | or isinstance(env.action_space(agent), spaces.Discrete) 44 | ) 45 | return homogenizable 46 | 47 | 48 | def image_observation(env, agents): 49 | imagable = True 50 | for agent in agents: 51 | if isinstance(env.observation_space(agent), spaces.Box): 52 | imagable = imagable and (env.observation_space(agent).low.shape == 3) 53 | imagable = imagable and (len(env.observation_space(agent).shape[2]) == 3) 54 | imagable = imagable and (env.observation_space(agent).low == 0).all() 55 | imagable = imagable and (env.observation_space(agent).high == 255).all() 56 | else: 57 | return False 58 | return imagable 59 | 60 | 61 | def box_action(env, agents): 62 | boxable = True 63 | for agent in agents: 64 | boxable = boxable and isinstance(env.action_space(agent), spaces.Box) 65 | return boxable 66 | 67 | 68 | def not_dict_observation(env, agents): 69 | is_dict = True 70 | for agent in agents: 71 | is_dict = is_dict and (isinstance(env.observation_space(agent), spaces.Dict)) 72 | return not is_dict 73 | 74 | 75 | def not_discrete_observation(env, agents): 76 | is_discrete = True 77 | for agent in agents: 78 | is_discrete = is_discrete and ( 79 | isinstance(env.observation_space(agent), spaces.Discrete) 80 | ) 81 | return not is_discrete 82 | 83 | 84 | def not_multibinary_observation(env, agents): 85 | is_discrete = True 86 | for agent in agents: 87 | is_discrete = is_discrete and ( 88 | isinstance(env.observation_space(agent), spaces.MultiBinary) 89 | ) 90 | return not is_discrete 91 | 92 | 93 | def unwrapped_check(env): 94 | env.reset() 95 | agents = env.agents 96 | 97 | if image_observation(env, agents): 98 | env = max_observation_v0(env, 2) 99 | env = color_reduction_v0(env, mode="full") 100 | env = normalize_obs_v0(env) 101 | 102 | if box_action(env, agents): 103 | env = clip_actions_v0(env) 104 | env = scale_actions_v0(env, 0.5) 105 | 106 | if observation_homogenizable(env, agents): 107 | env = pad_observations_v0(env) 108 | env = frame_stack_v1(env, 2) 109 | env = agent_indicator_v0(env) 110 | env = black_death_v3(env) 111 | 112 | if ( 113 | not_dict_observation(env, agents) 114 | and not_discrete_observation(env, agents) 115 | and not_multibinary_observation(env, agents) 116 | ): 117 | env = dtype_v0(env, np.float16) 118 | env = flatten_v0(env) 119 | env = frame_skip_v0(env, 2) 120 | 121 | if action_homogenizable(env, agents): 122 | env = pad_action_space_v0(env) 123 | 124 | env = clip_reward_v0(env, lower_bound=-1, upper_bound=1) 125 | env = delay_observations_v0(env, 2) 126 | env = sticky_actions_v0(env, 0.5) 127 | env = nan_random_v0(env) 128 | env = nan_zeros_v0(env) 129 | 130 | assert env.unwrapped.__class__ == DummyEnv, f"Failed to unwrap {env}" 131 | 132 | 133 | def test_unwrapped(): 134 | observation_spaces = [] 135 | base = spaces.Box(low=-1.0, high=1.0, shape=[2], dtype=np.float32) 136 | observation_spaces.append({f"a{i}": base for i in range(2)}) 137 | base = spaces.Box(low=0, high=255, shape=[64, 64, 3], dtype=np.int16) 138 | observation_spaces.append({f"a{i}": base for i in range(2)}) 139 | base = spaces.Discrete(5) 140 | observation_spaces.append({f"a{i}": base for i in range(2)}) 141 | base = spaces.MultiBinary([3, 4]) 142 | observation_spaces.append({f"a{i}": base for i in range(2)}) 143 | 144 | action_spaces = [] 145 | base = spaces.Box(-3.0, 3.0, [3], np.float32) 146 | action_spaces.append({f"a{i}": base for i in range(2)}) 147 | base = spaces.Discrete(5) 148 | action_spaces.append({f"a{i}": base for i in range(2)}) 149 | base = spaces.MultiDiscrete([4, 5]) 150 | action_spaces.append({f"a{i}": base for i in range(2)}) 151 | 152 | for obs_space in observation_spaces: 153 | for act_space in action_spaces: 154 | base_obs = {a: obs_space[a].sample() for a in obs_space} 155 | env = DummyEnv(base_obs, obs_space, act_space) 156 | unwrapped_check(env) 157 | -------------------------------------------------------------------------------- /supersuit/vector/markov_vector_wrapper.py: -------------------------------------------------------------------------------- 1 | import gymnasium.vector 2 | import numpy as np 3 | from gymnasium.vector.utils import concatenate, create_empty_array, iterate 4 | 5 | 6 | class MarkovVectorEnv(gymnasium.vector.VectorEnv): 7 | def __init__(self, par_env, black_death=False): 8 | """ 9 | parameters: 10 | - par_env: the pettingzoo Parallel environment that will be converted to a gymnasium vector environment 11 | - black_death: whether to give zero valued observations and 0 rewards when an agent is done, allowing for environments with multiple numbers of agents. 12 | Is equivalent to adding the black death wrapper, but somewhat more efficient. 13 | 14 | The resulting object will be a valid vector environment that has a num_envs 15 | parameter equal to the max number of agents, will return an array of observations, 16 | rewards, dones, etc, and will reset environment automatically when it finishes 17 | """ 18 | self.par_env = par_env 19 | self.metadata = par_env.metadata 20 | self.render_mode = par_env.unwrapped.render_mode 21 | self.observation_space = par_env.observation_space(par_env.possible_agents[0]) 22 | self.action_space = par_env.action_space(par_env.possible_agents[0]) 23 | assert all( 24 | self.observation_space == par_env.observation_space(agent) 25 | for agent in par_env.possible_agents 26 | ), "observation spaces not consistent. Perhaps you should wrap with `supersuit.multiagent_wrappers.pad_observations_v0`?" 27 | assert all( 28 | self.action_space == par_env.action_space(agent) 29 | for agent in par_env.possible_agents 30 | ), "action spaces not consistent. Perhaps you should wrap with `supersuit.multiagent_wrappers.pad_action_space_v0`?" 31 | self.num_envs = len(par_env.possible_agents) 32 | self.black_death = black_death 33 | 34 | def concat_obs(self, obs_dict): 35 | obs_list = [] 36 | for i, agent in enumerate(self.par_env.possible_agents): 37 | if agent not in obs_dict: 38 | raise AssertionError( 39 | "environment has agent death. Not allowed for pettingzoo_env_to_vec_env_v1 unless black_death is True" 40 | ) 41 | obs_list.append(obs_dict[agent]) 42 | 43 | return concatenate( 44 | self.observation_space, 45 | obs_list, 46 | create_empty_array(self.observation_space, self.num_envs), 47 | ) 48 | 49 | def step_async(self, actions): 50 | self._saved_actions = actions 51 | 52 | def step_wait(self): 53 | return self.step(self._saved_actions) 54 | 55 | def reset(self, seed=None, options=None): 56 | # TODO: should this be changed to infos? 57 | _observations, infos = self.par_env.reset(seed=seed, options=options) 58 | observations = self.concat_obs(_observations) 59 | infs = [infos.get(agent, {}) for agent in self.par_env.possible_agents] 60 | return observations, infs 61 | 62 | def step(self, actions): 63 | actions = list(iterate(self.action_space, actions)) 64 | agent_set = set(self.par_env.agents) 65 | act_dict = { 66 | agent: actions[i] 67 | for i, agent in enumerate(self.par_env.possible_agents) 68 | if agent in agent_set 69 | } 70 | observations, rewards, terms, truncs, infos = self.par_env.step(act_dict) 71 | 72 | # adds last observation to info where user can get it 73 | terminations = np.fromiter(terms.values(), dtype=bool) 74 | truncations = np.fromiter(truncs.values(), dtype=bool) 75 | env_done = (terminations | truncations).all() 76 | if env_done: 77 | for agent, obs in observations.items(): 78 | infos[agent]["terminal_observation"] = obs 79 | 80 | rews = np.array( 81 | [rewards.get(agent, 0) for agent in self.par_env.possible_agents], 82 | dtype=np.float32, 83 | ) 84 | tms = np.array( 85 | [terms.get(agent, False) for agent in self.par_env.possible_agents], 86 | dtype=np.uint8, 87 | ) 88 | tcs = np.array( 89 | [truncs.get(agent, False) for agent in self.par_env.possible_agents], 90 | dtype=np.uint8, 91 | ) 92 | infs = [infos.get(agent, {}) for agent in self.par_env.possible_agents] 93 | 94 | if env_done: 95 | observations, reset_infs = self.reset() 96 | else: 97 | observations = self.concat_obs(observations) 98 | # empty infos for reset infs 99 | reset_infs = [{} for _ in range(len(self.par_env.possible_agents))] 100 | # combine standard infos and reset infos 101 | infs = [{**inf, **reset_inf} for inf, reset_inf in zip(infs, reset_infs)] 102 | 103 | assert ( 104 | self.black_death or self.par_env.agents == self.par_env.possible_agents 105 | ), "MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True" 106 | return observations, rews, tms, tcs, infs 107 | 108 | def render(self): 109 | return self.par_env.render() 110 | 111 | def close(self): 112 | return self.par_env.close() 113 | 114 | def env_is_wrapped(self, wrapper_class): 115 | """ 116 | env_is_wrapped only suppors vector and gymnasium environments 117 | currently, not pettingzoo environments 118 | """ 119 | return [False] * self.num_envs 120 | -------------------------------------------------------------------------------- /supersuit/lambda_wrappers/observation_lambda.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import gymnasium 4 | import numpy as np 5 | from gymnasium.spaces import Box, Space 6 | 7 | from supersuit.utils.base_aec_wrapper import BaseWrapper 8 | from supersuit.utils.wrapper_chooser import WrapperChooser 9 | 10 | 11 | class aec_observation_lambda(BaseWrapper): 12 | def __init__(self, env, change_observation_fn, change_obs_space_fn=None): 13 | assert callable( 14 | change_observation_fn 15 | ), "change_observation_fn needs to be a function. It is {}".format( 16 | change_observation_fn 17 | ) 18 | assert change_obs_space_fn is None or callable( 19 | change_obs_space_fn 20 | ), "change_obs_space_fn needs to be a function. It is {}".format( 21 | change_obs_space_fn 22 | ) 23 | 24 | self.change_observation_fn = change_observation_fn 25 | self.change_obs_space_fn = change_obs_space_fn 26 | 27 | super().__init__(env) 28 | 29 | if hasattr(self, "possible_agents"): 30 | for agent in self.possible_agents: 31 | # call any validation logic in this function 32 | self.observation_space(agent) 33 | 34 | def _modify_action(self, agent, action): 35 | return action 36 | 37 | def _check_wrapper_params(self): 38 | if self.change_obs_space_fn is None and hasattr(self, "possible_agents"): 39 | for agent in self.possible_agents: 40 | assert isinstance( 41 | self.observation_space(agent), Box 42 | ), "the observation_lambda_wrapper only allows the change_obs_space_fn argument to be optional for Box observation spaces" 43 | 44 | @functools.lru_cache(maxsize=None) 45 | def observation_space(self, agent): 46 | if self.change_obs_space_fn is None: 47 | space = self.env.observation_space(agent) 48 | try: 49 | trans_low = self.change_observation_fn(space.low, space, agent) 50 | trans_high = self.change_observation_fn(space.high, space, agent) 51 | except TypeError: 52 | trans_low = self.change_observation_fn(space.low, space) 53 | trans_high = self.change_observation_fn(space.high, space) 54 | new_low = np.minimum(trans_low, trans_high) 55 | new_high = np.maximum(trans_low, trans_high) 56 | 57 | return Box(low=new_low, high=new_high, dtype=new_low.dtype) 58 | else: 59 | old_obs_space = self.env.observation_space(agent) 60 | try: 61 | return self.change_obs_space_fn(old_obs_space, agent) 62 | except TypeError: 63 | return self.change_obs_space_fn(old_obs_space) 64 | 65 | def _modify_observation(self, agent, observation): 66 | old_obs_space = self.env.observation_space(agent) 67 | try: 68 | return self.change_observation_fn(observation, old_obs_space, agent) 69 | except TypeError: 70 | return self.change_observation_fn(observation, old_obs_space) 71 | 72 | 73 | class gym_observation_lambda(gymnasium.Wrapper): 74 | def __init__(self, env, change_observation_fn, change_obs_space_fn=None): 75 | assert callable( 76 | change_observation_fn 77 | ), "change_observation_fn needs to be a function. It is {}".format( 78 | change_observation_fn 79 | ) 80 | assert change_obs_space_fn is None or callable( 81 | change_obs_space_fn 82 | ), "change_obs_space_fn needs to be a function. It is {}".format( 83 | change_obs_space_fn 84 | ) 85 | self.change_observation_fn = change_observation_fn 86 | self.change_obs_space_fn = change_obs_space_fn 87 | 88 | super().__init__(env) 89 | self._check_wrapper_params() 90 | self._modify_spaces() 91 | 92 | def _check_wrapper_params(self): 93 | if self.change_obs_space_fn is None: 94 | space = self.observation_space 95 | assert isinstance( 96 | space, Box 97 | ), "the observation_lambda_wrapper only allows the change_obs_space_fn argument to be optional for Box observation spaces" 98 | 99 | def _modify_spaces(self): 100 | space = self.observation_space 101 | 102 | if self.change_obs_space_fn is None: 103 | new_low = self.change_observation_fn(space.low, space) 104 | new_high = self.change_observation_fn(space.high, space) 105 | new_space = Box(low=new_low, high=new_high, dtype=new_low.dtype) 106 | else: 107 | new_space = self.change_obs_space_fn(space) 108 | assert isinstance( 109 | new_space, Space 110 | ), "output of change_obs_space_fn to observation_lambda_wrapper must be a gymnasium space" 111 | self.observation_space = new_space 112 | 113 | def _modify_observation(self, observation): 114 | return self.change_observation_fn(observation, self.env.observation_space) 115 | 116 | def step(self, action): 117 | observation, rew, termination, truncation, info = self.env.step(action) 118 | observation = self._modify_observation(observation) 119 | return observation, rew, termination, truncation, info 120 | 121 | def reset(self, seed=None, options=None): 122 | observation, infos = self.env.reset(seed=seed, options=options) 123 | observation = self._modify_observation(observation) 124 | return observation, infos 125 | 126 | 127 | observation_lambda_v0 = WrapperChooser( 128 | aec_wrapper=aec_observation_lambda, gym_wrapper=gym_observation_lambda 129 | ) 130 | -------------------------------------------------------------------------------- /test/generated_agents_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from pettingzoo.test import api_test, parallel_test 4 | from pettingzoo.test.example_envs import ( 5 | generated_agents_env_v0, 6 | generated_agents_parallel_v0, 7 | ) 8 | 9 | import supersuit 10 | from supersuit import dtype_v0 11 | 12 | 13 | wrappers = [ 14 | supersuit.dtype_v0(generated_agents_parallel_v0.env(), np.int32), 15 | supersuit.flatten_v0(generated_agents_parallel_v0.env()), 16 | supersuit.normalize_obs_v0( 17 | dtype_v0(generated_agents_parallel_v0.env(), np.float32), 18 | env_min=-1, 19 | env_max=5.0, 20 | ), 21 | supersuit.frame_stack_v1(generated_agents_parallel_v0.env(), 8), 22 | supersuit.reward_lambda_v0(generated_agents_parallel_v0.env(), lambda x: x / 10), 23 | supersuit.clip_reward_v0(generated_agents_parallel_v0.env()), 24 | supersuit.nan_noop_v0(generated_agents_parallel_v0.env(), 0), 25 | supersuit.nan_zeros_v0(generated_agents_parallel_v0.env()), 26 | supersuit.nan_random_v0(generated_agents_parallel_v0.env()), 27 | supersuit.frame_skip_v0(generated_agents_parallel_v0.env(), 4), 28 | supersuit.sticky_actions_v0(generated_agents_parallel_v0.env(), 0.75), 29 | supersuit.delay_observations_v0(generated_agents_parallel_v0.env(), 3), 30 | supersuit.max_observation_v0(generated_agents_parallel_v0.env(), 3), 31 | ] 32 | 33 | 34 | # TODO: fix errors: AssertionError: action is not in action space 35 | @pytest.mark.skip( 36 | reason="skipped: unknown bug, most likely due to converting to AEC env (e.g., obs_lambda has no parallel wrapper)" 37 | ) 38 | @pytest.mark.parametrize("env", wrappers) 39 | def test_pettingzoo_aec_api_par_gen(env): 40 | api_test(env, num_cycles=50) 41 | 42 | 43 | wrappers = [ 44 | supersuit.dtype_v0(generated_agents_env_v0.env(), np.int32), 45 | supersuit.flatten_v0(generated_agents_env_v0.env()), 46 | supersuit.normalize_obs_v0( 47 | dtype_v0(generated_agents_env_v0.env(), np.float32), env_min=-1, env_max=5.0 48 | ), 49 | supersuit.frame_stack_v1(generated_agents_env_v0.env(), 8), 50 | supersuit.reward_lambda_v0(generated_agents_env_v0.env(), lambda x: x / 10), 51 | supersuit.clip_reward_v0(generated_agents_env_v0.env()), 52 | supersuit.nan_noop_v0(generated_agents_env_v0.env(), 0), 53 | supersuit.nan_zeros_v0(generated_agents_env_v0.env()), 54 | supersuit.nan_random_v0(generated_agents_env_v0.env()), 55 | supersuit.frame_skip_v0(generated_agents_env_v0.env(), 4), 56 | supersuit.sticky_actions_v0(generated_agents_env_v0.env(), 0.75), 57 | supersuit.delay_observations_v0(generated_agents_env_v0.env(), 3), 58 | supersuit.max_observation_v0(generated_agents_env_v0.env(), 3), 59 | ] 60 | 61 | 62 | # TODO fix error: ValueError: operands could not be broadcast together with shapes (42,) (10,) 63 | @pytest.mark.skip( 64 | reason="skipped: unknown bug, most likely due to converting to AEC env (e.g., obs_lambda has no parallel wrapper)" 65 | ) 66 | @pytest.mark.parametrize("env", wrappers) 67 | def test_pettingzoo_aec_api_aec_gen(env): 68 | api_test(env, num_cycles=50) 69 | 70 | 71 | parallel_wrappers = wrappers = [ 72 | supersuit.dtype_v0(generated_agents_parallel_v0.parallel_env(), np.int32), 73 | supersuit.flatten_v0(generated_agents_parallel_v0.parallel_env()), 74 | supersuit.normalize_obs_v0( 75 | dtype_v0(generated_agents_parallel_v0.parallel_env(), np.float32), 76 | env_min=-1, 77 | env_max=5.0, 78 | ), 79 | supersuit.frame_stack_v1(generated_agents_parallel_v0.parallel_env(), 8), 80 | supersuit.reward_lambda_v0( 81 | generated_agents_parallel_v0.parallel_env(), lambda x: x / 10 82 | ), 83 | supersuit.clip_reward_v0(generated_agents_parallel_v0.parallel_env()), 84 | supersuit.nan_noop_v0(generated_agents_parallel_v0.parallel_env(), 0), 85 | supersuit.nan_zeros_v0(generated_agents_parallel_v0.parallel_env()), 86 | supersuit.nan_random_v0(generated_agents_parallel_v0.parallel_env()), 87 | supersuit.frame_skip_v0(generated_agents_parallel_v0.parallel_env(), 4, 0), 88 | supersuit.sticky_actions_v0(generated_agents_parallel_v0.parallel_env(), 0.75), 89 | supersuit.delay_observations_v0(generated_agents_parallel_v0.parallel_env(), 3), 90 | supersuit.max_observation_v0(generated_agents_parallel_v0.parallel_env(), 3), 91 | ] 92 | 93 | 94 | # TODO: fix normalizing obs issue: ValueError: operands could not be broadcast together with shapes (48,) (20,) 95 | @pytest.mark.skip( 96 | reason="skipped: unknown bug, most likely due to converting to AEC env (e.g., obs_lambda has no parallel wrapper)" 97 | ) 98 | @pytest.mark.parametrize("env", parallel_wrappers) 99 | def test_pettingzoo_parallel_api_gen(env): 100 | parallel_test.parallel_api_test(env, num_cycles=50) 101 | 102 | 103 | wrapper_fns = [ 104 | lambda: supersuit.pad_action_space_v0(generated_agents_parallel_v0.env()), 105 | lambda: supersuit.pad_observations_v0(generated_agents_parallel_v0.env()), 106 | lambda: supersuit.agent_indicator_v0(generated_agents_parallel_v0.env()), 107 | lambda: supersuit.vectorize_aec_env_v0(generated_agents_parallel_v0.env(), 2), 108 | lambda: supersuit.pad_action_space_v0(generated_agents_parallel_v0.parallel_env()), 109 | lambda: supersuit.pad_observations_v0(generated_agents_parallel_v0.parallel_env()), 110 | lambda: supersuit.agent_indicator_v0(generated_agents_parallel_v0.parallel_env()), 111 | lambda: supersuit.pettingzoo_env_to_vec_env_v1( 112 | generated_agents_parallel_v0.parallel_env() 113 | ), 114 | ] 115 | 116 | 117 | @pytest.mark.parametrize("wrapper_fn", wrapper_fns) 118 | def test_pettingzoo_missing_optional_error_message(wrapper_fn): 119 | with pytest.raises(AssertionError, match=" must have "): 120 | wrapper_fn() 121 | -------------------------------------------------------------------------------- /supersuit/generic_wrappers/utils/shared_wrapper_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import gymnasium 4 | from pettingzoo.utils import BaseParallelWrapper 5 | from pettingzoo.utils.wrappers import OrderEnforcingWrapper as BaseWrapper 6 | 7 | from supersuit.utils.wrapper_chooser import WrapperChooser 8 | 9 | 10 | class shared_wrapper_aec(BaseWrapper): 11 | def __init__(self, env, modifier_class): 12 | super().__init__(env) 13 | 14 | self.modifier_class = modifier_class 15 | self.modifiers = {} 16 | self._cur_seed = None 17 | self._cur_options = None 18 | 19 | if hasattr(self.env, "possible_agents"): 20 | self.add_modifiers(self.env.possible_agents) 21 | 22 | @functools.lru_cache(maxsize=None) 23 | def observation_space(self, agent): 24 | return self.modifiers[agent].modify_obs_space(self.env.observation_space(agent)) 25 | 26 | @functools.lru_cache(maxsize=None) 27 | def action_space(self, agent): 28 | return self.modifiers[agent].modify_action_space(self.env.action_space(agent)) 29 | 30 | def add_modifiers(self, agents_list): 31 | for agent in agents_list: 32 | if agent not in self.modifiers: 33 | # populate modifier spaces 34 | self.modifiers[agent] = self.modifier_class() 35 | self.observation_space(agent) 36 | self.action_space(agent) 37 | self.modifiers[agent].reset( 38 | seed=self._cur_seed, options=self._cur_options 39 | ) 40 | 41 | # modifiers for each agent has a different seed 42 | if self._cur_seed is not None: 43 | self._cur_seed += 1 44 | 45 | def reset(self, seed=None, options=None): 46 | self._cur_seed = seed 47 | self._cur_options = options 48 | 49 | for mod in self.modifiers.values(): 50 | mod.reset(seed=seed, options=options) 51 | super().reset(seed=seed, options=options) 52 | 53 | self.add_modifiers(self.agents) 54 | self.modifiers[self.agent_selection].modify_obs( 55 | super().observe(self.agent_selection) 56 | ) 57 | 58 | def step(self, action): 59 | mod = self.modifiers[self.agent_selection] 60 | action = mod.modify_action(action) 61 | if ( 62 | self.terminations[self.agent_selection] 63 | or self.truncations[self.agent_selection] 64 | ): 65 | action = None 66 | super().step(action) 67 | self.add_modifiers(self.agents) 68 | self.modifiers[self.agent_selection].modify_obs( 69 | super().observe(self.agent_selection) 70 | ) 71 | 72 | def observe(self, agent): 73 | return self.modifiers[agent].get_last_obs() 74 | 75 | 76 | class shared_wrapper_parr(BaseParallelWrapper): 77 | def __init__(self, env, modifier_class): 78 | super().__init__(env) 79 | 80 | self.modifier_class = modifier_class 81 | self.modifiers = {} 82 | self._cur_seed = None 83 | self._cur_options = None 84 | 85 | if hasattr(self.env, "possible_agents"): 86 | self.add_modifiers(self.env.possible_agents) 87 | 88 | @functools.lru_cache(maxsize=None) 89 | def observation_space(self, agent): 90 | return self.modifiers[agent].modify_obs_space(self.env.observation_space(agent)) 91 | 92 | @functools.lru_cache(maxsize=None) 93 | def action_space(self, agent): 94 | return self.modifiers[agent].modify_action_space(self.env.action_space(agent)) 95 | 96 | def add_modifiers(self, agents_list): 97 | for agent in agents_list: 98 | if agent not in self.modifiers: 99 | # populate modifier spaces 100 | self.modifiers[agent] = self.modifier_class() 101 | self.observation_space(agent) 102 | self.action_space(agent) 103 | self.modifiers[agent].reset( 104 | seed=self._cur_seed, options=self._cur_options 105 | ) 106 | 107 | # modifiers for each agent has a different seed 108 | if self._cur_seed is not None: 109 | self._cur_seed += 1 110 | 111 | def reset(self, seed=None, options=None): 112 | self._cur_seed = seed 113 | self._cur_options = options 114 | 115 | observations, infos = super().reset(seed=seed, options=options) 116 | self.add_modifiers(self.agents) 117 | for agent, mod in self.modifiers.items(): 118 | mod.reset(seed=seed, options=options) 119 | observations = { 120 | agent: self.modifiers[agent].modify_obs(obs) 121 | for agent, obs in observations.items() 122 | } 123 | return observations, infos 124 | 125 | def step(self, actions): 126 | actions = { 127 | agent: self.modifiers[agent].modify_action(action) 128 | for agent, action in actions.items() 129 | } 130 | observations, rewards, terminations, truncations, infos = super().step(actions) 131 | self.add_modifiers(self.agents) 132 | observations = { 133 | agent: self.modifiers[agent].modify_obs(obs) 134 | for agent, obs in observations.items() 135 | } 136 | return observations, rewards, terminations, truncations, infos 137 | 138 | 139 | class shared_wrapper_gym(gymnasium.Wrapper): 140 | def __init__(self, env, modifier_class): 141 | super().__init__(env) 142 | self.modifier = modifier_class() 143 | self.observation_space = self.modifier.modify_obs_space(self.observation_space) 144 | self.action_space = self.modifier.modify_action_space(self.action_space) 145 | 146 | def reset(self, seed=None, options=None): 147 | self.modifier.reset(seed=seed, options=options) 148 | obs, info = super().reset(seed=seed, options=options) 149 | obs = self.modifier.modify_obs(obs) 150 | return obs, info 151 | 152 | def step(self, action): 153 | obs, rew, term, trunc, info = super().step(self.modifier.modify_action(action)) 154 | obs = self.modifier.modify_obs(obs) 155 | return obs, rew, term, trunc, info 156 | 157 | 158 | shared_wrapper = WrapperChooser( 159 | aec_wrapper=shared_wrapper_aec, 160 | gym_wrapper=shared_wrapper_gym, 161 | parallel_wrapper=shared_wrapper_parr, 162 | ) 163 | -------------------------------------------------------------------------------- /supersuit/aec_vector/vector_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pettingzoo.utils.agent_selector import agent_selector 3 | 4 | from .base_aec_vec_env import VectorAECEnv 5 | 6 | 7 | class SyncAECVectorEnv(VectorAECEnv): 8 | def __init__(self, env_constructors): 9 | assert len(env_constructors) >= 1 10 | assert callable( 11 | env_constructors[0] 12 | ), "env_constructor must be a callable object (i.e function) that create an environment" 13 | 14 | self.envs = [env_constructor() for env_constructor in env_constructors] 15 | self.num_envs = len(env_constructors) 16 | self.env = self.envs[0] 17 | self.max_num_agents = self.env.max_num_agents 18 | self.possible_agents = self.env.possible_agents 19 | self._agent_selector = agent_selector(self.possible_agents) 20 | 21 | def action_space(self, agent): 22 | return self.env.action_space(agent) 23 | 24 | def observation_space(self, agent): 25 | return self.env.observation_space(agent) 26 | 27 | def _find_active_agent(self): 28 | cur_selection = self.agent_selection 29 | while not any(cur_selection == env.agent_selection for env in self.envs): 30 | cur_selection = self._agent_selector.next() 31 | return cur_selection 32 | 33 | def _collect_dicts(self): 34 | self.rewards = { 35 | agent: np.array( 36 | [ 37 | env.rewards[agent] if agent in env.rewards else 0 38 | for env in self.envs 39 | ], 40 | dtype=np.float32, 41 | ) 42 | for agent in self.possible_agents 43 | } 44 | self._cumulative_rewards = { 45 | agent: np.array( 46 | [ 47 | env._cumulative_rewards[agent] 48 | if agent in env._cumulative_rewards 49 | else 0 50 | for env in self.envs 51 | ], 52 | dtype=np.float32, 53 | ) 54 | for agent in self.possible_agents 55 | } 56 | self.terminations = { 57 | agent: np.array( 58 | [ 59 | env.terminations[agent] if agent in env.terminations else True 60 | for env in self.envs 61 | ], 62 | dtype=np.uint8, 63 | ) 64 | for agent in self.possible_agents 65 | } 66 | self.truncations = { 67 | agent: np.array( 68 | [ 69 | env.truncations[agent] if agent in env.truncations else True 70 | for env in self.envs 71 | ], 72 | dtype=np.uint8, 73 | ) 74 | for agent in self.possible_agents 75 | } 76 | self.infos = { 77 | agent: [env.infos[agent] if agent in env.infos else {} for env in self.envs] 78 | for agent in self.possible_agents 79 | } 80 | 81 | def reset(self, seed=None, options=None): 82 | """ 83 | returns: list of observations 84 | """ 85 | if seed is not None: 86 | for i, env in enumerate(self.envs): 87 | env.reset(seed=seed + i, options=options) 88 | else: 89 | for i, env in enumerate(self.envs): 90 | env.reset(seed=None, options=options) 91 | 92 | self.agent_selection = self._agent_selector.reset() 93 | self.agent_selection = self._find_active_agent() 94 | 95 | self._collect_dicts() 96 | self.envs_terminations = np.zeros(self.num_envs) 97 | self.envs_truncations = np.zeros(self.num_envs) 98 | 99 | def observe(self, agent): 100 | observations = [] 101 | for env in self.envs: 102 | obs = ( 103 | env.observe(agent) 104 | if (agent in env.terminations) or (agent in env.truncations) 105 | else np.zeros_like(env.observation_space(agent).low) 106 | ) 107 | observations.append(obs) 108 | return np.stack(observations) 109 | 110 | def last(self, observe=True): 111 | passes = np.array( 112 | [env.agent_selection != self.agent_selection for env in self.envs], 113 | dtype=np.uint8, 114 | ) 115 | last_agent = self.agent_selection 116 | obs = self.observe(last_agent) if observe else None 117 | return ( 118 | obs, 119 | self._cumulative_rewards[last_agent], 120 | self.terminations[last_agent], 121 | self.truncations[last_agent], 122 | self.envs_terminations, 123 | self.envs_truncations, 124 | passes, 125 | self.infos[last_agent], 126 | ) 127 | 128 | def step(self, actions, observe=True): 129 | assert len(actions) == len( 130 | self.envs 131 | ), f"{len(actions)} actions given, but there are {len(self.envs)} environments!" 132 | old_agent = self.agent_selection 133 | 134 | envs_dones = [] 135 | for i, (act, env) in enumerate(zip(actions, self.envs)): 136 | # Prior to the truncation API update, the env was reset if env.agents was an empty list 137 | # After the truncation API update, the env needs to be reset if every agent is terminated OR truncated 138 | terminations = np.fromiter(env.terminations.values(), dtype=bool) 139 | truncations = np.fromiter(env.truncations.values(), dtype=bool) 140 | env_done = (terminations | truncations).all() 141 | envs_dones.append(env_done) 142 | 143 | if env_done: 144 | env.reset() 145 | elif env.agent_selection == old_agent: 146 | if isinstance(type(act), np.ndarray): 147 | act = np.array(act) 148 | act = ( 149 | act 150 | if not ( 151 | self.terminations[old_agent][i] 152 | or self.truncations[old_agent][i] 153 | ) 154 | else None 155 | ) # if the agent is dead, set action to None 156 | env.step(act) 157 | 158 | self.agent_selection = self._agent_selector.next() 159 | self.agent_selection = self._find_active_agent() 160 | 161 | self.envs_dones = np.array(envs_dones) 162 | self._collect_dicts() 163 | -------------------------------------------------------------------------------- /supersuit/generic_wrappers/frame_skip.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import numpy as np 3 | from pettingzoo.utils.wrappers import BaseParallelWrapper, BaseWrapper 4 | 5 | from supersuit.utils.frame_skip import check_transform_frameskip 6 | from supersuit.utils.make_defaultdict import make_defaultdict 7 | from supersuit.utils.wrapper_chooser import WrapperChooser 8 | 9 | 10 | class frame_skip_gym(gymnasium.Wrapper): 11 | def __init__(self, env, num_frames): 12 | super().__init__(env) 13 | self.num_frames = check_transform_frameskip(num_frames) 14 | 15 | def step(self, action): 16 | low, high = self.num_frames 17 | num_skips = int(self.env.unwrapped.np_random.integers(low, high + 1)) 18 | total_reward = 0.0 19 | 20 | for x in range(num_skips): 21 | obs, rew, term, trunc, info = super().step(action) 22 | total_reward += rew 23 | if term or trunc: 24 | break 25 | 26 | return obs, total_reward, term, trunc, info 27 | 28 | 29 | class StepAltWrapper(BaseWrapper): 30 | def _modify_action(self, agent, action): 31 | return action 32 | 33 | def _modify_observation(self, agent, observation): 34 | return observation 35 | 36 | 37 | class frame_skip_aec(StepAltWrapper): 38 | def __init__(self, env, num_frames): 39 | super().__init__(env) 40 | assert isinstance( 41 | num_frames, int 42 | ), "multi-agent frame skip only takes in an integer" 43 | assert num_frames > 0 44 | check_transform_frameskip(num_frames) 45 | self.num_frames = num_frames 46 | 47 | def reset(self, seed=None, options=None): 48 | super().reset(seed=seed, options=options) 49 | self.agents = self.env.agents[:] 50 | self.terminations = make_defaultdict({agent: False for agent in self.agents}) 51 | self.truncations = make_defaultdict({agent: False for agent in self.agents}) 52 | self.rewards = make_defaultdict({agent: 0.0 for agent in self.agents}) 53 | self._cumulative_rewards = make_defaultdict( 54 | {agent: 0.0 for agent in self.agents} 55 | ) 56 | self.infos = make_defaultdict({agent: {} for agent in self.agents}) 57 | self.skip_num = make_defaultdict({agent: 0 for agent in self.agents}) 58 | self.old_actions = make_defaultdict({agent: None for agent in self.agents}) 59 | self._final_observations = make_defaultdict( 60 | {agent: None for agent in self.agents} 61 | ) 62 | 63 | def observe(self, agent): 64 | fin_observe = self._final_observations[agent] 65 | return fin_observe if fin_observe is not None else super().observe(agent) 66 | 67 | def step(self, action): 68 | self._has_updated = True 69 | 70 | # if agent is dead, perform a None step 71 | if ( 72 | self.terminations[self.agent_selection] 73 | or self.truncations[self.agent_selection] 74 | ): 75 | if self.env.agents and self.agent_selection == self.env.agent_selection: 76 | self.env.step(None) 77 | self._was_dead_step(action) 78 | return 79 | 80 | cur_agent = self.agent_selection 81 | self._cumulative_rewards[cur_agent] = 0 82 | self.rewards = make_defaultdict({a: 0.0 for a in self.agents}) 83 | self.skip_num[ 84 | cur_agent 85 | ] = ( 86 | self.num_frames 87 | ) # set the skip num to the param inputted in the frame_skip wrapper 88 | self.old_actions[cur_agent] = action 89 | 90 | while ( 91 | self.old_actions[self.env.agent_selection] is not None 92 | ): # this is like `for x in range(num_skips):` (L18) 93 | step_agent = self.env.agent_selection 94 | 95 | # if agent is dead, perform a None step 96 | if (step_agent in self.env.terminations) or ( 97 | step_agent in self.env.truncations 98 | ): 99 | # reward = self.env.rewards[step_agent] 100 | # done = self.env.dones[step_agent] 101 | # info = self.env.infos[step_agent] 102 | observe, reward, termination, truncation, info = self.env.last( 103 | observe=False 104 | ) 105 | action = self.old_actions[step_agent] 106 | self.env.step(action) 107 | for agent in self.env.agents: 108 | self.rewards[agent] += self.env.rewards[agent] 109 | self.infos[self.env.agent_selection] = info 110 | 111 | while self.env.agents and ( 112 | self.env.terminations[self.env.agent_selection] 113 | or self.env.truncations[self.env.agent_selection] 114 | ): 115 | dead_agent = self.env.agent_selection 116 | self.terminations[dead_agent] = self.env.terminations[dead_agent] 117 | self.truncations[dead_agent] = self.env.truncations[dead_agent] 118 | self._final_observations[dead_agent] = self.env.observe(dead_agent) 119 | self.env.step(None) 120 | step_agent = self.env.agent_selection 121 | 122 | self.skip_num[step_agent] -= 1 123 | if self.skip_num[step_agent] == 0: 124 | self.old_actions[ 125 | step_agent 126 | ] = None # if it is time to skip, set action to None, effectively breaking the while loop 127 | 128 | my_agent_set = set(self.agents) 129 | for agent in self.env.agents: 130 | self.terminations[agent] = self.env.terminations[agent] 131 | self.truncations[agent] = self.env.truncations[agent] 132 | self.infos[agent] = self.env.infos[agent] 133 | if agent not in my_agent_set: 134 | self.agents.append(agent) 135 | self.agent_selection = self.env.agent_selection 136 | self._accumulate_rewards() 137 | self._deads_step_first() 138 | 139 | 140 | class frame_skip_par(BaseParallelWrapper): 141 | def __init__(self, env, num_frames, default_action=None): 142 | super().__init__(env) 143 | self.num_frames = check_transform_frameskip(num_frames) 144 | self.default_action = default_action 145 | 146 | def step(self, action): 147 | action = {**action} 148 | low, high = self.num_frames 149 | num_skips = int(self.env.unwrapped.np_random.integers(low, high + 1)) 150 | orig_agents = set(action.keys()) 151 | 152 | total_reward = make_defaultdict({agent: 0.0 for agent in self.agents}) 153 | total_terminations = {} 154 | total_truncations = {} 155 | total_infos = {} 156 | total_obs = {} 157 | 158 | for x in range(num_skips): 159 | obs, rews, term, trunc, info = super().step(action) 160 | 161 | for agent, rew in rews.items(): 162 | total_reward[agent] += rew 163 | total_terminations[agent] = term[agent] 164 | total_truncations[agent] = trunc[agent] 165 | total_infos[agent] = info[agent] 166 | total_obs[agent] = obs[agent] 167 | 168 | for agent in self.env.agents: 169 | if agent not in action: 170 | assert ( 171 | self.default_action is not None 172 | ), "parallel environments that use frame_skip_v0 must provide a `default_action` argument for steps between an agent being generated and an agent taking its first step" 173 | action[agent] = self.default_action 174 | 175 | if ( 176 | np.fromiter(term.values(), dtype=bool) 177 | | np.fromiter(trunc.values(), dtype=bool) 178 | ).all(): 179 | break 180 | 181 | # delete any values created by agents which were 182 | # generated and deleted before they took any actions 183 | final_agents = set(self.agents) 184 | for agent in list(total_reward): 185 | if agent not in final_agents and agent not in orig_agents: 186 | del total_reward[agent] 187 | del total_terminations[agent] 188 | del total_truncations[agent] 189 | del total_infos[agent] 190 | del total_obs[agent] 191 | 192 | return ( 193 | total_obs, 194 | total_reward, 195 | total_terminations, 196 | total_truncations, 197 | total_infos, 198 | ) 199 | 200 | 201 | frame_skip_v0 = WrapperChooser( 202 | aec_wrapper=frame_skip_aec, 203 | gym_wrapper=frame_skip_gym, 204 | parallel_wrapper=frame_skip_par, 205 | ) 206 | -------------------------------------------------------------------------------- /supersuit/vector/multiproc_vec.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import multiprocessing as mp 3 | import time 4 | import traceback 5 | 6 | import gymnasium.vector 7 | import numpy as np 8 | from gymnasium.vector.utils import ( 9 | concatenate, 10 | create_empty_array, 11 | create_shared_memory, 12 | iterate, 13 | read_from_shared_memory, 14 | write_to_shared_memory, 15 | ) 16 | 17 | from .utils.shared_array import SharedArray 18 | 19 | 20 | def compress_info(infos): 21 | non_empty_infs = [(i, info) for i, info in enumerate(infos) if info] 22 | return non_empty_infs 23 | 24 | 25 | def decompress_info(num_envs, idx_starts, comp_infos): 26 | all_info = [{}] * num_envs 27 | for idx_start, comp_infos in zip(idx_starts, comp_infos): 28 | for i, info in comp_infos: 29 | all_info[idx_start + i] = info 30 | return all_info 31 | 32 | 33 | def write_observations(vec_env, env_start_idx, shared_obs, obs): 34 | obs = list(iterate(vec_env.observation_space, obs)) 35 | for i in range(vec_env.num_envs): 36 | write_to_shared_memory( 37 | vec_env.observation_space, 38 | env_start_idx + i, 39 | obs[i], 40 | shared_obs, 41 | ) 42 | 43 | 44 | def numpy_deepcopy(buf): 45 | if isinstance(buf, dict): 46 | return {name: numpy_deepcopy(v) for name, v in buf.items()} 47 | elif isinstance(buf, tuple): 48 | return tuple(numpy_deepcopy(v) for v in buf) 49 | elif isinstance(buf, np.ndarray): 50 | return buf.copy() 51 | else: 52 | raise ValueError("numpy_deepcopy ") 53 | 54 | 55 | def async_loop( 56 | vec_env_constr, inpt_p, pipe, shared_obs, shared_rews, shared_terms, shared_truncs 57 | ): 58 | inpt_p.close() 59 | try: 60 | vec_env = vec_env_constr() 61 | 62 | pipe.send(vec_env.num_envs) 63 | env_start_idx = pipe.recv() 64 | env_end_idx = env_start_idx + vec_env.num_envs 65 | while True: 66 | instr = pipe.recv() 67 | comp_infos = [] 68 | 69 | if instr == "close": 70 | vec_env.close() 71 | 72 | elif isinstance(instr, tuple): 73 | name, data = instr 74 | 75 | if name == "reset": 76 | observations, infos = vec_env.reset(seed=data[0], options=data[1]) 77 | comp_infos = compress_info(infos) 78 | 79 | write_observations(vec_env, env_start_idx, shared_obs, observations) 80 | shared_terms.np_arr[env_start_idx:env_end_idx] = False 81 | shared_truncs.np_arr[env_start_idx:env_end_idx] = False 82 | shared_rews.np_arr[env_start_idx:env_end_idx] = 0.0 83 | 84 | elif name == "step": 85 | actions = data 86 | actions = concatenate( 87 | vec_env.action_space, 88 | actions, 89 | create_empty_array(vec_env.action_space, n=len(actions)), 90 | ) 91 | observations, rewards, terms, truncs, infos = vec_env.step(actions) 92 | write_observations(vec_env, env_start_idx, shared_obs, observations) 93 | shared_terms.np_arr[env_start_idx:env_end_idx] = terms 94 | shared_truncs.np_arr[env_start_idx:env_end_idx] = truncs 95 | shared_rews.np_arr[env_start_idx:env_end_idx] = rewards 96 | comp_infos = compress_info(infos) 97 | 98 | elif name == "env_is_wrapped": 99 | comp_infos = vec_env.env_is_wrapped(data) 100 | 101 | else: 102 | raise AssertionError("bad tuple instruction name: " + name) 103 | elif instr == "render": 104 | render_result = vec_env.render() 105 | if vec_env.render_mode == "rgb_array": 106 | comp_infos = render_result 107 | elif instr == "terminate": 108 | return 109 | else: 110 | raise AssertionError("bad instruction: " + instr) 111 | pipe.send(comp_infos) 112 | except BaseException as e: 113 | tb = traceback.format_exc() 114 | pipe.send((e, tb)) 115 | 116 | 117 | class ProcConcatVec(gymnasium.vector.VectorEnv): 118 | def __init__( 119 | self, vec_env_constrs, observation_space, action_space, tot_num_envs, metadata 120 | ): 121 | self.observation_space = observation_space 122 | self.action_space = action_space 123 | self.num_envs = num_envs = tot_num_envs 124 | self.metadata = metadata 125 | 126 | self.shared_obs = create_shared_memory(self.observation_space, n=self.num_envs) 127 | self.shared_act = create_shared_memory(self.action_space, n=self.num_envs) 128 | self.shared_rews = SharedArray((num_envs,), dtype=np.float32) 129 | self.shared_terms = SharedArray((num_envs,), dtype=np.uint8) 130 | self.shared_truncs = SharedArray((num_envs,), dtype=np.uint8) 131 | 132 | self.observations_buffers = read_from_shared_memory( 133 | self.observation_space, self.shared_obs, n=self.num_envs 134 | ) 135 | 136 | self.graceful_shutdown_timeout = 10 137 | 138 | pipes = [] 139 | procs = [] 140 | for constr in vec_env_constrs: 141 | inpt, outpt = mp.Pipe() 142 | constr = gymnasium.vector.async_vector_env.CloudpickleWrapper(constr) 143 | proc = mp.Process( 144 | target=async_loop, 145 | args=( 146 | constr, 147 | inpt, 148 | outpt, 149 | self.shared_obs, 150 | self.shared_rews, 151 | self.shared_terms, 152 | self.shared_truncs, 153 | ), 154 | ) 155 | proc.start() 156 | outpt.close() 157 | pipes.append(inpt) 158 | procs.append(proc) 159 | 160 | self.pipes = pipes 161 | self.procs = procs 162 | 163 | num_envs = 0 164 | env_nums = self._receive_info() 165 | idx_starts = [] 166 | for pipe, cnum_env in zip(self.pipes, env_nums): 167 | cur_env_idx = num_envs 168 | num_envs += cnum_env 169 | pipe.send(cur_env_idx) 170 | idx_starts.append(cur_env_idx) 171 | idx_starts.append(num_envs) 172 | 173 | assert num_envs == tot_num_envs 174 | self.idx_starts = idx_starts 175 | 176 | def reset(self, seed=None, options=None): 177 | for i, pipe in enumerate(self.pipes): 178 | if seed is not None: 179 | pipe.send(("reset", (seed + i, options))) 180 | else: 181 | pipe.send(("reset", (seed, options))) 182 | 183 | info = self._receive_info() 184 | 185 | return numpy_deepcopy(self.observations_buffers), copy.deepcopy(info) 186 | 187 | def step_async(self, actions): 188 | actions = list(iterate(self.action_space, actions)) 189 | for i, pipe in enumerate(self.pipes): 190 | start, end = self.idx_starts[i : i + 2] 191 | pipe.send(("step", actions[start:end])) 192 | 193 | def _receive_info(self): 194 | all_data = [] 195 | for cin in self.pipes: 196 | data = cin.recv() 197 | if isinstance(data, tuple): 198 | e, tb = data 199 | print(tb) 200 | raise e 201 | all_data.append(data) 202 | return all_data 203 | 204 | def step_wait(self): 205 | compressed_infos = self._receive_info() 206 | infos = decompress_info(self.num_envs, self.idx_starts, compressed_infos) 207 | rewards = self.shared_rews.np_arr 208 | terms = self.shared_terms.np_arr 209 | truncs = self.shared_truncs.np_arr 210 | return ( 211 | numpy_deepcopy(self.observations_buffers), 212 | rewards.copy(), 213 | terms.astype(bool).copy(), 214 | truncs.astype(bool).copy(), 215 | copy.deepcopy(infos), 216 | ) 217 | 218 | def step(self, actions): 219 | self.step_async(actions) 220 | return self.step_wait() 221 | 222 | def __del__(self): 223 | self.close() 224 | 225 | def render(self): 226 | self.pipes[0].send("render") 227 | render_result = self.pipes[0].recv() 228 | 229 | if isinstance(render_result, tuple): 230 | e, tb = render_result 231 | print(tb) 232 | raise e 233 | 234 | return render_result 235 | 236 | def close(self): 237 | try: 238 | for pipe, proc in zip(self.pipes, self.procs): 239 | if proc.is_alive(): 240 | pipe.send(("close", None)) 241 | except OSError: 242 | pass 243 | else: 244 | deadline = ( 245 | None 246 | if self.graceful_shutdown_timeout is None 247 | else time.monotonic() + self.graceful_shutdown_timeout 248 | ) 249 | for proc in self.procs: 250 | timeout = None if deadline is None else deadline - time.monotonic() 251 | if timeout is not None and timeout <= 0: 252 | break 253 | proc.join(timeout) 254 | for pipe, proc in zip(self.pipes, self.procs): 255 | if proc.is_alive(): 256 | proc.kill() 257 | pipe.close() 258 | 259 | def env_is_wrapped(self, wrapper_class, indices=None): 260 | for i, pipe in enumerate(self.pipes): 261 | pipe.send(("env_is_wrapped", wrapper_class)) 262 | 263 | results = self._receive_info() 264 | return sum(results, []) 265 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This repository is licensed as follows: 2 | All assets in this repository are the copyright of the Farama Foundation, except 3 | where prohibited. Contributors to the repository transfer copyright of their work 4 | to the Farama Foundation. 5 | 6 | Some code in this repository has been taken from other open source projects 7 | and was originally released under the MIT or Apache 2.0 licenses, with 8 | copyright held by another party. We've attributed these authors and they 9 | retain their copyright to the extent required by law. Everything else 10 | is owned by the Farama Foundation. The Secret Code font was also released under 11 | the MIT license by Matthew Welch (http://www.squaregear.net/fonts/). 12 | The MIT and Apache 2.0 licenses are included below. 13 | 14 | The Farama Foundation releases the elements of this repository they copyright to 15 | under the MIT license. 16 | 17 | -------------------------------------------------------------------------------- 18 | 19 | MIT License 20 | 21 | Permission is hereby granted, free of charge, to any person obtaining a copy 22 | of this software and associated documentation files (the "Software"), to deal 23 | in the Software without restriction, including without limitation the rights 24 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 25 | copies of the Software, and to permit persons to whom the Software is 26 | furnished to do so, subject to the following conditions: 27 | 28 | The above copyright notice and this permission notice shall be included in all 29 | copies or substantial portions of the Software. 30 | 31 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 32 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 33 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 34 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 35 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 36 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 37 | SOFTWARE. 38 | 39 | -------------------------------------------------------------------------------- 40 | 41 | Apache License 42 | Version 2.0, January 2004 43 | http://www.apache.org/licenses/ 44 | 45 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 46 | 47 | 1. Definitions. 48 | 49 | "License" shall mean the terms and conditions for use, reproduction, 50 | and distribution as defined by Sections 1 through 9 of this document. 51 | 52 | "Licensor" shall mean the copyright owner or entity authorized by 53 | the copyright owner that is granting the License. 54 | 55 | "Legal Entity" shall mean the union of the acting entity and all 56 | other entities that control, are controlled by, or are under common 57 | control with that entity. For the purposes of this definition, 58 | "control" means (i) the power, direct or indirect, to cause the 59 | direction or management of such entity, whether by contract or 60 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 61 | outstanding shares, or (iii) beneficial ownership of such entity. 62 | 63 | "You" (or "Your") shall mean an individual or Legal Entity 64 | exercising permissions granted by this License. 65 | 66 | "Source" form shall mean the preferred form for making modifications, 67 | including but not limited to software source code, documentation 68 | source, and configuration files. 69 | 70 | "Object" form shall mean any form resulting from mechanical 71 | transformation or translation of a Source form, including but 72 | not limited to compiled object code, generated documentation, 73 | and conversions to other media types. 74 | 75 | "Work" shall mean the work of authorship, whether in Source or 76 | Object form, made available under the License, as indicated by a 77 | copyright notice that is included in or attached to the work 78 | (an example is provided in the Appendix below). 79 | 80 | "Derivative Works" shall mean any work, whether in Source or Object 81 | form, that is based on (or derived from) the Work and for which the 82 | editorial revisions, annotations, elaborations, or other modifications 83 | represent, as a whole, an original work of authorship. For the purposes 84 | of this License, Derivative Works shall not include works that remain 85 | separable from, or merely link (or bind by name) to the interfaces of, 86 | the Work and Derivative Works thereof. 87 | 88 | "Contribution" shall mean any work of authorship, including 89 | the original version of the Work and any modifications or additions 90 | to that Work or Derivative Works thereof, that is intentionally 91 | submitted to Licensor for inclusion in the Work by the copyright owner 92 | or by an individual or Legal Entity authorized to submit on behalf of 93 | the copyright owner. For the purposes of this definition, "submitted" 94 | means any form of electronic, verbal, or written communication sent 95 | to the Licensor or its representatives, including but not limited to 96 | communication on electronic mailing lists, source code control systems, 97 | and issue tracking systems that are managed by, or on behalf of, the 98 | Licensor for the purpose of discussing and improving the Work, but 99 | excluding communication that is conspicuously marked or otherwise 100 | designated in writing by the copyright owner as "Not a Contribution." 101 | 102 | "Contributor" shall mean Licensor and any individual or Legal Entity 103 | on behalf of whom a Contribution has been received by Licensor and 104 | subsequently incorporated within the Work. 105 | 106 | 2. Grant of Copyright License. Subject to the terms and conditions of 107 | this License, each Contributor hereby grants to You a perpetual, 108 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 109 | copyright license to reproduce, prepare Derivative Works of, 110 | publicly display, publicly perform, sublicense, and distribute the 111 | Work and such Derivative Works in Source or Object form. 112 | 113 | 3. Grant of Patent License. Subject to the terms and conditions of 114 | this License, each Contributor hereby grants to You a perpetual, 115 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 116 | (except as stated in this section) patent license to make, have made, 117 | use, offer to sell, sell, import, and otherwise transfer the Work, 118 | where such license applies only to those patent claims licensable 119 | by such Contributor that are necessarily infringed by their 120 | Contribution(s) alone or by combination of their Contribution(s) 121 | with the Work to which such Contribution(s) was submitted. If You 122 | institute patent litigation against any entity (including a 123 | cross-claim or counterclaim in a lawsuit) alleging that the Work 124 | or a Contribution incorporated within the Work constitutes direct 125 | or contributory patent infringement, then any patent licenses 126 | granted to You under this License for that Work shall terminate 127 | as of the date such litigation is filed. 128 | 129 | 4. Redistribution. You may reproduce and distribute copies of the 130 | Work or Derivative Works thereof in any medium, with or without 131 | modifications, and in Source or Object form, provided that You 132 | meet the following conditions: 133 | 134 | (a) You must give any other recipients of the Work or 135 | Derivative Works a copy of this License; and 136 | 137 | (b) You must cause any modified files to carry prominent notices 138 | stating that You changed the files; and 139 | 140 | (c) You must retain, in the Source form of any Derivative Works 141 | that You distribute, all copyright, patent, trademark, and 142 | attribution notices from the Source form of the Work, 143 | excluding those notices that do not pertain to any part of 144 | the Derivative Works; and 145 | 146 | (d) If the Work includes a "NOTICE" text file as part of its 147 | distribution, then any Derivative Works that You distribute must 148 | include a readable copy of the attribution notices contained 149 | within such NOTICE file, excluding those notices that do not 150 | pertain to any part of the Derivative Works, in at least one 151 | of the following places: within a NOTICE text file distributed 152 | as part of the Derivative Works; within the Source form or 153 | documentation, if provided along with the Derivative Works; or, 154 | within a display generated by the Derivative Works, if and 155 | wherever such third-party notices normally appear. The contents 156 | of the NOTICE file are for informational purposes only and 157 | do not modify the License. You may add Your own attribution 158 | notices within Derivative Works that You distribute, alongside 159 | or as an addendum to the NOTICE text from the Work, provided 160 | that such additional attribution notices cannot be construed 161 | as modifying the License. 162 | 163 | You may add Your own copyright statement to Your modifications and 164 | may provide additional or different license terms and conditions 165 | for use, reproduction, or distribution of Your modifications, or 166 | for any such Derivative Works as a whole, provided Your use, 167 | reproduction, and distribution of the Work otherwise complies with 168 | the conditions stated in this License. 169 | 170 | 5. Submission of Contributions. Unless You explicitly state otherwise, 171 | any Contribution intentionally submitted for inclusion in the Work 172 | by You to the Licensor shall be under the terms and conditions of 173 | this License, without any additional terms or conditions. 174 | Notwithstanding the above, nothing herein shall supersede or modify 175 | the terms of any separate license agreement you may have executed 176 | with Licensor regarding such Contributions. 177 | 178 | 6. Trademarks. This License does not grant permission to use the trade 179 | names, trademarks, service marks, or product names of the Licensor, 180 | except as required for reasonable and customary use in describing the 181 | origin of the Work and reproducing the content of the NOTICE file. 182 | 183 | 7. Disclaimer of Warranty. Unless required by applicable law or 184 | agreed to in writing, Licensor provides the Work (and each 185 | Contributor provides its Contributions) on an "AS IS" BASIS, 186 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 187 | implied, including, without limitation, any warranties or conditions 188 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 189 | PARTICULAR PURPOSE. You are solely responsible for determining the 190 | appropriateness of using or redistributing the Work and assume any 191 | risks associated with Your exercise of permissions under this License. 192 | 193 | 8. Limitation of Liability. In no event and under no legal theory, 194 | whether in tort (including negligence), contract, or otherwise, 195 | unless required by applicable law (such as deliberate and grossly 196 | negligent acts) or agreed to in writing, shall any Contributor be 197 | liable to You for damages, including any direct, indirect, special, 198 | incidental, or consequential damages of any character arising as a 199 | result of this License or out of the use or inability to use the 200 | Work (including but not limited to damages for loss of goodwill, 201 | work stoppage, computer failure or malfunction, or any and all 202 | other commercial damages or losses), even if such Contributor 203 | has been advised of the possibility of such damages. 204 | 205 | 9. Accepting Warranty or Additional Liability. While redistributing 206 | the Work or Derivative Works thereof, You may choose to offer, 207 | and charge a fee for, acceptance of support, warranty, indemnity, 208 | or other liability obligations and/or rights consistent with this 209 | License. However, in accepting such obligations, You may act only 210 | on Your own behalf and on Your sole responsibility, not on behalf 211 | of any other Contributor, and only if You agree to indemnify, 212 | defend, and hold each Contributor harmless for any liability 213 | incurred by, or claims asserted against, such Contributor by reason 214 | of your accepting any such warranty or additional liability. 215 | 216 | END OF TERMS AND CONDITIONS 217 | -------------------------------------------------------------------------------- /test/pettingzoo_api_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from pettingzoo.butterfly import ( 4 | cooperative_pong_v5, 5 | knights_archers_zombies_v10, 6 | pistonball_v6, 7 | ) 8 | from pettingzoo.classic import connect_four_v3 9 | from pettingzoo.mpe import simple_push_v3, simple_spread_v3, simple_world_comm_v3 10 | from pettingzoo.sisl import pursuit_v4 11 | from pettingzoo.test import api_test, parallel_api_test, seed_test 12 | from pettingzoo.utils.all_modules import ( 13 | atari_environments, 14 | butterfly_environments, 15 | classic_environments, 16 | mpe_environments, 17 | sisl_environments, 18 | ) 19 | 20 | import supersuit 21 | from supersuit import ( 22 | dtype_v0, 23 | frame_skip_v0, 24 | frame_stack_v2, 25 | pad_action_space_v0, 26 | sticky_actions_v0, 27 | ) 28 | from supersuit.utils.convert_box import convert_box 29 | 30 | 31 | atari = list(atari_environments.values()) 32 | butterfly = list(butterfly_environments.values()) 33 | classic = list(classic_environments.values()) 34 | mpe = list(mpe_environments.values()) 35 | sisl = list(sisl_environments.values()) 36 | all = atari + butterfly + classic + mpe + sisl 37 | 38 | BUTTERFLY_MPE_CLASSIC = [ 39 | knights_archers_zombies_v10, 40 | simple_push_v3, 41 | connect_four_v3, 42 | simple_spread_v3, 43 | ] 44 | BUTTERFLY_MPE = [knights_archers_zombies_v10, simple_push_v3, simple_spread_v3] 45 | 46 | 47 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 48 | def test_frame_stack(env_fn): 49 | _env = env_fn.env() 50 | wrapped_env = frame_stack_v2(_env) 51 | api_test(wrapped_env) 52 | 53 | 54 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 55 | def test_frame_stack_parallel(env_fn): 56 | _env = env_fn.parallel_env() 57 | wrapped_env = frame_stack_v2(_env) 58 | parallel_api_test(wrapped_env) 59 | 60 | 61 | @pytest.mark.parametrize("env_fn", [simple_push_v3]) 62 | def test_frame_skip(env_fn): 63 | env = env_fn.raw_env(max_cycles=100) 64 | env = frame_skip_v0(env, 3) 65 | env.reset() 66 | x = 0 67 | for _ in env.agent_iter(25): 68 | assert env.unwrapped.steps == (x // 2) * 3 69 | action = env.action_space(env.agent_selection).sample() 70 | env.step(action) 71 | x += 1 72 | 73 | 74 | @pytest.mark.parametrize("env_fn", [simple_push_v3]) 75 | def test_frame_skip_parallel(env_fn): 76 | env = env_fn.parallel_env(max_cycles=100) 77 | env = frame_skip_v0(env, 3) 78 | env.reset() 79 | x = 0 80 | while env.agents: 81 | assert env.unwrapped.steps == (x // 2) * 3 82 | actions = {agent: env.action_space(agent).sample() for agent in env.agents} 83 | env.step(actions) 84 | x += env.num_agents 85 | 86 | 87 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) 88 | def test_pad_action_space(env_fn): 89 | _env = env_fn.env() 90 | wrapped_env = pad_action_space_v0(_env) 91 | api_test(wrapped_env) 92 | seed_test(lambda: sticky_actions_v0(simple_world_comm_v3.env(), 0.5), 100) 93 | 94 | 95 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 96 | def test_pad_action_space_parallel(env_fn): 97 | _env = env_fn.parallel_env() 98 | wrapped_env = pad_action_space_v0(_env) 99 | parallel_api_test(wrapped_env) 100 | 101 | 102 | @pytest.mark.parametrize( 103 | "env_fn", atari + [pistonball_v6, cooperative_pong_v5, pursuit_v4] 104 | ) 105 | def test_color_reduction(env_fn): 106 | env = supersuit.color_reduction_v0(env_fn.env(), "R") 107 | api_test(env) 108 | 109 | 110 | @pytest.mark.parametrize( 111 | "env_fn", atari + [pistonball_v6, cooperative_pong_v5, pursuit_v4] 112 | ) 113 | def test_color_reduction_parallel(env_fn): 114 | env = supersuit.color_reduction_v0(env_fn.parallel_env(), "R") 115 | parallel_api_test(env) 116 | 117 | 118 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10]) 119 | @pytest.mark.parametrize( 120 | "wrapper_kwargs", 121 | [dict(x_size=5, y_size=10), dict(x_size=5, y_size=10, linear_interp=True)], 122 | ) 123 | def test_resize_dtype(env_fn, wrapper_kwargs): 124 | env = supersuit.resize_v1( 125 | dtype_v0(env_fn.env(vector_state=False), np.uint8), **wrapper_kwargs 126 | ) 127 | api_test(env) 128 | 129 | 130 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10]) 131 | @pytest.mark.parametrize( 132 | "wrapper_kwargs", 133 | [dict(x_size=5, y_size=10), dict(x_size=5, y_size=10, linear_interp=True)], 134 | ) 135 | def test_resize_dtype_parallel(env_fn, wrapper_kwargs): 136 | env = supersuit.resize_v1( 137 | dtype_v0(env_fn.parallel_env(vector_state=False), np.uint8), **wrapper_kwargs 138 | ) 139 | parallel_api_test(env) 140 | 141 | 142 | @pytest.mark.parametrize( 143 | "env_fn", 144 | atari 145 | + butterfly 146 | + [v for k, v in sisl_environments.items() if k != "sisl/multiwalker_v9"], 147 | ) 148 | def test_dtype(env_fn): 149 | env = supersuit.dtype_v0(env_fn.env(), np.int32) 150 | api_test(env) 151 | 152 | 153 | @pytest.mark.parametrize("env_fn", atari + butterfly + sisl) 154 | def test_dtype_parallel(env_fn): 155 | env = supersuit.dtype_v0(env_fn.parallel_env(), np.int32) 156 | parallel_api_test(env) 157 | 158 | 159 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) 160 | def test_flatten(env_fn): 161 | env = supersuit.flatten_v0(knights_archers_zombies_v10.env()) 162 | api_test(env) 163 | 164 | 165 | # Classic environments don't have parallel envs so this doesn't apply 166 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 167 | def test_flatten_parallel(env_fn): 168 | env = supersuit.flatten_v0(env_fn.parallel_env()) 169 | parallel_api_test(env) 170 | 171 | 172 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10]) 173 | def test_reshape(env_fn): 174 | env = supersuit.reshape_v0(env_fn.env(vector_state=False), (512 * 512, 3)) 175 | api_test(env) 176 | 177 | 178 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10]) 179 | def test_reshape_parallel(env_fn): 180 | env = supersuit.reshape_v0(env_fn.parallel_env(vector_state=False), (512 * 512, 3)) 181 | parallel_api_test(env) 182 | 183 | 184 | # MPE environment has infinite bounds for observation space (only environments with finite bounds can be passed to normalize_obs) 185 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10]) 186 | def test_normalize_obs(env_fn): 187 | env = supersuit.normalize_obs_v0( 188 | dtype_v0(env_fn.env(), np.float32), env_min=-1, env_max=5.0 189 | ) 190 | api_test(env) 191 | 192 | 193 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10]) 194 | def test_normalize_obs_parallel(env_fn): 195 | env = supersuit.normalize_obs_v0( 196 | dtype_v0(env_fn.parallel_env(), np.float32), env_min=-1, env_max=5.0 197 | ) 198 | parallel_api_test(env) 199 | 200 | 201 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 202 | def test_pad_observations(env_fn): 203 | env = supersuit.pad_observations_v0(env_fn.env()) 204 | api_test(env) 205 | 206 | 207 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 208 | def test_pad_observations_parallel(env_fn): 209 | env = supersuit.pad_observations_v0(env_fn.parallel_env()) 210 | parallel_api_test(env) 211 | 212 | 213 | @pytest.mark.skip( 214 | reason="Black death wrapper is only designed for parallel envs, AEC envs should simply skip the agent by setting env.agent_selection manually" 215 | ) 216 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) 217 | def test_black_death(env_fn): 218 | env = supersuit.black_death_v3(env_fn.env()) 219 | api_test(env) 220 | 221 | 222 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 223 | def test_black_death_parallel(env_fn): 224 | env = supersuit.black_death_v3(env_fn.parallel_env()) 225 | parallel_api_test(env) 226 | 227 | 228 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10]) 229 | @pytest.mark.parametrize("env_kwargs", [dict(type_only=True), dict(type_only=False)]) 230 | def test_agent_indicator(env_fn, env_kwargs): 231 | env = supersuit.agent_indicator_v0(env_fn.env(), **env_kwargs) 232 | api_test(env) 233 | 234 | 235 | @pytest.mark.parametrize("env_fn", [knights_archers_zombies_v10]) 236 | @pytest.mark.parametrize("env_kwargs", [dict(type_only=True), dict(type_only=False)]) 237 | def test_agent_indicator_parallel(env_fn, env_kwargs): 238 | env = supersuit.agent_indicator_v0(env_fn.parallel_env(), **env_kwargs) 239 | parallel_api_test(env) 240 | 241 | 242 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) 243 | def test_reward_lambda(env_fn): 244 | env = supersuit.reward_lambda_v0(env_fn.env(), lambda x: x / 10) 245 | api_test(env) 246 | 247 | 248 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 249 | def test_reward_lambda_parallel(env_fn): 250 | env = supersuit.reward_lambda_v0(env_fn.parallel_env(), lambda x: x / 10) 251 | parallel_api_test(env) 252 | 253 | 254 | @pytest.mark.parametrize( 255 | "env_fn", 256 | [v for k, v in butterfly_environments.items() if k != "butterfly/pistonball_v6"] 257 | + mpe 258 | + sisl, 259 | ) 260 | def test_observation_lambda(env_fn): 261 | env = supersuit.observation_lambda_v0(env_fn.env(), lambda obs, obs_space: obs - 1) 262 | api_test(env) 263 | 264 | 265 | # Example using observation lambda with an action masked environment (flattening the obs space while keeping action mask) 266 | @pytest.mark.parametrize("env_fn", [connect_four_v3]) 267 | def test_observation_lambda_action_mask(env_fn): 268 | env = env_fn.env() 269 | env.reset() 270 | obs = env.observe(env.possible_agents[0]) 271 | 272 | # Example: reshape the observation to flatten the first two dimensions: (6, 7, 2) -> (42, 2) 273 | newshape = obs["observation"].reshape((-1, 2)).shape 274 | 275 | def change_obs_space_fn(obs_space): 276 | obs_space["observation"] = convert_box( 277 | lambda obs: obs.reshape(newshape), old_box=obs_space["observation"] 278 | ) 279 | return obs_space 280 | 281 | def change_observation_fn(observation, old_obs_space): 282 | # Reshape observation 283 | observation["observation"] = observation["observation"].reshape(newshape) 284 | # Invert the action mask (make illegal actions legal, and vice versa) 285 | observation["action_mask"] = 1 - observation["action_mask"] 286 | return observation 287 | 288 | env = supersuit.observation_lambda_v0( 289 | env_fn.env(), 290 | change_obs_space_fn=change_obs_space_fn, 291 | change_observation_fn=change_observation_fn, 292 | ) 293 | 294 | env.reset() 295 | obs = env.observe(env.possible_agents[0]) 296 | assert obs["observation"].shape == ( 297 | 42, 298 | 2, 299 | ), "New observation should be shape (42, 2)" 300 | assert np.array_equal( 301 | obs["action_mask"], np.zeros(7) 302 | ), "New action mask should be all zeros" 303 | assert env.observation_space(env.possible_agents[0])["observation"].shape == ( 304 | 42, 305 | 2, 306 | ), "Observation space should be (42, 2)" 307 | 308 | api_test(env) 309 | 310 | 311 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 312 | def test_observation_lambda_parallel(env_fn): 313 | pass 314 | 315 | 316 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) 317 | def test_clip_reward(env_fn): 318 | env = supersuit.clip_reward_v0(env_fn.env()) 319 | api_test(env) 320 | 321 | 322 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 323 | def test_clip_reward_parallel(env_fn): 324 | env = supersuit.clip_reward_v0(env_fn.parallel_env()) 325 | parallel_api_test(env) 326 | 327 | 328 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) 329 | def test_nan_noop(env_fn): 330 | env = supersuit.nan_noop_v0(env_fn.env(), 0) 331 | api_test(env) 332 | 333 | 334 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 335 | def test_nan_noop_parallel(env_fn): 336 | env = supersuit.nan_noop_v0(env_fn.parallel_env(), 0) 337 | parallel_api_test(env) 338 | 339 | 340 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) 341 | def test_nan_zeros(env_fn): 342 | env = supersuit.nan_zeros_v0(env_fn.env()) 343 | api_test(env) 344 | 345 | 346 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 347 | def test_nan_zeros_parallel(env_fn): 348 | env = supersuit.nan_zeros_v0(env_fn.parallel_env()) 349 | parallel_api_test(env) 350 | 351 | 352 | # Note: hanabi v5 fails here 353 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 354 | def test_nan_random(env_fn): 355 | env = supersuit.nan_random_v0(env_fn.env()) 356 | api_test(env) 357 | 358 | 359 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 360 | def test_nan_random_parallel(env_fn): 361 | env = supersuit.nan_random_v0(env_fn.parallel_env()) 362 | parallel_api_test(env) 363 | 364 | 365 | # Note: hanabi v5 fails here 366 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 367 | def test_sticky_actions(env_fn): 368 | env = supersuit.sticky_actions_v0(env_fn.env(), 0.75) 369 | api_test(env) 370 | 371 | 372 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 373 | def test_sticky_actions_parallel(env_fn): 374 | env = supersuit.sticky_actions_v0(env_fn.parallel_env(), 0.75) 375 | parallel_api_test(env) 376 | 377 | 378 | # Note: hanabi_v5 and texas_holdem_v4 fail here 379 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 380 | def test_delay_observations(env_fn): 381 | env = supersuit.delay_observations_v0(env_fn.env(), 3) 382 | api_test(env) 383 | 384 | 385 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 386 | def test_delay_observations_parallel(env_fn): 387 | env = supersuit.delay_observations_v0(env_fn.parallel_env(), 3) 388 | parallel_api_test(env) 389 | 390 | 391 | @pytest.mark.parametrize("env_fn", atari + butterfly + classic + mpe + sisl) 392 | def test_max_observation(env_fn): 393 | env = supersuit.max_observation_v0(knights_archers_zombies_v10.env(), 3) 394 | api_test(env) 395 | 396 | 397 | @pytest.mark.parametrize("env_fn", atari + butterfly + mpe + sisl) 398 | def test_max_observation_parallel(env_fn): 399 | env = supersuit.max_observation_v0(knights_archers_zombies_v10.parallel_env(), 3) 400 | parallel_api_test(env) 401 | --------------------------------------------------------------------------------