├── config ├── offline_baselines_jax ├── py.typed ├── version.txt ├── common │ ├── __init__.py │ ├── vec_env │ │ ├── vec_extract_dict_obs.py │ │ ├── vec_frame_stack.py │ │ ├── __init__.py │ │ ├── util.py │ │ ├── vec_check_nan.py │ │ ├── vec_monitor.py │ │ ├── vec_video_recorder.py │ │ ├── vec_transpose.py │ │ ├── dummy_vec_env.py │ │ ├── subproc_vec_env.py │ │ ├── stacked_observations.py │ │ ├── vec_normalize.py │ │ └── base_vec_env.py │ ├── type_aliases.py │ ├── policies.py │ ├── preprocessing.py │ ├── jax_layers.py │ └── utils.py ├── cql │ ├── __init__.py │ ├── core.py │ └── cql.py ├── sac │ ├── __init__.py │ ├── core.py │ ├── sac.py │ └── policies.py ├── td3 │ ├── __init__.py │ ├── core.py │ ├── td3.py │ └── policies.py └── __init__.py ├── .coveragerc ├── .gitignore ├── LICENSE ├── NOTICE ├── setup.py └── README.md /config: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /offline_baselines_jax/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /offline_baselines_jax/version.txt: -------------------------------------------------------------------------------- 1 | 1.0.0 -------------------------------------------------------------------------------- /offline_baselines_jax/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /offline_baselines_jax/cql/__init__.py: -------------------------------------------------------------------------------- 1 | from offline_baselines_jax.cql.cql import CQL 2 | -------------------------------------------------------------------------------- /offline_baselines_jax/sac/__init__.py: -------------------------------------------------------------------------------- 1 | from offline_baselines_jax.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from .sac import SAC 3 | -------------------------------------------------------------------------------- /offline_baselines_jax/td3/__init__.py: -------------------------------------------------------------------------------- 1 | from offline_baselines_jax.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy 2 | from offline_baselines_jax.td3.td3 import TD3 3 | -------------------------------------------------------------------------------- /offline_baselines_jax/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from offline_baselines_jax.sac import SAC 4 | from offline_baselines_jax.cql import CQL 5 | from offline_baselines_jax.td3 import TD3 6 | 7 | # Read version from file 8 | version_file = os.path.join(os.path.dirname(__file__), "version.txt") 9 | with open(version_file, "r") as file_handler: 10 | __version__ = file_handler.read().strip() -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = False 3 | omit = 4 | tests/* 5 | setup.py 6 | # Require graphical interface 7 | stable_baselines3/common/results_plotter.py 8 | # Require ffmpeg 9 | stable_baselines3/common/vec_env/vec_video_recorder.py 10 | 11 | [report] 12 | exclude_lines = 13 | pragma: no cover 14 | raise NotImplementedError() 15 | if typing.TYPE_CHECKING: 16 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/vec_extract_dict_obs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper 4 | 5 | 6 | class VecExtractDictObs(VecEnvWrapper): 7 | """ 8 | A vectorized wrapper for extracting dictionary observations. 9 | 10 | :param venv: The vectorized environment 11 | :param key: The key of the dictionary observation 12 | """ 13 | 14 | def __init__(self, venv: VecEnv, key: str): 15 | self.key = key 16 | super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) 17 | 18 | def reset(self) -> np.ndarray: 19 | obs = self.venv.reset() 20 | return obs[self.key] 21 | 22 | def step_wait(self) -> VecEnvStepReturn: 23 | obs, reward, done, info = self.venv.step_wait() 24 | return obs[self.key], reward, done, info 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | *.pkl 4 | *.py~ 5 | *.bak 6 | .pytest_cache 7 | .DS_Store 8 | .idea 9 | .vscode 10 | .coverage 11 | .coverage.* 12 | __pycache__/ 13 | _build/ 14 | *.npz 15 | *.pth 16 | .pytype/ 17 | git_rewrite_commit_history.sh 18 | 19 | # Setuptools distribution and build folders. 20 | /dist/ 21 | /build 22 | keys/ 23 | 24 | # Virtualenv 25 | /env 26 | /venv 27 | 28 | 29 | *.sublime-project 30 | *.sublime-workspace 31 | 32 | .idea 33 | 34 | logs/ 35 | 36 | .ipynb_checkpoints 37 | ghostdriver.log 38 | 39 | htmlcov 40 | 41 | junk 42 | src 43 | 44 | *.egg-info 45 | .cache 46 | *.lprof 47 | *.prof 48 | 49 | MUJOCO_LOG.TXT 50 | 2dnavi.py 51 | 2dnavioff.py 52 | foo* 53 | online.py 54 | server.py 55 | client.py 56 | gail/ 57 | goal* 58 | rndtest.py 59 | tests/ 60 | _navigation_2d/ 61 | envs/ 62 | multiworld/ 63 | backup/ 64 | runs/ 65 | .hydra/* 66 | *.yaml 67 | 68 | d4rl_maze/ 69 | outputs/ 70 | offline_data/ 71 | tests/ 72 | prior_sopt_train.py 73 | online_test.py 74 | metla_train.py 75 | *train.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2022 Minjong Yoo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Large portion of the code of offline_baselines_jax were ported from Stable-Baselines3, a fork of OpenAI Baselines, 2 | both licensed under the MIT License: 3 | 4 | before the fork (June 2018): 5 | Copyright (c) 2017 OpenAI (http://openai.com) 6 | 7 | after the fork (June 2018): 8 | Copyright (c) 2018-2019 Stable-Baselines Team 9 | 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in 19 | all copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | THE SOFTWARE. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import find_packages, setup 4 | 5 | with open(os.path.join("offline_baselines_jax", "version.txt"), "r") as file_handler: 6 | __version__ = file_handler.read().strip() 7 | 8 | setup( 9 | name="offline_baselines_jax", 10 | packages=[package for package in find_packages() if package.startswith("offline_baselines_jax")], 11 | package_data={"offline_baselines_jax": ["py.typed", "version.txt"]}, 12 | install_requires=[ 13 | "stable_baselines3==1.4.0", 14 | # "jax==0.3.4", 15 | # "jaxlib==0.3.2", 16 | "flax==0.4.0", 17 | "tensorflow_probability", 18 | 'optax==0.1.1' 19 | ], 20 | description="Jax version of implementations of offline reinforcement learning algorithms.", 21 | author="Minjong Yoo", 22 | url="https://github.com/mjyoo2/offline_baselines_jax", 23 | author_email="mjyoo222@gmail.com", 24 | license="MIT", 25 | version=__version__, 26 | python_requires=">=3.7", 27 | # PyPI package information. 28 | classifiers=[ 29 | "Programming Language :: Python :: 3", 30 | "Programming Language :: Python :: 3.7", 31 | "Programming Language :: Python :: 3.8", 32 | "Programming Language :: Python :: 3.9", 33 | ], 34 | ) 35 | 36 | # python setup.py sdist 37 | # python setup.py bdist_wheel 38 | # twine upload --repository-url https://test.pypi.org/legacy/ dist/* 39 | # twine upload dist/* 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Offline Baselines (JAX) 2 | 3 | Offline Baselines with JAX is a set of implementations of reinforcement learning algorithms in JAX. 4 | 5 | This library is based on Stable Baselines 3 (https://github.com/DLR-RM/stable-baselines3), and JAXRL (https://github.com/ikostrikov/jaxrl). 6 | 7 | ### Windows 10 8 | 9 | Offline baselines does not support Window OS. 10 | 11 | ### Install 12 | 13 | ``` 14 | git clone https://github.com/mjyoo2/offline_baselines_jax.git 15 | python setup.py install 16 | ``` 17 | 18 | ### Install using pip 19 | Install the offline baselines with jax package: 20 | ``` 21 | pip install git+https://github.com/mjyoo2/offline_baselines_jax 22 | ``` 23 | 24 | ## Performance 25 | We check speed SAC and TD3 algorithm. We use RTX 3090, Intel i9-10940. Learning environment is HalfCheetah-v2. 26 | 27 | | **Algorithm** | **Stable Baselines (Pytorch)** | **Offline Baselines (Jax)** | 28 | |---------------|--------------------------------|-----------------------------| 29 | | SAC | 125 steps / 1 second | 570 steps / 1 second | 30 | | TD3 | 240 steps / 1 second | 800 steps / 1 second | 31 | 32 | ## Example 33 | ```python 34 | from offline_baselines_jax import SAC 35 | from offline_baselines_jax.sac.policies import SACPolicy 36 | 37 | import gym 38 | 39 | train_env = gym.make('HalfCheetah-v2') 40 | 41 | model = SAC(SACPolicy, train_env, seed=777, verbose=1, batch_size=1024, buffer_size=50000, train_freq=1) 42 | 43 | model.learn(total_timesteps=10000) 44 | model.save('./model.zip') 45 | model = SAC.load('./model.zip', train_env) 46 | 47 | model.learn(total_timesteps=10000) 48 | 49 | ``` 50 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/type_aliases.py: -------------------------------------------------------------------------------- 1 | """Common aliases for type hints""" 2 | 3 | from enum import Enum 4 | from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union 5 | 6 | import gym 7 | import numpy as np 8 | import jax.numpy as jnp 9 | import flax 10 | 11 | from stable_baselines3.common import callbacks 12 | from stable_baselines3.common import vec_env 13 | 14 | 15 | GymEnv = Union[gym.Env, vec_env.VecEnv] 16 | GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] 17 | GymStepReturn = Tuple[GymObs, float, bool, Dict] 18 | TensorDict = Dict[Union[str, int], jnp.ndarray] 19 | OptimizerStateDict = Dict[str, Any] 20 | MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] 21 | 22 | # A schedule takes the remaining progress as input 23 | # and ouputs a scalar (e.g. learning rate, clip range, ...) 24 | Schedule = Callable[[float], float] 25 | Params = flax.core.FrozenDict[str, Any] 26 | InfoDict = Dict[str, float] 27 | 28 | 29 | class ReplayBufferSamples(NamedTuple): 30 | observations: jnp.ndarray 31 | actions: jnp.ndarray 32 | next_observations: jnp.ndarray 33 | dones: jnp.ndarray 34 | rewards: jnp.ndarray 35 | 36 | 37 | class DictReplayBufferSamples(ReplayBufferSamples): 38 | observations: TensorDict 39 | actions: jnp.ndarray 40 | next_observations: jnp.ndarray 41 | dones: jnp.ndarray 42 | rewards: jnp.ndarray 43 | 44 | 45 | class RolloutReturn(NamedTuple): 46 | episode_timesteps: int 47 | n_episodes: int 48 | continue_training: bool 49 | 50 | 51 | class TrainFrequencyUnit(Enum): 52 | STEP = "step" 53 | EPISODE = "episode" 54 | 55 | 56 | class TrainFreq(NamedTuple): 57 | frequency: int 58 | unit: TrainFrequencyUnit # either "step" or "episode" 59 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | from gym import spaces 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper 7 | from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations 8 | 9 | 10 | class VecFrameStack(VecEnvWrapper): 11 | """ 12 | Frame stacking wrapper for vectorized environment. Designed for image observations. 13 | 14 | Uses the StackedObservations class, or StackedDictObservations depending on the observations space 15 | 16 | :param venv: the vectorized environment to wrap 17 | :param n_stack: Number of frames to stack 18 | :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. 19 | If None, automatically detect channel to stack over in case of image observation or default to "last" (default). 20 | Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces 21 | """ 22 | 23 | def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None): 24 | self.venv = venv 25 | self.n_stack = n_stack 26 | 27 | wrapped_obs_space = venv.observation_space 28 | 29 | if isinstance(wrapped_obs_space, spaces.Box): 30 | assert not isinstance( 31 | channels_order, dict 32 | ), f"Expected None or string for channels_order but received {channels_order}" 33 | self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) 34 | 35 | elif isinstance(wrapped_obs_space, spaces.Dict): 36 | self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) 37 | 38 | else: 39 | raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces") 40 | 41 | observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space) 42 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 43 | 44 | def step_wait( 45 | self, 46 | ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]: 47 | 48 | observations, rewards, dones, infos = self.venv.step_wait() 49 | 50 | observations, infos = self.stackedobs.update(observations, dones, infos) 51 | 52 | return observations, rewards, dones, infos 53 | 54 | def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: 55 | """ 56 | Reset all environments 57 | """ 58 | observation = self.venv.reset() # pytype:disable=annotation-type-mismatch 59 | 60 | observation = self.stackedobs.reset(observation) 61 | return observation 62 | 63 | def close(self) -> None: 64 | self.venv.close() 65 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa F401 2 | import typing 3 | from copy import deepcopy 4 | from typing import Optional, Type, Union 5 | 6 | from offline_baselines_jax.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper 7 | from offline_baselines_jax.common.vec_env.dummy_vec_env import DummyVecEnv 8 | from offline_baselines_jax.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations 9 | from offline_baselines_jax.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from offline_baselines_jax.common.vec_env.vec_check_nan import VecCheckNan 11 | from offline_baselines_jax.common.vec_env.vec_extract_dict_obs import VecExtractDictObs 12 | from offline_baselines_jax.common.vec_env.vec_frame_stack import VecFrameStack 13 | from offline_baselines_jax.common.vec_env.vec_monitor import VecMonitor 14 | from offline_baselines_jax.common.vec_env.vec_normalize import VecNormalize 15 | from offline_baselines_jax.common.vec_env.vec_transpose import VecTransposeImage 16 | from offline_baselines_jax.common.vec_env.vec_video_recorder import VecVideoRecorder 17 | 18 | # Avoid circular import 19 | if typing.TYPE_CHECKING: 20 | from offline_baselines_jax.common.type_aliases import GymEnv 21 | 22 | 23 | def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]: 24 | """ 25 | Retrieve a ``VecEnvWrapper`` object by recursively searching. 26 | 27 | :param env: 28 | :param vec_wrapper_class: 29 | :return: 30 | """ 31 | env_tmp = env 32 | while isinstance(env_tmp, VecEnvWrapper): 33 | if isinstance(env_tmp, vec_wrapper_class): 34 | return env_tmp 35 | env_tmp = env_tmp.venv 36 | return None 37 | 38 | 39 | def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]: 40 | """ 41 | :param env: 42 | :return: 43 | """ 44 | return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type 45 | 46 | 47 | def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool: 48 | """ 49 | Check if an environment is already wrapped by a given ``VecEnvWrapper``. 50 | 51 | :param env: 52 | :param vec_wrapper_class: 53 | :return: 54 | """ 55 | return unwrap_vec_wrapper(env, vec_wrapper_class) is not None 56 | 57 | 58 | # Define here to avoid circular import 59 | def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: 60 | """ 61 | Sync eval env and train env when using VecNormalize 62 | 63 | :param env: 64 | :param eval_env: 65 | """ 66 | env_tmp, eval_env_tmp = env, eval_env 67 | while isinstance(env_tmp, VecEnvWrapper): 68 | if isinstance(env_tmp, VecNormalize): 69 | # Only synchronize if observation normalization exists 70 | if hasattr(env_tmp, "obs_rms"): 71 | eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) 72 | eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) 73 | env_tmp = env_tmp.venv 74 | eval_env_tmp = eval_env_tmp.venv 75 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for dealing with vectorized environments. 3 | """ 4 | from collections import OrderedDict 5 | from typing import Any, Dict, List, Tuple 6 | 7 | import gym 8 | import numpy as np 9 | 10 | from stable_baselines3.common.preprocessing import check_for_nested_spaces 11 | from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs 12 | 13 | 14 | def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 15 | """ 16 | Deep-copy a dict of numpy arrays. 17 | 18 | :param obs: a dict of numpy arrays. 19 | :return: a dict of copied numpy arrays. 20 | """ 21 | assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" 22 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) 23 | 24 | 25 | def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: 26 | """ 27 | Convert an internal representation raw_obs into the appropriate type 28 | specified by space. 29 | 30 | :param obs_space: an observation space. 31 | :param obs_dict: a dict of numpy arrays. 32 | :return: returns an observation of the same type as space. 33 | If space is Dict, function is identity; if space is Tuple, converts dict to Tuple; 34 | otherwise, space is unstructured and returns the value raw_obs[None]. 35 | """ 36 | if isinstance(obs_space, gym.spaces.Dict): 37 | return obs_dict 38 | elif isinstance(obs_space, gym.spaces.Tuple): 39 | assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" 40 | return tuple(obs_dict[i] for i in range(len(obs_space.spaces))) 41 | else: 42 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" 43 | return obs_dict[None] 44 | 45 | 46 | def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]: 47 | """ 48 | Get dict-structured information about a gym.Space. 49 | 50 | Dict spaces are represented directly by their dict of subspaces. 51 | Tuple spaces are converted into a dict with keys indexing into the tuple. 52 | Unstructured spaces are represented by {None: obs_space}. 53 | 54 | :param obs_space: an observation space 55 | :return: A tuple (keys, shapes, dtypes): 56 | keys: a list of dict keys. 57 | shapes: a dict mapping keys to shapes. 58 | dtypes: a dict mapping keys to dtypes. 59 | """ 60 | check_for_nested_spaces(obs_space) 61 | if isinstance(obs_space, gym.spaces.Dict): 62 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" 63 | subspaces = obs_space.spaces 64 | elif isinstance(obs_space, gym.spaces.Tuple): 65 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)} 66 | else: 67 | assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" 68 | subspaces = {None: obs_space} 69 | keys = [] 70 | shapes = {} 71 | dtypes = {} 72 | for key, box in subspaces.items(): 73 | keys.append(key) 74 | shapes[key] = box.shape 75 | dtypes[key] = box.dtype 76 | return keys, shapes, dtypes 77 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/vec_check_nan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper 6 | 7 | 8 | class VecCheckNan(VecEnvWrapper): 9 | """ 10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default, 11 | allowing you to know from what the NaN of inf originated from. 12 | 13 | :param venv: the vectorized environment to wrap 14 | :param raise_exception: Whether or not to raise a ValueError, instead of a UserWarning 15 | :param warn_once: Whether or not to only warn once. 16 | :param check_inf: Whether or not to check for +inf or -inf as well 17 | """ 18 | 19 | def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True): 20 | VecEnvWrapper.__init__(self, venv) 21 | self.raise_exception = raise_exception 22 | self.warn_once = warn_once 23 | self.check_inf = check_inf 24 | self._actions = None 25 | self._observations = None 26 | self._user_warned = False 27 | 28 | def step_async(self, actions: np.ndarray) -> None: 29 | self._check_val(async_step=True, actions=actions) 30 | 31 | self._actions = actions 32 | self.venv.step_async(actions) 33 | 34 | def step_wait(self) -> VecEnvStepReturn: 35 | observations, rewards, news, infos = self.venv.step_wait() 36 | 37 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news) 38 | 39 | self._observations = observations 40 | return observations, rewards, news, infos 41 | 42 | def reset(self) -> VecEnvObs: 43 | observations = self.venv.reset() 44 | self._actions = None 45 | 46 | self._check_val(async_step=False, observations=observations) 47 | 48 | self._observations = observations 49 | return observations 50 | 51 | def _check_val(self, *, async_step: bool, **kwargs) -> None: 52 | # if warn and warn once and have warned once: then stop checking 53 | if not self.raise_exception and self.warn_once and self._user_warned: 54 | return 55 | 56 | found = [] 57 | for name, val in kwargs.items(): 58 | has_nan = np.any(np.isnan(val)) 59 | has_inf = self.check_inf and np.any(np.isinf(val)) 60 | if has_inf: 61 | found.append((name, "inf")) 62 | if has_nan: 63 | found.append((name, "nan")) 64 | 65 | if found: 66 | self._user_warned = True 67 | msg = "" 68 | for i, (name, type_val) in enumerate(found): 69 | msg += f"found {type_val} in {name}" 70 | if i != len(found) - 1: 71 | msg += ", " 72 | 73 | msg += ".\r\nOriginated from the " 74 | 75 | if not async_step: 76 | if self._actions is None: 77 | msg += "environment observation (at reset)" 78 | else: 79 | msg += f"environment, Last given value was: \r\n\taction={self._actions}" 80 | else: 81 | msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}" 82 | 83 | if self.raise_exception: 84 | raise ValueError(msg) 85 | else: 86 | warnings.warn(msg, UserWarning) 87 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/policies.py: -------------------------------------------------------------------------------- 1 | """Policies: abstract base class and concrete implementations.""" 2 | 3 | import os 4 | from typing import Any, Optional, Tuple, Union, Callable, Sequence 5 | 6 | import flax 7 | import flax.linen as nn 8 | import jax 9 | import jax.numpy as jnp 10 | import optax 11 | 12 | from offline_baselines_jax.common.type_aliases import Params 13 | 14 | 15 | @flax.struct.dataclass 16 | class Model: 17 | step: int 18 | apply_fn: Callable[..., Any] = flax.struct.field(pytree_node=False) 19 | params: Params 20 | batch_stats: Union[Params] 21 | tx: Optional[optax.GradientTransformation] = flax.struct.field(pytree_node=False) 22 | opt_state: Optional[optax.OptState] = None 23 | 24 | @classmethod 25 | def create( 26 | cls, 27 | model_def: nn.Module, 28 | inputs: Sequence[jnp.ndarray], 29 | tx: Optional[optax.GradientTransformation] = None, 30 | **kwargs 31 | ) -> 'Model': 32 | 33 | variables = model_def.init(*inputs) 34 | 35 | _, params = variables.pop('params') 36 | 37 | """ 38 | NOTE: 39 | Here we unfreeze the parameter. 40 | This is because some optimizer classes in optax must receive a dict, not a frozendict, which is annoying. 41 | https://github.com/deepmind/optax/issues/160 42 | """ 43 | params = params.unfreeze() 44 | 45 | # Frozendict's 'pop' method does not support default value. So we use get method instead. 46 | batch_stats = variables.get("batch_stats", None) 47 | 48 | if tx is not None: opt_state = tx.init(params) 49 | else: opt_state = None 50 | 51 | return cls( 52 | step=1, 53 | apply_fn=model_def.apply, 54 | params=params, 55 | batch_stats=batch_stats, 56 | tx=tx, 57 | opt_state=opt_state, 58 | **kwargs 59 | ) 60 | 61 | def __call__(self, *args, **kwargs): 62 | return self.apply_fn({"params": self.params}, *args, **kwargs) 63 | 64 | def apply_gradient( 65 | self, 66 | loss_fn: Optional[Callable[[Params], Any]] = None, 67 | grads: Optional[Any] = None, 68 | has_aux: bool = True 69 | ) -> Union[Tuple['Model', Any], 'Model']: 70 | 71 | assert (loss_fn is not None or grads is not None, 'Either a loss function or grads must be specified.') 72 | 73 | if grads is None: 74 | grad_fn = jax.grad(loss_fn, has_aux=has_aux) 75 | if has_aux: grads, aux = grad_fn(self.params) 76 | else: grads = grad_fn(self.params) 77 | else: 78 | assert (has_aux, 'When grads are provided, expects no aux outputs.') 79 | 80 | updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params) 81 | new_params = optax.apply_updates(self.params, updates) 82 | new_model = self.replace(step=self.step + 1, params=new_params, opt_state=new_opt_state) 83 | 84 | if has_aux: 85 | return new_model, aux 86 | else: 87 | return new_model 88 | 89 | def save_dict(self, save_path: str) -> Params: 90 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 91 | with open(save_path, 'wb') as f: 92 | f.write(flax.serialization.to_bytes(self.params)) 93 | return self.params 94 | 95 | def load_dict(self, load_path: str) -> "Model": 96 | with open(load_path, 'rb') as f: 97 | params = flax.serialization.from_bytes(self.params, f.read()) 98 | return self.replace(params=params) 99 | 100 | def save_batch_stats(self, save_path: str) -> Params: 101 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 102 | with open(save_path, 'wb') as f: 103 | f.write(flax.serialization.to_bytes(self.batch_stats)) 104 | return self.batch_stats 105 | 106 | def load_batch_stats(self, load_path: str) -> "Model": 107 | with open(load_path, 'rb') as f: 108 | batch_stats = flax.serialization.from_bytes(self.batch_stats, f.read()) 109 | return self.replace(batch_stats=batch_stats) -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/vec_monitor.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | from typing import Optional, Tuple 4 | 5 | import numpy as np 6 | 7 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper 8 | 9 | 10 | class VecMonitor(VecEnvWrapper): 11 | """ 12 | A vectorized monitor wrapper for *vectorized* Gym environments, 13 | it is used to record the episode reward, length, time and other data. 14 | 15 | Some environments like `openai/procgen `_ 16 | or `gym3 `_ directly initialize the 17 | vectorized environments, without giving us a chance to use the ``Monitor`` 18 | wrapper. So this class simply does the job of the ``Monitor`` wrapper on 19 | a vectorized level. 20 | 21 | :param venv: The vectorized environment 22 | :param filename: the location to save a log file, can be None for no log 23 | :param info_keywords: extra information to log, from the information return of env.step() 24 | """ 25 | 26 | def __init__( 27 | self, 28 | venv: VecEnv, 29 | filename: Optional[str] = None, 30 | info_keywords: Tuple[str, ...] = (), 31 | ): 32 | # Avoid circular import 33 | from stable_baselines3.common.monitor import Monitor, ResultsWriter 34 | 35 | # This check is not valid for special `VecEnv` 36 | # like the ones created by Procgen, that does follow completely 37 | # the `VecEnv` interface 38 | try: 39 | is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0] 40 | except AttributeError: 41 | is_wrapped_with_monitor = False 42 | 43 | if is_wrapped_with_monitor: 44 | warnings.warn( 45 | "The environment is already wrapped with a `Monitor` wrapper" 46 | "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be" 47 | "overwritten by the `VecMonitor` ones.", 48 | UserWarning, 49 | ) 50 | 51 | VecEnvWrapper.__init__(self, venv) 52 | self.episode_returns = None 53 | self.episode_lengths = None 54 | self.episode_count = 0 55 | self.t_start = time.time() 56 | 57 | env_id = None 58 | if hasattr(venv, "spec") and venv.spec is not None: 59 | env_id = venv.spec.id 60 | 61 | if filename: 62 | self.results_writer = ResultsWriter( 63 | filename, header={"t_start": self.t_start, "env_id": env_id}, extra_keys=info_keywords 64 | ) 65 | else: 66 | self.results_writer = None 67 | self.info_keywords = info_keywords 68 | 69 | def reset(self) -> VecEnvObs: 70 | obs = self.venv.reset() 71 | self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) 72 | self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 73 | return obs 74 | 75 | def step_wait(self) -> VecEnvStepReturn: 76 | obs, rewards, dones, infos = self.venv.step_wait() 77 | self.episode_returns += rewards 78 | self.episode_lengths += 1 79 | new_infos = list(infos[:]) 80 | for i in range(len(dones)): 81 | if dones[i]: 82 | info = infos[i].copy() 83 | episode_return = self.episode_returns[i] 84 | episode_length = self.episode_lengths[i] 85 | episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)} 86 | for key in self.info_keywords: 87 | episode_info[key] = info[key] 88 | info["episode"] = episode_info 89 | self.episode_count += 1 90 | self.episode_returns[i] = 0 91 | self.episode_lengths[i] = 0 92 | if self.results_writer: 93 | self.results_writer.write_row(episode_info) 94 | new_infos[i] = info 95 | return obs, rewards, dones, new_infos 96 | 97 | def close(self) -> None: 98 | if self.results_writer: 99 | self.results_writer.close() 100 | return self.venv.close() 101 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/vec_video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable 3 | 4 | from gym.wrappers.monitoring import video_recorder 5 | 6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper 7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 8 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 9 | 10 | 11 | class VecVideoRecorder(VecEnvWrapper): 12 | """ 13 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. 14 | It requires ffmpeg or avconv to be installed on the machine. 15 | 16 | :param venv: 17 | :param video_folder: Where to save videos 18 | :param record_video_trigger: Function that defines when to start recording. 19 | The function takes the current number of step, 20 | and returns whether we should start recording or not. 21 | :param video_length: Length of recorded videos 22 | :param name_prefix: Prefix to the video name 23 | """ 24 | 25 | def __init__( 26 | self, 27 | venv: VecEnv, 28 | video_folder: str, 29 | record_video_trigger: Callable[[int], bool], 30 | video_length: int = 200, 31 | name_prefix: str = "rl-video", 32 | ): 33 | 34 | VecEnvWrapper.__init__(self, venv) 35 | 36 | self.env = venv 37 | # Temp variable to retrieve metadata 38 | temp_env = venv 39 | 40 | # Unwrap to retrieve metadata dict 41 | # that will be used by gym recorder 42 | while isinstance(temp_env, VecEnvWrapper): 43 | temp_env = temp_env.venv 44 | 45 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv): 46 | metadata = temp_env.get_attr("metadata")[0] 47 | else: 48 | metadata = temp_env.metadata 49 | 50 | self.env.metadata = metadata 51 | 52 | self.record_video_trigger = record_video_trigger 53 | self.video_recorder = None 54 | 55 | self.video_folder = os.path.abspath(video_folder) 56 | # Create output folder if needed 57 | os.makedirs(self.video_folder, exist_ok=True) 58 | 59 | self.name_prefix = name_prefix 60 | self.step_id = 0 61 | self.video_length = video_length 62 | 63 | self.recording = False 64 | self.recorded_frames = 0 65 | 66 | def reset(self) -> VecEnvObs: 67 | obs = self.venv.reset() 68 | self.start_video_recorder() 69 | return obs 70 | 71 | def start_video_recorder(self) -> None: 72 | self.close_video_recorder() 73 | 74 | video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" 75 | base_path = os.path.join(self.video_folder, video_name) 76 | self.video_recorder = video_recorder.VideoRecorder( 77 | env=self.env, base_path=base_path, metadata={"step_id": self.step_id} 78 | ) 79 | 80 | self.video_recorder.capture_frame() 81 | self.recorded_frames = 1 82 | self.recording = True 83 | 84 | def _video_enabled(self) -> bool: 85 | return self.record_video_trigger(self.step_id) 86 | 87 | def step_wait(self) -> VecEnvStepReturn: 88 | obs, rews, dones, infos = self.venv.step_wait() 89 | 90 | self.step_id += 1 91 | if self.recording: 92 | self.video_recorder.capture_frame() 93 | self.recorded_frames += 1 94 | if self.recorded_frames > self.video_length: 95 | print(f"Saving video to {self.video_recorder.path}") 96 | self.close_video_recorder() 97 | elif self._video_enabled(): 98 | self.start_video_recorder() 99 | 100 | return obs, rews, dones, infos 101 | 102 | def close_video_recorder(self) -> None: 103 | if self.recording: 104 | self.video_recorder.close() 105 | self.recording = False 106 | self.recorded_frames = 1 107 | 108 | def close(self) -> None: 109 | VecEnvWrapper.close(self) 110 | self.close_video_recorder() 111 | 112 | def __del__(self): 113 | self.close() 114 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/vec_transpose.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Dict, Union 3 | 4 | import numpy as np 5 | from gym import spaces 6 | 7 | from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first 8 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper 9 | 10 | 11 | class VecTransposeImage(VecEnvWrapper): 12 | """ 13 | Re-order channels, from HxWxC to CxHxW. 14 | It is required for PyTorch convolution layers. 15 | 16 | :param venv: 17 | :param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not, 18 | which may result in unwanted behavior, see GH issue #671. 19 | """ 20 | 21 | def __init__(self, venv: VecEnv, skip: bool = False): 22 | assert is_image_space(venv.observation_space) or isinstance( 23 | venv.observation_space, spaces.dict.Dict 24 | ), "The observation space must be an image or dictionary observation space" 25 | 26 | self.skip = skip 27 | # Do nothing 28 | if skip: 29 | super().__init__(venv) 30 | return 31 | 32 | if isinstance(venv.observation_space, spaces.dict.Dict): 33 | self.image_space_keys = [] 34 | observation_space = deepcopy(venv.observation_space) 35 | for key, space in observation_space.spaces.items(): 36 | if is_image_space(space): 37 | # Keep track of which keys should be transposed later 38 | self.image_space_keys.append(key) 39 | observation_space.spaces[key] = self.transpose_space(space, key) 40 | else: 41 | observation_space = self.transpose_space(venv.observation_space) 42 | super().__init__(venv, observation_space=observation_space) 43 | 44 | @staticmethod 45 | def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box: 46 | """ 47 | Transpose an observation space (re-order channels). 48 | 49 | :param observation_space: 50 | :param key: In case of dictionary space, the key of the observation space. 51 | :return: 52 | """ 53 | # Sanity checks 54 | assert is_image_space(observation_space), "The observation space must be an image" 55 | assert is_image_space_channels_first( 56 | observation_space 57 | ), f"The observation space {key} must follow the channel last convention" 58 | channels, height, width = observation_space.shape 59 | new_shape = (height, width, channels) 60 | return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) 61 | 62 | @staticmethod 63 | def transpose_image(image: np.ndarray) -> np.ndarray: 64 | """ 65 | Transpose an image or batch of images (re-order channels). 66 | 67 | :param image: 68 | :return: 69 | """ 70 | if len(image.shape) == 3: 71 | return np.transpose(image, (1, 2, 0)) 72 | return np.transpose(image, (0, 2, 3, 1)) 73 | 74 | def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]: 75 | """ 76 | Transpose (if needed) and return new observations. 77 | 78 | :param observations: 79 | :return: Transposed observations 80 | """ 81 | # Do nothing 82 | if self.skip: 83 | return observations 84 | 85 | if isinstance(observations, dict): 86 | # Avoid modifying the original object in place 87 | observations = deepcopy(observations) 88 | for k in self.image_space_keys: 89 | observations[k] = self.transpose_image(observations[k]) 90 | else: 91 | observations = self.transpose_image(observations) 92 | return observations 93 | 94 | def step_wait(self) -> VecEnvStepReturn: 95 | observations, rewards, dones, infos = self.venv.step_wait() 96 | 97 | # Transpose the terminal observations 98 | for idx, done in enumerate(dones): 99 | if not done: 100 | continue 101 | if "terminal_observation" in infos[idx]: 102 | infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"]) 103 | 104 | return self.transpose_observations(observations), rewards, dones, infos 105 | 106 | def reset(self) -> Union[np.ndarray, Dict]: 107 | """ 108 | Reset all environments 109 | """ 110 | return self.transpose_observations(self.venv.reset()) 111 | 112 | def close(self) -> None: 113 | self.venv.close() 114 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/dummy_vec_env.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from copy import deepcopy 3 | from typing import Any, Callable, List, Optional, Sequence, Type, Union 4 | 5 | import gym 6 | import numpy as np 7 | 8 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn 9 | from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info 10 | 11 | 12 | class DummyVecEnv(VecEnv): 13 | """ 14 | Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current 15 | Python process. This is useful for computationally simple environment such as ``cartpole-v1``, 16 | as the overhead of multiprocess or multithread outweighs the environment computation time. 17 | This can also be used for RL methods that 18 | require a vectorized environment, but that you want a single environments to train with. 19 | 20 | :param env_fns: a list of functions 21 | that return environments to vectorize 22 | """ 23 | 24 | def __init__(self, env_fns: List[Callable[[], gym.Env]]): 25 | self.envs = [fn() for fn in env_fns] 26 | env = self.envs[0] 27 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) 28 | obs_space = env.observation_space 29 | self.keys, shapes, dtypes = obs_space_info(obs_space) 30 | 31 | self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys]) 32 | self.buf_dones = np.zeros((self.num_envs,), dtype=bool) 33 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) 34 | self.buf_infos = [{} for _ in range(self.num_envs)] 35 | self.actions = None 36 | self.metadata = env.metadata 37 | 38 | def step_async(self, actions: np.ndarray) -> None: 39 | self.actions = actions 40 | 41 | def step_wait(self) -> VecEnvStepReturn: 42 | for env_idx in range(self.num_envs): 43 | obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( 44 | self.actions[env_idx] 45 | ) 46 | if self.buf_dones[env_idx]: 47 | # save final observation where user can get it, then reset 48 | self.buf_infos[env_idx]["terminal_observation"] = obs 49 | obs = self.envs[env_idx].reset() 50 | self._save_obs(env_idx, obs) 51 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) 52 | 53 | def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: 54 | if seed is None: 55 | seed = np.random.randint(0, 2**32 - 1) 56 | seeds = [] 57 | for idx, env in enumerate(self.envs): 58 | seeds.append(env.seed(seed + idx)) 59 | return seeds 60 | 61 | def reset(self) -> VecEnvObs: 62 | for env_idx in range(self.num_envs): 63 | obs = self.envs[env_idx].reset() 64 | self._save_obs(env_idx, obs) 65 | return self._obs_from_buf() 66 | 67 | def close(self) -> None: 68 | for env in self.envs: 69 | env.close() 70 | 71 | def get_images(self) -> Sequence[np.ndarray]: 72 | return [env.render(mode="rgb_array") for env in self.envs] 73 | 74 | def render(self, mode: str = "human") -> Optional[np.ndarray]: 75 | """ 76 | Gym environment rendering. If there are multiple environments then 77 | they are tiled together in one image via ``BaseVecEnv.render()``. 78 | Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the 79 | underlying environment. 80 | 81 | Therefore, some arguments such as ``mode`` will have values that are valid 82 | only when ``num_envs == 1``. 83 | 84 | :param mode: The rendering type. 85 | """ 86 | if self.num_envs == 1: 87 | return self.envs[0].render(mode=mode) 88 | else: 89 | return super().render(mode=mode) 90 | 91 | def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: 92 | for key in self.keys: 93 | if key is None: 94 | self.buf_obs[key][env_idx] = obs 95 | else: 96 | self.buf_obs[key][env_idx] = obs[key] 97 | 98 | def _obs_from_buf(self) -> VecEnvObs: 99 | return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) 100 | 101 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: 102 | """Return attribute from vectorized environment (see base class).""" 103 | target_envs = self._get_target_envs(indices) 104 | return [getattr(env_i, attr_name) for env_i in target_envs] 105 | 106 | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: 107 | """Set attribute inside vectorized environments (see base class).""" 108 | target_envs = self._get_target_envs(indices) 109 | for env_i in target_envs: 110 | setattr(env_i, attr_name, value) 111 | 112 | def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: 113 | """Call instance methods of vectorized environments.""" 114 | target_envs = self._get_target_envs(indices) 115 | return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] 116 | 117 | def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: 118 | """Check if worker environments are wrapped with a given wrapper""" 119 | target_envs = self._get_target_envs(indices) 120 | # Import here to avoid a circular import 121 | from stable_baselines3.common import env_util 122 | 123 | return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs] 124 | 125 | def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]: 126 | indices = self._get_indices(indices) 127 | return [self.envs[i] for i in indices] 128 | -------------------------------------------------------------------------------- /offline_baselines_jax/td3/core.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Any, Tuple 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from offline_baselines_jax.common.jax_layers import polyak_update 8 | from offline_baselines_jax.common.policies import Model 9 | from offline_baselines_jax.common.type_aliases import InfoDict, Params 10 | 11 | STATIC_ARGNAMES = ( 12 | "gamma", 13 | "tau", 14 | "target_policy_noise", 15 | "target_noise_clip", 16 | "alpha", 17 | "without_exploration", 18 | "actor_update_cond" 19 | ) 20 | 21 | 22 | def td3_critic_update( 23 | rng:Any, 24 | critic: Model, 25 | critic_target: Model, 26 | actor_target: Model, 27 | 28 | observations: jnp.ndarray, 29 | actions: jnp.ndarray, 30 | next_observations: jnp.ndarray, 31 | rewards: jnp.ndarray, 32 | dones: jnp.ndarray, 33 | 34 | gamma:float, 35 | target_policy_noise: float, 36 | target_noise_clip: float 37 | ): 38 | 39 | dropout_key, _ = jax.random.split(rng) 40 | 41 | # Select action according to policy and add clipped noise 42 | noise = jax.random.normal(rng, shape=actions.shape) * jnp.sqrt(target_policy_noise) 43 | noise = jnp.clip(noise, -target_noise_clip, target_noise_clip) 44 | next_actions = actor_target(next_observations, rngs={"dropout": dropout_key}) 45 | next_actions = jnp.clip(next_actions + noise, -1.0, 1.0) 46 | 47 | # Compute the next Q-values: min over all critics targets 48 | next_q_values = critic_target(next_observations, next_actions, deterministic=False, rngs={"dropout": dropout_key}) 49 | next_q_values = jnp.min(next_q_values, axis=1) 50 | target_q_values = rewards + (1 - dones) * gamma * next_q_values 51 | 52 | def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 53 | # Get current Q-values estimates for each critic network 54 | q_values = critic.apply_fn( 55 | {"params": critic_params}, 56 | observations, actions, 57 | deterministic=False, 58 | rngs={"dropout": dropout_key} 59 | ) 60 | 61 | # Compute critic loss 62 | n_qs = q_values.shape[1] 63 | 64 | critic_loss = sum(jnp.mean((target_q_values - q_values[:, i, ...]) ** 2) for i in range(n_qs)) 65 | # critic_loss = critic_loss / n_qs 66 | return critic_loss, {'critic_loss': critic_loss, 'current_q': q_values.mean()} 67 | 68 | new_critic, info = critic.apply_gradient(critic_loss_fn) 69 | return new_critic, info 70 | 71 | 72 | def td3_actor_update( 73 | rng: jnp.ndarray, 74 | actor: Model, 75 | critic: Model, 76 | 77 | observations: jnp.ndarray, 78 | actions: jnp.ndarray, 79 | 80 | alpha: float, 81 | without_exploration: bool 82 | ): 83 | dropout_key, _ = jax.random.split(rng) 84 | 85 | if without_exploration: 86 | _actions_pi = actor(observations, deterministic=False, rngs={"dropout": dropout_key}) 87 | q1 = critic(observations, _actions_pi, deterministic=False, rngs={"dropout": dropout_key})[0] 88 | coef_lambda = alpha / (jnp.mean(jnp.abs(q1))) 89 | else: 90 | coef_lambda = 1 91 | 92 | def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 93 | # Compute actor loss 94 | actions_pi = actor.apply_fn( 95 | {"params": actor_params}, 96 | observations, 97 | deterministic=False, 98 | rngs={"dropout": dropout_key} 99 | ) 100 | q_value = critic( 101 | observations, 102 | actions_pi, 103 | deterministic=False, 104 | rngs={"dropout": dropout_key} 105 | )[0].mean() 106 | 107 | actor_loss = - q_value 108 | 109 | if without_exploration: 110 | bc_loss = jnp.mean(jnp.square(actions_pi - actions)) 111 | actor_loss = coef_lambda * actor_loss + bc_loss 112 | 113 | return actor_loss, {'actor_loss': actor_loss, 'q_value': q_value, 'coef_lambda': coef_lambda} 114 | 115 | new_actor, info = actor.apply_gradient(actor_loss_fn) 116 | return new_actor, info 117 | 118 | 119 | @functools.partial(jax.jit, static_argnames=tuple(STATIC_ARGNAMES)) 120 | def update_td3( 121 | rng: int, 122 | actor: Model, 123 | critic: Model, 124 | actor_target: Model, 125 | critic_target: Model, 126 | 127 | observations: jnp.ndarray, 128 | actions: jnp.ndarray, 129 | next_observations: jnp.ndarray, 130 | rewards: jnp.ndarray, 131 | dones: jnp.ndarray, 132 | 133 | actor_update_cond: bool, 134 | tau: float, 135 | target_policy_noise: float, 136 | target_noise_clip: float, 137 | gamma: float, 138 | alpha: float, 139 | without_exploration: bool 140 | ): 141 | 142 | rng, key = jax.random.split(rng, 2) 143 | new_critic, critic_info = td3_critic_update( 144 | rng=rng, 145 | critic=critic, 146 | critic_target=critic_target, 147 | actor_target=actor_target, 148 | observations=observations, 149 | actions=actions, 150 | next_observations=next_observations, 151 | rewards=rewards, 152 | dones=dones, 153 | gamma=gamma, 154 | target_policy_noise=target_policy_noise, 155 | target_noise_clip=target_noise_clip 156 | ) 157 | 158 | if actor_update_cond: 159 | new_actor, actor_info = td3_actor_update( 160 | rng=rng, 161 | actor=actor, 162 | critic=new_critic, 163 | observations=observations, 164 | actions=actions, 165 | alpha=alpha, 166 | without_exploration=without_exploration 167 | ) 168 | new_actor_target = polyak_update(new_actor, actor_target, tau) 169 | new_critic_target = polyak_update(new_critic, critic_target, tau) 170 | else: 171 | new_actor, actor_info = actor, {'actor_loss': 0, 'q_value': 0, 'coef_lambda': 1} 172 | new_actor_target = actor_target 173 | new_critic_target = critic_target 174 | 175 | new_models = { 176 | "critic": new_critic, 177 | "critic_target": new_critic_target, 178 | "actor": new_actor, 179 | "actor_target": new_actor_target 180 | } 181 | 182 | return rng, new_models, {**critic_info, **actor_info, "actor_update_cond": actor_update_cond} -------------------------------------------------------------------------------- /offline_baselines_jax/sac/core.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Any, Tuple 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from offline_baselines_jax.common.policies import Model 8 | from offline_baselines_jax.common.type_aliases import ( 9 | InfoDict, 10 | Params 11 | ) 12 | from offline_baselines_jax.common.jax_layers import polyak_update 13 | 14 | 15 | def log_ent_coef_update( 16 | rng:Any, 17 | log_ent_coef: Model, 18 | actor: Model, 19 | observations: jnp.ndarray, 20 | target_entropy: float, 21 | ) -> Tuple[Model, InfoDict]: 22 | dropout_key, _ = jax.random.split(rng) 23 | dist = actor(observations, deterministic=False, rngs={"dropout": dropout_key}) 24 | actions_pi = dist.sample(seed=rng) 25 | log_prob = dist.log_prob(actions_pi) 26 | 27 | def temperature_loss_fn(ent_params: Params): 28 | ent_coef = log_ent_coef.apply_fn({'params': ent_params}) 29 | ent_coef_loss = -(ent_coef * (target_entropy + log_prob)).mean() 30 | 31 | return ent_coef_loss, {'ent_coef': ent_coef, 'ent_coef_loss': ent_coef_loss} 32 | 33 | new_ent_coef, info = log_ent_coef.apply_gradient(temperature_loss_fn) 34 | return new_ent_coef, info 35 | 36 | 37 | def sac_actor_update( 38 | rng: int, 39 | actor: Model, 40 | critic: Model, 41 | log_ent_coef: Model, 42 | 43 | observations: jnp.ndarray, 44 | ): 45 | ent_coef = jnp.exp(log_ent_coef()) 46 | dropout_key, _ = jax.random.split(rng) 47 | 48 | def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 49 | dist = actor.apply_fn( 50 | {'params': actor_params}, 51 | observations, 52 | deterministic=False, 53 | rngs={"dropout": dropout_key} 54 | ) 55 | actions_pi = dist.sample(seed=rng) 56 | log_prob = dist.log_prob(actions_pi) 57 | 58 | q_values_pi = critic(observations, actions_pi, deterministic=False, rngs={"dropout": dropout_key}) 59 | min_qf_pi = jnp.min(q_values_pi, axis=1) 60 | 61 | actor_loss = (ent_coef * log_prob - min_qf_pi).mean() 62 | return actor_loss, {'actor_loss': actor_loss, 'entropy': -log_prob} 63 | 64 | new_actor, info = actor.apply_gradient(actor_loss_fn) 65 | return new_actor, info 66 | 67 | 68 | def sac_critic_update( 69 | rng:Any, 70 | actor: Model, 71 | critic: Model, 72 | critic_target: Model, 73 | log_ent_coef: Model, 74 | 75 | observations: jnp.ndarray, 76 | actions: jnp.ndarray, 77 | next_observations: jnp.ndarray, 78 | rewards: jnp.ndarray, 79 | dones: jnp.ndarray, 80 | 81 | gamma:float 82 | ): 83 | dropout_key, _ = jax.random.split(rng) 84 | 85 | dist = actor(next_observations, deterministic=False, rngs={"dropout": dropout_key}) 86 | next_actions = dist.sample(seed=rng) 87 | next_log_prob = dist.log_prob(next_actions) 88 | 89 | # Compute the next Q values: min over all critics targets 90 | next_q_values = critic_target(next_observations, next_actions, deterministic=False, rngs={"dropout": dropout_key}) 91 | next_q_values = jnp.min(next_q_values, axis=1) 92 | 93 | ent_coef = jnp.exp(log_ent_coef()) 94 | 95 | # add entropy term 96 | next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1) 97 | # td error + entropy term 98 | target_q_values = rewards + (1 - dones) * gamma * next_q_values 99 | 100 | def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 101 | # Get current Q-values estimates for each critic network 102 | # using action from the replay buffer 103 | q_values = critic.apply_fn( 104 | {'params': critic_params}, 105 | observations, 106 | actions, 107 | deterministic=False, 108 | rngs={"dropout": dropout_key} 109 | ) 110 | 111 | # Compute critic loss 112 | n_qs = q_values.shape[1] 113 | 114 | critic_loss = sum([jnp.mean((target_q_values - q_values[:, i, ...]) ** 2) for i in range(n_qs)]) 115 | critic_loss = critic_loss / n_qs 116 | 117 | return critic_loss, {'critic_loss': critic_loss, 'current_q': q_values.mean(), "n_qs": n_qs} 118 | 119 | new_critic, info = critic.apply_gradient(critic_loss_fn) 120 | return new_critic, info 121 | 122 | 123 | @functools.partial(jax.jit, static_argnames=('gamma', 'target_entropy', 'tau', 'target_update_cond', 'entropy_update')) 124 | def sac_update( 125 | rng: int, 126 | actor: Model, 127 | critic: Model, 128 | critic_target: Model, 129 | log_ent_coef: Model, 130 | 131 | observations: jnp.ndarray, 132 | actions: jnp.ndarray, 133 | rewards: jnp.ndarray, 134 | next_observations: jnp.ndarray, 135 | dones: jnp.ndarray, 136 | 137 | gamma: float, 138 | target_entropy: float, 139 | tau: float, 140 | target_update_cond: bool, 141 | entropy_update: bool 142 | ): 143 | 144 | rng, key = jax.random.split(rng, 2) 145 | new_critic, critic_info = sac_critic_update( 146 | rng=rng, 147 | actor=actor, 148 | critic=critic, 149 | critic_target=critic_target, 150 | log_ent_coef=log_ent_coef, 151 | observations=observations, 152 | actions=actions, 153 | next_observations=next_observations, 154 | rewards=rewards, 155 | dones=dones, 156 | gamma=gamma 157 | ) 158 | 159 | if target_update_cond: 160 | new_critic_target = polyak_update(new_critic, critic_target, tau) 161 | else: 162 | new_critic_target = critic_target 163 | 164 | rng, key = jax.random.split(rng, 2) 165 | new_actor, actor_info = sac_actor_update( 166 | rng=rng, 167 | actor=actor, 168 | critic=critic, 169 | log_ent_coef=log_ent_coef, 170 | observations=observations 171 | ) 172 | 173 | rng, key = jax.random.split(rng, 2) 174 | if entropy_update: 175 | new_temp, ent_info = log_ent_coef_update( 176 | rng=rng, 177 | log_ent_coef=log_ent_coef, 178 | actor=actor, 179 | observations=observations, 180 | target_entropy=target_entropy 181 | ) 182 | else: 183 | new_temp, ent_info = log_ent_coef, {'ent_coef': jnp.exp(log_ent_coef()), 'ent_coef_loss': 0} 184 | 185 | new_models = { 186 | "critic": new_critic, 187 | "critic_target": new_critic_target, 188 | "actor": new_actor, 189 | "log_ent_coef": new_temp 190 | } 191 | return rng, new_models, {**critic_info, **actor_info, **ent_info} 192 | -------------------------------------------------------------------------------- /offline_baselines_jax/td3/td3.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Type, Union 2 | 3 | import gym 4 | import jax 5 | import numpy as np 6 | from stable_baselines3.common.noise import ActionNoise 7 | 8 | from offline_baselines_jax.common.buffers import ReplayBuffer 9 | from offline_baselines_jax.common.off_policy_algorithm import OffPolicyAlgorithm 10 | from offline_baselines_jax.common.type_aliases import GymEnv, MaybeCallback, Schedule, Params 11 | from offline_baselines_jax.td3.policies import TD3Policy 12 | from .core import update_td3 13 | 14 | 15 | class TD3(OffPolicyAlgorithm): 16 | 17 | def __init__( 18 | self, 19 | env: Union[GymEnv, str], 20 | policy: Union[str, Type[TD3Policy]] = TD3Policy, 21 | learning_rate: Union[float, Schedule] = 1e-3, 22 | buffer_size: int = 1_000_000, # 1e6 23 | learning_starts: int = 100, 24 | batch_size: int = 100, 25 | tau: float = 0.005, 26 | gamma: float = 0.99, 27 | train_freq: Union[int, Tuple[int, str]] = (1, 'episode'), 28 | gradient_steps: int = -1, 29 | action_noise: Optional[ActionNoise] = None, 30 | replay_buffer_class: Optional[ReplayBuffer] = None, 31 | replay_buffer_kwargs: Optional[Dict[str, Any]] = None, 32 | optimize_memory_usage: bool = False, 33 | policy_delay: int = 2, 34 | target_policy_noise: float = 0.2, 35 | target_noise_clip: float = 0.5, 36 | tensorboard_log: Optional[str] = None, 37 | create_eval_env: bool = False, 38 | policy_kwargs: Optional[Dict[str, Any]] = None, 39 | verbose: int = 0, 40 | seed: int = 0, 41 | alpha: int = 2.5, 42 | _init_setup_model: bool = True, 43 | without_exploration: bool = False, 44 | ): 45 | 46 | super(TD3, self).__init__( 47 | policy, 48 | env, 49 | learning_rate, 50 | buffer_size, 51 | learning_starts, 52 | batch_size, 53 | tau, 54 | gamma, 55 | train_freq, 56 | gradient_steps, 57 | action_noise=action_noise, 58 | replay_buffer_class=replay_buffer_class, 59 | replay_buffer_kwargs=replay_buffer_kwargs, 60 | policy_kwargs=policy_kwargs, 61 | tensorboard_log=tensorboard_log, 62 | verbose=verbose, 63 | create_eval_env=create_eval_env, 64 | seed=seed, 65 | optimize_memory_usage=optimize_memory_usage, 66 | supported_action_spaces=(gym.spaces.Box), 67 | support_multi_env=True, 68 | without_exploration=without_exploration, 69 | ) 70 | if without_exploration and gradient_steps == -1: 71 | self.gradient_steps = policy_delay 72 | 73 | self.alpha = alpha 74 | self.policy_delay = policy_delay 75 | self.target_noise_clip = target_noise_clip 76 | self.target_policy_noise = target_policy_noise 77 | 78 | if _init_setup_model: 79 | self._setup_model() 80 | 81 | def _setup_model(self) -> None: 82 | super(TD3, self)._setup_model() 83 | self._create_aliases() 84 | 85 | def _create_aliases(self) -> None: 86 | self.actor = self.policy.actor 87 | self.actor_target = self.policy.actor_target 88 | self.critic = self.policy.critic 89 | self.critic_target = self.policy.critic_target 90 | 91 | def train(self, gradient_steps: int, batch_size: int = 100) -> None: 92 | actor_losses, critic_losses, coef_lambda = [], [], [] 93 | for _ in range(gradient_steps): 94 | self._n_updates += 1 95 | # Sample replay buffer 96 | replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) 97 | self.rng, key = jax.random.split(self.rng, 2) 98 | 99 | actor_update_cond = (self._n_updates % self.policy_delay == 0) 100 | self.rng, new_models, info = \ 101 | update_td3( 102 | key, 103 | actor=self.actor, 104 | actor_target=self.actor_target, 105 | critic=self.critic, 106 | critic_target=self.critic_target, 107 | 108 | observations=replay_data.observations, 109 | actions=replay_data.actions, 110 | next_observations=replay_data.next_observations, 111 | rewards=replay_data.rewards, 112 | dones=replay_data.dones, 113 | 114 | actor_update_cond=actor_update_cond, 115 | tau=self.tau, 116 | target_policy_noise=self.target_policy_noise, 117 | target_noise_clip=self.target_noise_clip, 118 | gamma=self.gamma, 119 | alpha=self.alpha, 120 | without_exploration=self.without_exploration 121 | ) 122 | 123 | self.apply_new_models(new_models) 124 | self.actor_target = new_models["actor_target"] 125 | self.policy.actor_target = new_models["actor_target"] 126 | 127 | actor_losses.append(info['actor_loss']) 128 | critic_losses.append(info['critic_loss']) 129 | coef_lambda.append(info['coef_lambda']) 130 | 131 | 132 | self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") 133 | if len(actor_losses) > 0: 134 | self.logger.record("train/actor_loss", np.mean(actor_losses)) 135 | self.logger.record("train/critic_loss", np.mean(critic_losses)) 136 | self.logger.record("train/coef", np.mean(coef_lambda)) 137 | 138 | def learn( 139 | self, 140 | total_timesteps: int, 141 | callback: MaybeCallback = None, 142 | log_interval: int = 4, 143 | eval_env: Optional[GymEnv] = None, 144 | eval_freq: int = -1, 145 | n_eval_episodes: int = 5, 146 | tb_log_name: str = "TD3", 147 | eval_log_path: Optional[str] = None, 148 | reset_num_timesteps: bool = True, 149 | ) -> OffPolicyAlgorithm: 150 | 151 | return super(TD3, self).learn( 152 | total_timesteps=total_timesteps, 153 | callback=callback, 154 | log_interval=log_interval, 155 | eval_env=eval_env, 156 | eval_freq=eval_freq, 157 | n_eval_episodes=n_eval_episodes, 158 | tb_log_name=tb_log_name, 159 | eval_log_path=eval_log_path, 160 | reset_num_timesteps=reset_num_timesteps, 161 | ) 162 | 163 | def _excluded_save_params(self) -> List[str]: 164 | return super(TD3, self)._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] 165 | 166 | def _get_jax_save_params(self) -> Dict[str, Params]: 167 | params_dict = {} 168 | params_dict['actor'] = self.actor.params 169 | params_dict['critic'] = self.critic.params 170 | params_dict['critic_target'] = self.critic_target.params 171 | params_dict['actor_target'] = self.actor_target.params 172 | return params_dict 173 | 174 | def _get_jax_load_params(self) -> List[str]: 175 | return ['actor', 'critic', 'critic_target', 'actor_target'] 176 | 177 | def _load_policy(self) -> None: 178 | super(TD3, self)._load_policy() 179 | self.policy.actor_target = self.actor_target 180 | -------------------------------------------------------------------------------- /offline_baselines_jax/cql/core.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Any, Tuple 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from offline_baselines_jax.common.policies import Model 8 | from offline_baselines_jax.common.type_aliases import InfoDict, ReplayBufferSamples, Params 9 | 10 | 11 | def log_alpha_update(log_alpha_coef: Model, conservative_loss: float) -> Tuple[Model, InfoDict]: 12 | def alpha_loss_fn(alpha_params: Params): 13 | alpha_coef = jnp.exp(log_alpha_coef.apply_fn({'params': alpha_params})) 14 | alpha_coef_loss = -alpha_coef * conservative_loss 15 | 16 | return alpha_coef_loss, {'alpha_coef': alpha_coef, 'alpha_coef_loss': alpha_coef_loss} 17 | 18 | new_alpha_coef, info = log_alpha_coef.apply_gradient(alpha_loss_fn) 19 | new_alpha_coef = param_clip(new_alpha_coef, 1e+6) 20 | return new_alpha_coef, info 21 | 22 | def log_ent_coef_update(key:Any, log_ent_coef: Model, actor:Model , target_entropy: float, replay_data:ReplayBufferSamples) -> Tuple[Model, InfoDict]: 23 | def temperature_loss_fn(ent_params: Params): 24 | dist = actor(replay_data.observations) 25 | actions_pi = dist.sample(seed=key) 26 | log_prob = dist.log_prob(actions_pi) 27 | 28 | ent_coef = log_ent_coef.apply_fn({'params': ent_params}) 29 | ent_coef_loss = -(ent_coef * (target_entropy + log_prob)).mean() 30 | 31 | return ent_coef_loss, {'ent_coef': ent_coef, 'ent_coef_loss': ent_coef_loss} 32 | 33 | new_ent_coef, info = log_ent_coef.apply_gradient(temperature_loss_fn) 34 | return new_ent_coef, info 35 | 36 | 37 | def sac_actor_update(key: int, actor: Model, critic:Model, log_ent_coef: Model, replay_data:ReplayBufferSamples): 38 | def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 39 | dist = actor.apply_fn({'params': actor_params}, replay_data.observations) 40 | actions_pi = dist.sample(seed=key) 41 | log_prob = dist.log_prob(actions_pi) 42 | 43 | ent_coef = jnp.exp(log_ent_coef()) 44 | 45 | q_values_pi = critic(replay_data.observations, actions_pi) 46 | min_qf_pi = jnp.min(q_values_pi, axis=0) 47 | 48 | actor_loss = (ent_coef * log_prob - min_qf_pi).mean() 49 | return actor_loss, {'actor_loss': actor_loss, 'entropy': -log_prob} 50 | 51 | new_actor, info = actor.apply_gradient(actor_loss_fn) 52 | return new_actor, info 53 | 54 | 55 | def sac_critic_update(key:Any, actor: Model, critic: Model, critic_target: Model, log_ent_coef: Model, 56 | log_alpha_coef: Model, replay_data: ReplayBufferSamples, gamma:float, conservative_weight:float, 57 | lagrange_thresh:float,): 58 | next_dist = actor(replay_data.next_observations) 59 | next_actions = next_dist.sample(seed=key) 60 | next_log_prob = next_dist.log_prob(next_actions) 61 | 62 | # Compute the next Q values: min over all critics targets 63 | next_q_values = critic_target(replay_data.next_observations, next_actions) 64 | next_q_values = jnp.min(next_q_values, axis=0) 65 | 66 | ent_coef = jnp.exp(log_ent_coef()) 67 | # add entropy term 68 | next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1) 69 | # td error + entropy term 70 | target_q_values = replay_data.rewards + (1 - replay_data.dones) * gamma * next_q_values 71 | 72 | batch_size, action_dim = replay_data.actions.shape 73 | alpha_coef = jnp.exp(log_alpha_coef()) 74 | ############################### 75 | ## For CQL Conservative Loss ## 76 | ############################### 77 | 78 | cql_dist = actor(replay_data.observations) 79 | cql_actions = cql_dist.sample(seed=key) 80 | cql_log_prob = cql_dist.log_prob(cql_actions) 81 | 82 | repeated_observations = jnp.repeat(replay_data.observations, repeats=10, axis=0) 83 | key, subkey = jax.random.split(key, 2) 84 | random_actions = jax.random.uniform(subkey, minval=-1, maxval=1, shape=(batch_size * 10, action_dim)) 85 | 86 | random_density = jnp.log(0.5 ** action_dim) 87 | alpha_coef = jnp.clip(jnp.exp(log_alpha_coef()), 0, 1e+6) 88 | 89 | def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]: 90 | # Get current Q-values estimates for each critic network 91 | # using action from the replay buffer 92 | current_q = critic.apply_fn({'params': critic_params}, replay_data.observations, replay_data.actions) 93 | cql_q = critic.apply_fn({'params': critic_params}, replay_data.observations, cql_actions) 94 | random_q = critic.apply_fn({'params': critic_params}, repeated_observations, random_actions) 95 | conservative_loss = 0 96 | for idx in range(len(cql_q)): 97 | conservative_loss += jax.scipy.special.logsumexp(jnp.ndarray([jnp.repeat(cql_q[idx], repeats=10, axis=0) - cql_log_prob, random_q[idx] - random_density])).mean() - current_q[idx].mean() 98 | conservative_loss = (conservative_weight * ((conservative_loss) / len(cql_q)) - lagrange_thresh) 99 | # Compute critic loss 100 | critic_loss = 0 101 | for q in current_q: 102 | critic_loss = critic_loss + jnp.mean(jnp.square(q - target_q_values)) 103 | critic_loss = critic_loss / len(current_q) + alpha_coef * conservative_loss 104 | 105 | return critic_loss, {'critic_loss': critic_loss, 'current_q': current_q.mean(), 'conservative_loss': conservative_loss} 106 | 107 | new_critic, info = critic.apply_gradient(critic_loss_fn) 108 | return new_critic, info 109 | 110 | def param_clip(log_alpha_coef: Model, a_max: float) -> Model: 111 | new_log_alpha_params = jax.tree_multimap(lambda p: jnp.clip(p, a_max=jnp.log(a_max)), log_alpha_coef.params) 112 | return log_alpha_coef.replace(params=new_log_alpha_params) 113 | 114 | def target_update(critic: Model, critic_target: Model, tau: float) -> Model: 115 | new_target_params = jax.tree_multimap(lambda p, tp: p * tau + tp * (1 - tau), critic.params, critic_target.params) 116 | return critic_target.replace(params=new_target_params) 117 | 118 | 119 | @functools.partial(jax.jit, static_argnames=('gamma', 'target_entropy', 'tau', 'target_update_cond', 'entropy_update', 120 | 'alpha_update', 'conservative_weight', 'lagrange_thresh')) 121 | def update_cql( 122 | rng: int, actor: Model, critic: Model, critic_target: Model, log_ent_coef: Model, log_alpha_coef: Model, replay_data: ReplayBufferSamples, 123 | gamma: float, target_entropy: float, tau: float, target_update_cond: bool, entropy_update: bool, alpha_update: bool, 124 | conservative_weight:float, lagrange_thresh:float, 125 | ) -> Tuple[int, Model, Model, Model, Model, Model, InfoDict]: 126 | rng, key = jax.random.split(rng) 127 | new_critic, critic_info = sac_critic_update(key, actor, critic, critic_target, log_ent_coef, log_alpha_coef, replay_data, 128 | gamma, conservative_weight, lagrange_thresh) 129 | if target_update_cond: 130 | new_critic_target = target_update(new_critic, critic_target, tau) 131 | else: 132 | new_critic_target = critic_target 133 | 134 | rng, key = jax.random.split(rng) 135 | new_actor, actor_info = sac_actor_update(key, actor, new_critic, log_ent_coef, replay_data) 136 | rng, key = jax.random.split(rng) 137 | if entropy_update: 138 | new_temp, ent_info = log_ent_coef_update(key, log_ent_coef, new_actor, target_entropy, replay_data) 139 | else: 140 | new_temp, ent_info = log_ent_coef, {'ent_coef': jnp.exp(log_ent_coef()), 'ent_coef_loss': 0} 141 | 142 | if alpha_update: 143 | new_alpha, alpha_info = log_alpha_update(log_alpha_coef, critic_info['conservative_loss']) 144 | else: 145 | new_alpha, alpha_info = log_alpha_coef, {'alpha_coef': jnp.exp(log_alpha_coef()), 'alpha_coef_loss': 0} 146 | 147 | return rng, new_actor, new_critic, new_critic_target, new_temp, new_alpha, {**critic_info, **actor_info, **ent_info, **alpha_info} 148 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/preprocessing.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, Tuple, Union 3 | 4 | import numpy as np 5 | import jax.numpy as jnp 6 | from gym import spaces 7 | 8 | def is_image_space_channels_first(observation_space: spaces.Box) -> bool: 9 | """ 10 | Check if an image observation space (see ``is_image_space``) 11 | is channels-first (CxHxW, True) or channels-last (HxWxC, False). 12 | 13 | Use a heuristic that channel dimension is the smallest of the three. 14 | If second dimension is smallest, raise an exception (no support). 15 | 16 | :param observation_space: 17 | :return: True if observation space is channels-first image, False if channels-last. 18 | """ 19 | smallest_dimension = np.argmin(observation_space.shape).item() 20 | if smallest_dimension == 1: 21 | warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.") 22 | return smallest_dimension == 0 23 | 24 | 25 | def is_image_space( 26 | observation_space: spaces.Space, 27 | check_channels: bool = False, 28 | ) -> bool: 29 | """ 30 | Check if a observation space has the shape, limits and dtype 31 | of a valid image. 32 | The check is conservative, so that it returns False if there is a doubt. 33 | 34 | Valid images: RGB, RGBD, GrayScale with values in [0, 255] 35 | 36 | :param observation_space: 37 | :param check_channels: Whether to do or not the check for the number of channels. 38 | e.g., with frame-stacking, the observation space may have more channels than expected. 39 | :return: 40 | """ 41 | if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3: 42 | # Check the type 43 | if observation_space.dtype != np.uint8: 44 | return False 45 | # Check the value range 46 | if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): 47 | return False 48 | # Skip channels check 49 | if not check_channels: 50 | return True 51 | # Check the number of channels 52 | if is_image_space_channels_first(observation_space): 53 | n_channels = observation_space.shape[0] 54 | else: 55 | n_channels = observation_space.shape[-1] 56 | # RGB, RGBD, GrayScale 57 | return n_channels in [1, 3, 4] 58 | return False 59 | 60 | 61 | def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> np.ndarray: 62 | """ 63 | Handle the different cases for images as PyTorch use channel first format. 64 | 65 | :param observation: 66 | :param observation_space: 67 | :return: channel first observation if observation is an image 68 | """ 69 | # Avoid circular import 70 | from stable_baselines3.common.vec_env import VecTransposeImage 71 | 72 | if is_image_space(observation_space): 73 | if not (observation.shape == observation_space.shape or observation.shape[1:] == observation_space.shape): 74 | # Try to re-order the channels 75 | transpose_obs = VecTransposeImage.transpose_image(observation) 76 | if transpose_obs.shape == observation_space.shape or transpose_obs.shape[1:] == observation_space.shape: 77 | observation = transpose_obs 78 | return observation 79 | 80 | 81 | def get_obs_shape( 82 | observation_space: spaces.Space, 83 | ) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]: 84 | """ 85 | Get the shape of the observation (useful for the buffers). 86 | 87 | :param observation_space: 88 | :return: 89 | """ 90 | if isinstance(observation_space, spaces.Box): 91 | return observation_space.shape 92 | elif isinstance(observation_space, spaces.Discrete): 93 | # Observation is an int 94 | return (1,) 95 | elif isinstance(observation_space, spaces.MultiDiscrete): 96 | # Number of discrete features 97 | return (int(len(observation_space.nvec)),) 98 | elif isinstance(observation_space, spaces.MultiBinary): 99 | # Number of binary features 100 | return (int(observation_space.n),) 101 | elif isinstance(observation_space, spaces.Dict): 102 | return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} 103 | 104 | else: 105 | raise NotImplementedError(f"{observation_space} observation space is not supported") 106 | 107 | 108 | def get_flattened_obs_dim(observation_space: spaces.Space) -> int: 109 | """ 110 | Get the dimension of the observation space when flattened. 111 | It does not apply to image observation space. 112 | 113 | Used by the ``FlattenExtractor`` to compute the input shape. 114 | 115 | :param observation_space: 116 | :return: 117 | """ 118 | # See issue https://github.com/openai/gym/issues/1915 119 | # it may be a problem for Dict/Tuple spaces too... 120 | if isinstance(observation_space, spaces.MultiDiscrete): 121 | return sum(observation_space.nvec) 122 | else: 123 | # Use Gym internal method 124 | return spaces.utils.flatdim(observation_space) 125 | 126 | 127 | def get_action_dim(action_space: spaces.Space) -> int: 128 | """ 129 | Get the dimension of the action space. 130 | 131 | :param action_space: 132 | :return: 133 | """ 134 | if isinstance(action_space, spaces.Box): 135 | return int(np.prod(action_space.shape)) 136 | elif isinstance(action_space, spaces.Discrete): 137 | # Action is an int 138 | return 1 139 | elif isinstance(action_space, spaces.MultiDiscrete): 140 | # Number of discrete actions 141 | return int(len(action_space.nvec)) 142 | elif isinstance(action_space, spaces.MultiBinary): 143 | # Number of binary actions 144 | return int(action_space.n) 145 | else: 146 | raise NotImplementedError(f"{action_space} action space is not supported") 147 | 148 | 149 | def check_for_nested_spaces(obs_space: spaces.Space): 150 | """ 151 | Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples). 152 | If so, raise an Exception informing that there is no support for this. 153 | 154 | :param obs_space: an observation space 155 | :return: 156 | """ 157 | if isinstance(obs_space, (spaces.Dict, spaces.Tuple)): 158 | sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces 159 | for sub_space in sub_spaces: 160 | if isinstance(sub_space, (spaces.Dict, spaces.Tuple)): 161 | raise NotImplementedError( 162 | "Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)." 163 | ) 164 | 165 | 166 | def preprocess_obs( 167 | obs: jnp.ndarray, 168 | observation_space: spaces.Space, 169 | normalize_images: bool = True 170 | ): 171 | if isinstance(observation_space, spaces.Box): 172 | if is_image_space(observation_space) and normalize_images: 173 | return obs / 255.0 174 | return obs 175 | 176 | elif isinstance(observation_space, spaces.Discrete): 177 | n_classes = observation_space.n 178 | return np.eye(n_classes)[obs] 179 | 180 | elif isinstance(observation_space, spaces.MultiDiscrete): 181 | return np.concatenate( 182 | [np.eye(observation_space.nvec[idx])[obs_] for idx, obs_ in enumerate(np.split(obs, 1, axis=1))] 183 | ).view(obs.shape[0], sum(observation_space.nvec)) 184 | 185 | elif isinstance(observation_space, spaces.MultiBinary): 186 | return obs 187 | 188 | elif isinstance(observation_space, spaces.Dict): 189 | # Do not modify by reference the original observation 190 | preprocessed_obs = {} 191 | for key, _obs in obs.items(): 192 | preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images) 193 | return preprocessed_obs 194 | 195 | else: 196 | raise NotImplementedError(f"Preprocessing not implemented for {observation_space}") -------------------------------------------------------------------------------- /offline_baselines_jax/sac/sac.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Type, Union 2 | 3 | import flax.linen as nn 4 | import gym 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import optax 9 | from stable_baselines3.common.noise import ActionNoise 10 | 11 | from offline_baselines_jax.common.buffers import ReplayBuffer 12 | from offline_baselines_jax.common.off_policy_algorithm import OffPolicyAlgorithm 13 | from offline_baselines_jax.common.policies import Model 14 | from offline_baselines_jax.common.type_aliases import GymEnv, MaybeCallback, Schedule, Params 15 | from offline_baselines_jax.sac.policies import SACPolicy 16 | from .core import sac_update 17 | 18 | 19 | class LogEntropyCoef(nn.Module): 20 | init_value: float = 1.0 21 | @nn.compact 22 | def __call__(self) -> jnp.ndarray: 23 | log_temp = self.param('log_temp', init_fn=lambda key: jnp.full((), jnp.log(self.init_value))) 24 | return log_temp 25 | 26 | 27 | class SAC(OffPolicyAlgorithm): 28 | def __init__( 29 | self, 30 | env: Union[GymEnv, str], 31 | policy: Union[str, Type[SACPolicy]] = SACPolicy, 32 | learning_rate: Union[float, Schedule] = 3e-4, 33 | buffer_size: int = 1_000_000, # 1e6 34 | learning_starts: int = 100, 35 | batch_size: int = 256, 36 | tau: float = 0.005, 37 | gamma: float = 0.99, 38 | train_freq: Union[int, Tuple[int, str]] = 1, 39 | gradient_steps: int = 1, 40 | action_noise: Optional[ActionNoise] = None, 41 | replay_buffer_class: Optional[ReplayBuffer] = None, 42 | replay_buffer_kwargs: Optional[Dict[str, Any]] = None, 43 | optimize_memory_usage: bool = False, 44 | ent_coef: Union[str, float] = "auto", 45 | target_update_interval: int = 1, 46 | target_entropy: Union[str, float] = "auto", 47 | tensorboard_log: Optional[str] = None, 48 | create_eval_env: bool = False, 49 | policy_kwargs: Optional[Dict[str, Any]] = None, 50 | verbose: int = 0, 51 | seed: int = 0, 52 | _init_setup_model: bool = True, 53 | without_exploration: bool = False, 54 | ): 55 | 56 | super(SAC, self).__init__( 57 | policy, 58 | env, 59 | learning_rate, 60 | buffer_size, 61 | learning_starts, 62 | batch_size, 63 | tau, 64 | gamma, 65 | train_freq, 66 | gradient_steps, 67 | action_noise, 68 | replay_buffer_class=replay_buffer_class, 69 | replay_buffer_kwargs=replay_buffer_kwargs, 70 | policy_kwargs=policy_kwargs, 71 | tensorboard_log=tensorboard_log, 72 | verbose=verbose, 73 | create_eval_env=create_eval_env, 74 | seed=seed, 75 | optimize_memory_usage=optimize_memory_usage, 76 | supported_action_spaces=(gym.spaces.Box), 77 | support_multi_env=True, 78 | without_exploration=without_exploration, 79 | ) 80 | 81 | self.target_entropy = target_entropy 82 | self.log_ent_coef = None 83 | # Entropy coefficient / Entropy temperature 84 | # Inverse of the reward scale 85 | self.ent_coef = ent_coef 86 | self.target_update_interval = target_update_interval 87 | self.entropy_update = True 88 | 89 | if _init_setup_model: 90 | self._setup_model() 91 | 92 | def _setup_model(self) -> None: 93 | super(SAC, self)._setup_model() 94 | self._create_aliases() 95 | # Target entropy is used when learning the entropy coefficient 96 | if self.target_entropy == "auto": 97 | # automatically set target entropy if needed 98 | self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) 99 | else: 100 | # Force conversion 101 | # this will also throw an error for unexpected string 102 | self.target_entropy = float(self.target_entropy) 103 | 104 | # The entropy coefficient or entropy can be learned automatically 105 | # see Automating Entropy Adjustment for Maximum Entropy RL section 106 | # of https://arxiv.org/abs/1812.05905 107 | if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"): 108 | # Default initial value of ent_coef when learned 109 | init_value = 1.0 110 | if "_" in self.ent_coef: 111 | init_value = float(self.ent_coef.split("_")[1]) 112 | assert init_value > 0.0, "The initial value of ent_coef must be greater than 0" 113 | 114 | # Note: we optimize the log of the entropy coeff which is slightly different from the paper 115 | # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 116 | log_ent_coef_def = LogEntropyCoef(init_value) 117 | self.rng, temp_key = jax.random.split(self.rng, 2) 118 | self.log_ent_coef = Model.create( 119 | log_ent_coef_def, 120 | inputs=[temp_key], 121 | tx=optax.adam(learning_rate=self.lr_schedule(1)) 122 | ) 123 | 124 | else: 125 | # Force conversion to float 126 | # this will throw an error if a malformed string (different from 'auto') 127 | # is passed 128 | log_ent_coef_def = LogEntropyCoef(self.ent_coef) 129 | self.rng, temp_key = jax.random.split(self.rng, 2) 130 | self.log_ent_coef = Model.create(log_ent_coef_def, inputs=[temp_key]) 131 | self.entropy_update = False 132 | 133 | def _create_aliases(self) -> None: 134 | self.actor = self.policy.actor 135 | self.critic = self.policy.critic 136 | self.critic_target = self.policy.critic_target 137 | 138 | def train(self, gradient_steps: int, batch_size: int = 64) -> None: 139 | ent_coef_losses, ent_coefs = [], [] 140 | actor_losses, critic_losses = [], [] 141 | 142 | for gradient_step in range(gradient_steps): 143 | # Sample replay buffer 144 | replay_data = self.replay_buffer.sample(batch_size=batch_size) 145 | 146 | self.rng, key = jax.random.split(self.rng, 2) 147 | target_update_cond = (gradient_step % self.target_update_interval == 0) 148 | 149 | self.rng, new_models, info \ 150 | = sac_update( 151 | rng=key, 152 | actor=self.actor, 153 | critic=self.critic, 154 | critic_target=self.critic_target, 155 | log_ent_coef=self.log_ent_coef, 156 | 157 | observations=replay_data.observations, 158 | actions=replay_data.actions, 159 | rewards=replay_data.rewards, 160 | next_observations=replay_data.next_observations, 161 | dones=replay_data.dones, 162 | 163 | gamma=self.gamma, 164 | target_entropy=self.target_entropy, 165 | tau=self.tau, 166 | target_update_cond=target_update_cond, 167 | entropy_update=self.entropy_update 168 | ) 169 | 170 | ent_coef_losses.append(info['ent_coef_loss']) 171 | ent_coefs.append(info['ent_coef']) 172 | critic_losses.append(info['critic_loss']) 173 | actor_losses.append(info['actor_loss']) 174 | 175 | self.apply_new_models(new_models) 176 | 177 | self._n_updates += gradient_steps 178 | if self.num_timesteps % 500 == 0: 179 | self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") 180 | self.logger.record("train/ent_coef", np.mean(ent_coefs)) 181 | self.logger.record("train/actor_loss", np.mean(actor_losses)) 182 | self.logger.record("train/critic_loss", np.mean(critic_losses)) 183 | if len(ent_coef_losses) > 0: 184 | self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) 185 | 186 | def offline_train(self, gradient_steps: int, batch_size: int) -> None: 187 | raise NotImplementedError() 188 | 189 | def learn( 190 | self, 191 | total_timesteps: int, 192 | callback: MaybeCallback = None, 193 | log_interval: int = 4, 194 | eval_env: Optional[GymEnv] = None, 195 | eval_freq: int = -1, 196 | n_eval_episodes: int = 5, 197 | tb_log_name: str = "SAC", 198 | eval_log_path: Optional[str] = None, 199 | reset_num_timesteps: bool = True, 200 | ) -> OffPolicyAlgorithm: 201 | 202 | return super(SAC, self).learn( 203 | total_timesteps=total_timesteps, 204 | callback=callback, 205 | log_interval=log_interval, 206 | eval_env=eval_env, 207 | eval_freq=eval_freq, 208 | n_eval_episodes=n_eval_episodes, 209 | tb_log_name=tb_log_name, 210 | eval_log_path=eval_log_path, 211 | reset_num_timesteps=reset_num_timesteps, 212 | ) 213 | 214 | def _excluded_save_params(self) -> List[str]: 215 | return super(SAC, self)._excluded_save_params() + SAC_COMPONENTS 216 | 217 | def _get_jax_save_params(self) -> Dict[str, Params]: 218 | params_dict = {} 219 | for comp_str in SAC_COMPONENTS: 220 | comp = getattr(self, comp_str) 221 | params_dict[comp_str] = comp.params 222 | return params_dict 223 | 224 | def _get_jax_load_params(self) -> List[str]: 225 | return SAC_COMPONENTS 226 | 227 | 228 | SAC_COMPONENTS = ["actor", "critic", "critic_target", "log_ent_coef"] -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/subproc_vec_env.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | from collections import OrderedDict 3 | from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union 4 | 5 | import gym 6 | import numpy as np 7 | 8 | from stable_baselines3.common.vec_env.base_vec_env import ( 9 | CloudpickleWrapper, 10 | VecEnv, 11 | VecEnvIndices, 12 | VecEnvObs, 13 | VecEnvStepReturn, 14 | ) 15 | 16 | 17 | def _worker( 18 | remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper 19 | ) -> None: 20 | # Import here to avoid a circular import 21 | from stable_baselines3.common.env_util import is_wrapped 22 | 23 | parent_remote.close() 24 | env = env_fn_wrapper.var() 25 | while True: 26 | try: 27 | cmd, data = remote.recv() 28 | if cmd == "step": 29 | observation, reward, done, info = env.step(data) 30 | if done: 31 | # save final observation where user can get it, then reset 32 | info["terminal_observation"] = observation 33 | observation = env.reset() 34 | remote.send((observation, reward, done, info)) 35 | elif cmd == "seed": 36 | remote.send(env.seed(data)) 37 | elif cmd == "reset": 38 | observation = env.reset() 39 | remote.send(observation) 40 | elif cmd == "render": 41 | remote.send(env.render(data)) 42 | elif cmd == "close": 43 | env.close() 44 | remote.close() 45 | break 46 | elif cmd == "get_spaces": 47 | remote.send((env.observation_space, env.action_space)) 48 | elif cmd == "env_method": 49 | method = getattr(env, data[0]) 50 | remote.send(method(*data[1], **data[2])) 51 | elif cmd == "get_attr": 52 | remote.send(getattr(env, data)) 53 | elif cmd == "set_attr": 54 | remote.send(setattr(env, data[0], data[1])) 55 | elif cmd == "is_wrapped": 56 | remote.send(is_wrapped(env, data)) 57 | else: 58 | raise NotImplementedError(f"`{cmd}` is not implemented in the worker") 59 | except EOFError: 60 | break 61 | 62 | 63 | class SubprocVecEnv(VecEnv): 64 | """ 65 | Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own 66 | process, allowing significant speed up when the environment is computationally complex. 67 | 68 | For performance reasons, if your environment is not IO bound, the number of environments should not exceed the 69 | number of logical cores on your CPU. 70 | 71 | .. warning:: 72 | 73 | Only 'forkserver' and 'spawn' start methods are thread-safe, 74 | which is important when TensorFlow sessions or other non thread-safe 75 | libraries are used in the parent (see issue #217). However, compared to 76 | 'fork' they incur a small start-up cost and have restrictions on 77 | global variables. With those methods, users must wrap the code in an 78 | ``if __name__ == "__main__":`` block. 79 | For more information, see the multiprocessing documentation. 80 | 81 | :param env_fns: Environments to run in subprocesses 82 | :param start_method: method used to start the subprocesses. 83 | Must be one of the methods returned by multiprocessing.get_all_start_methods(). 84 | Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. 85 | """ 86 | 87 | def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None): 88 | self.waiting = False 89 | self.closed = False 90 | n_envs = len(env_fns) 91 | 92 | if start_method is None: 93 | # Fork is not a thread safe method (see issue #217) 94 | # but is more user friendly (does not require to wrap the code in 95 | # a `if __name__ == "__main__":`) 96 | forkserver_available = "forkserver" in mp.get_all_start_methods() 97 | start_method = "forkserver" if forkserver_available else "spawn" 98 | ctx = mp.get_context(start_method) 99 | 100 | self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)]) 101 | self.processes = [] 102 | for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns): 103 | args = (work_remote, remote, CloudpickleWrapper(env_fn)) 104 | # daemon=True: if the main process crashes, we should not cause things to hang 105 | process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error 106 | process.start() 107 | self.processes.append(process) 108 | work_remote.close() 109 | 110 | self.remotes[0].send(("get_spaces", None)) 111 | observation_space, action_space = self.remotes[0].recv() 112 | VecEnv.__init__(self, len(env_fns), observation_space, action_space) 113 | 114 | def step_async(self, actions: np.ndarray) -> None: 115 | for remote, action in zip(self.remotes, actions): 116 | remote.send(("step", action)) 117 | self.waiting = True 118 | 119 | def step_wait(self) -> VecEnvStepReturn: 120 | results = [remote.recv() for remote in self.remotes] 121 | self.waiting = False 122 | obs, rews, dones, infos = zip(*results) 123 | return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos 124 | 125 | def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: 126 | if seed is None: 127 | seed = np.random.randint(0, 2**32 - 1) 128 | for idx, remote in enumerate(self.remotes): 129 | remote.send(("seed", seed + idx)) 130 | return [remote.recv() for remote in self.remotes] 131 | 132 | def reset(self) -> VecEnvObs: 133 | for remote in self.remotes: 134 | remote.send(("reset", None)) 135 | obs = [remote.recv() for remote in self.remotes] 136 | return _flatten_obs(obs, self.observation_space) 137 | 138 | def close(self) -> None: 139 | if self.closed: 140 | return 141 | if self.waiting: 142 | for remote in self.remotes: 143 | remote.recv() 144 | for remote in self.remotes: 145 | remote.send(("close", None)) 146 | for process in self.processes: 147 | process.join() 148 | self.closed = True 149 | 150 | def get_images(self) -> Sequence[np.ndarray]: 151 | for pipe in self.remotes: 152 | # gather images from subprocesses 153 | # `mode` will be taken into account later 154 | pipe.send(("render", "rgb_array")) 155 | imgs = [pipe.recv() for pipe in self.remotes] 156 | return imgs 157 | 158 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: 159 | """Return attribute from vectorized environment (see base class).""" 160 | target_remotes = self._get_target_remotes(indices) 161 | for remote in target_remotes: 162 | remote.send(("get_attr", attr_name)) 163 | return [remote.recv() for remote in target_remotes] 164 | 165 | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: 166 | """Set attribute inside vectorized environments (see base class).""" 167 | target_remotes = self._get_target_remotes(indices) 168 | for remote in target_remotes: 169 | remote.send(("set_attr", (attr_name, value))) 170 | for remote in target_remotes: 171 | remote.recv() 172 | 173 | def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: 174 | """Call instance methods of vectorized environments.""" 175 | target_remotes = self._get_target_remotes(indices) 176 | for remote in target_remotes: 177 | remote.send(("env_method", (method_name, method_args, method_kwargs))) 178 | return [remote.recv() for remote in target_remotes] 179 | 180 | def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: 181 | """Check if worker environments are wrapped with a given wrapper""" 182 | target_remotes = self._get_target_remotes(indices) 183 | for remote in target_remotes: 184 | remote.send(("is_wrapped", wrapper_class)) 185 | return [remote.recv() for remote in target_remotes] 186 | 187 | def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]: 188 | """ 189 | Get the connection object needed to communicate with the wanted 190 | envs that are in subprocesses. 191 | 192 | :param indices: refers to indices of envs. 193 | :return: Connection object to communicate between processes. 194 | """ 195 | indices = self._get_indices(indices) 196 | return [self.remotes[i] for i in indices] 197 | 198 | 199 | def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.spaces.Space) -> VecEnvObs: 200 | """ 201 | Flatten observations, depending on the observation space. 202 | 203 | :param obs: observations. 204 | A list or tuple of observations, one per environment. 205 | Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays. 206 | :return: flattened observations. 207 | A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays. 208 | Each NumPy array has the environment index as its first axis. 209 | """ 210 | assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" 211 | assert len(obs) > 0, "need observations from at least one environment" 212 | 213 | if isinstance(space, gym.spaces.Dict): 214 | assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" 215 | assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" 216 | return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) 217 | elif isinstance(space, gym.spaces.Tuple): 218 | assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" 219 | obs_len = len(space.spaces) 220 | return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) 221 | else: 222 | return np.stack(obs) 223 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/jax_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, List, Tuple, Type, Union, Sequence, Any, Callable 3 | 4 | import flax 5 | import flax.linen as nn 6 | import gym 7 | import jax 8 | import jax.numpy as jnp 9 | from flax.linen.initializers import zeros 10 | from jax import lax 11 | 12 | from offline_baselines_jax.common.policies import Model 13 | from offline_baselines_jax.common.preprocessing import is_image_space 14 | from offline_baselines_jax.common.type_aliases import TensorDict 15 | 16 | PRNGKey = Any 17 | Params = flax.core.FrozenDict[str, Any] 18 | Shape = Sequence[int] 19 | InfoDict = Dict[str, float] 20 | Dtype = Any # this could be a real type? 21 | Array = Any 22 | PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]] 23 | 24 | default_kernel_init = nn.initializers.xavier_normal() 25 | default_bias_init = zeros 26 | 27 | 28 | def polyak_update(source: Model, target: Model, tau: float) -> Model: 29 | new_target_params = jax.tree_multimap(lambda p, tp: p * tau + tp * (1 - tau), source.params, target.params) 30 | return target.replace(params=new_target_params) 31 | 32 | 33 | def calculate_gain(nonlinearity, param=None): 34 | r"""Return the recommended gain value for the given nonlinearity function. 35 | The values are as follows: 36 | 37 | ================= ==================================================== 38 | nonlinearity gain 39 | ================= ==================================================== 40 | Linear / Identity :math:`1` 41 | Conv{1,2,3}D :math:`1` 42 | Sigmoid :math:`1` 43 | Tanh :math:`\frac{5}{3}` 44 | ReLU :math:`\sqrt{2}` 45 | Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` 46 | SELU :math:`\frac{3}{4}` 47 | ================= ==================================================== 48 | 49 | .. warning:: 50 | In order to implement `Self-Normalizing Neural Networks`_ , 51 | you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. 52 | This gives the initial weights a variance of ``1 / N``, 53 | which is necessary to induce a stable fixed point in the forward pass. 54 | In contrast, the default gain for ``SELU`` sacrifices the normalisation 55 | effect for more stable gradient flow in rectangular layers. 56 | 57 | Args: 58 | nonlinearity: the non-linear function (`nn.functional` name) 59 | param: optional parameter for the non-linear function 60 | 61 | Examples: 62 | >>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 63 | 64 | .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html 65 | """ 66 | linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] 67 | if nonlinearity in linear_fns or nonlinearity == 'sigmoid': 68 | return 1 69 | elif nonlinearity == 'tanh': 70 | return 5.0 / 3 71 | elif nonlinearity == 'relu': 72 | return math.sqrt(2.0) 73 | elif nonlinearity == 'leaky_relu': 74 | if param is None: 75 | negative_slope = 0.01 76 | elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): 77 | # True/False are instances of int, hence check above 78 | negative_slope = param 79 | else: 80 | raise ValueError("negative_slope {} not a valid number".format(param)) 81 | return math.sqrt(2.0 / (1 + negative_slope ** 2)) 82 | elif nonlinearity == 'selu': 83 | return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) 84 | else: 85 | raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) 86 | 87 | 88 | def create_mlp( 89 | output_dim: int, 90 | net_arch: List[int], 91 | activation_fn: Callable = nn.relu, 92 | dropout: float = 0.0, 93 | squash_output: bool = False, 94 | layernorm: bool = False, 95 | batchnorm: bool = False, 96 | use_bias: bool = True, 97 | kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init, 98 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros 99 | ) -> nn.Module: 100 | 101 | if output_dim > 0: 102 | net_arch = list(net_arch) 103 | net_arch.append(output_dim) 104 | return MLP(net_arch, activation_fn, dropout, squash_output, layernorm, batchnorm, use_bias, kernel_init, bias_init) 105 | 106 | 107 | def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> Tuple[List[int], List[int]]: 108 | """ 109 | Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG). 110 | 111 | The ``net_arch`` parameter allows to specify the amount and size of the hidden layers, 112 | which can be different for the actor and the critic. 113 | It is assumed to be a list of ints or a dict. 114 | 115 | 1. If it is a list, actor and critic networks will have the same architecture. 116 | The architecture is represented by a list of integers (of arbitrary length (zero allowed)) 117 | each specifying the number of units per layer. 118 | If the number of ints is zero, the network will be linear. 119 | 2. If it is a dict, it should have the following structure: 120 | ``dict(qf=[], pi=[])``. 121 | where the network architecture is a list as described in 1. 122 | 123 | For example, to have actor and critic that share the same network architecture, 124 | you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each). 125 | 126 | If you want a different architecture for the actor and the critic, 127 | then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``. 128 | 129 | .. note:: 130 | Compared to their on-policy counterparts, no shared layers (other than the features extractor) 131 | between the actor and the critic are allowed (to prevent issues with target networks). 132 | 133 | :param net_arch: The specification of the actor and critic networks. 134 | See above for details on its formatting. 135 | :return: The network architectures for the actor and the critic 136 | """ 137 | try: 138 | net_arch = list(net_arch) 139 | except: 140 | pass 141 | 142 | if isinstance(net_arch, list): 143 | actor_arch, critic_arch = net_arch, net_arch 144 | else: 145 | assert isinstance(net_arch, dict), "Error: the net_arch can only contain be a list of ints or a dict" 146 | assert "pi" in net_arch, "Error: no key 'pi' was provided in net_arch for the actor network" 147 | assert "qf" in net_arch, "Error: no key 'qf' was provided in net_arch for the critic network" 148 | actor_arch, critic_arch = net_arch["pi"], net_arch["qf"] 149 | return actor_arch, critic_arch 150 | 151 | 152 | class MLP(nn.Module): 153 | net_arch: List 154 | activation_fn: nn.Module 155 | dropout: float = 0.0 156 | squashed_out: bool = False 157 | 158 | layernorm: bool = False 159 | batchnorm: bool = False 160 | use_bias: bool = True 161 | kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init 162 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros 163 | 164 | @nn.compact 165 | def __call__(self, x, deterministic: bool = False, training: bool = True): 166 | 167 | for feature in self.net_arch[:-1]: 168 | x = nn.Dense(feature, kernel_init=self.kernel_init, use_bias=self.use_bias, bias_init=self.bias_init)(x) 169 | if self.batchnorm: x = nn.BatchNorm(use_running_average=not training, momentum=0.1)(x) 170 | if self.layernorm: x = nn.LayerNorm()(x) 171 | x = self.activation_fn(x) 172 | x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic) 173 | 174 | if len(self.net_arch) > 0: 175 | x = nn.Dense( 176 | self.net_arch[-1], 177 | kernel_init=self.kernel_init, 178 | use_bias=self.use_bias, 179 | bias_init=self.bias_init 180 | )(x) 181 | 182 | if self.squashed_out: return nn.tanh(x) 183 | else: return x 184 | 185 | 186 | class Sequential(nn.Module): 187 | layers: Sequence[nn.Module] 188 | 189 | @nn.compact 190 | def __call__(self, x, *args, **kwargs): 191 | for layer in self.layers: 192 | x = layer(x, *args, **kwargs) 193 | return x 194 | 195 | 196 | class BaseFeaturesExtractor(nn.Module): 197 | """ 198 | Base class that represents a features extractor. 199 | 200 | :param observation_space: 201 | :param features_dim: Number of features extracted. 202 | """ 203 | 204 | _observation_space: gym.Space 205 | _feature_dim: int = 0 206 | 207 | @property 208 | def features_dim(self) -> int: 209 | return self._features_dim 210 | 211 | def __call__(self, observations: jnp.array) -> jnp.array: 212 | raise NotImplementedError() 213 | 214 | 215 | class FlattenExtractor(BaseFeaturesExtractor): 216 | @nn.compact 217 | def __call__(self, x: jnp.ndarray): 218 | return x.reshape((x.shape[0], -1)) 219 | 220 | 221 | class NatureCNN(BaseFeaturesExtractor): 222 | """ 223 | CNN from DQN nature paper: 224 | Mnih, Volodymyr, et al. 225 | "Human-level control through deep reinforcement learning." 226 | Nature 518.7540 (2015): 529-533. 227 | 228 | features_dim: Number of features extracted. 229 | This corresponds to the number of unit for the last layer. 230 | """ 231 | feature_dim: int = 512 232 | 233 | @nn.compact 234 | def __call__(self, observations: jnp.array) -> jnp.array: 235 | x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4))(observations) 236 | x = nn.relu(x) 237 | x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2))(x) 238 | x = nn.relu(x) 239 | x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x) 240 | x = nn.relu(x) 241 | x = x.reshape((x.shape[0], -1)) # flatten 242 | 243 | x = nn.Dense(features=self.feature_dim)(x) 244 | x = nn.relu(x) 245 | return x 246 | 247 | 248 | class CombinedExtractor(BaseFeaturesExtractor): 249 | """ 250 | Combined feature extractor for Dict observation spaces. 251 | Builds a feature extractor for each key of the space. Input from each space 252 | is fed through a separate submodule (CNN or MLP, depending on input shape), 253 | the output features are concatenated and fed through additional MLP network ("combined"). 254 | 255 | :param observation_space: 256 | :param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to 257 | 256 to avoid exploding network sizes. 258 | """ 259 | _observation_space: gym.spaces.Dict 260 | cnn_output_dim: int = 256 261 | 262 | @nn.compact 263 | def __call__(self, observation: TensorDict): 264 | encoded_tensor_list = [] 265 | for key, subspace in self._observation_space.spaces.items(): 266 | if is_image_space(subspace): 267 | encoded_tensor_list.append(NatureCNN(self.cnn_output_dim)(observation[key])) 268 | else: 269 | # The observation key is a vector, flatten it if needed 270 | encoded_tensor_list.append(observation[key].reshape((observation[key].shape[0], -1))) 271 | return jnp.concatenate(encoded_tensor_list, axis=1) 272 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/stacked_observations.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | 4 | import numpy as np 5 | from gym import spaces 6 | 7 | from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first 8 | 9 | 10 | class StackedObservations: 11 | """ 12 | Frame stacking wrapper for data. 13 | 14 | Dimension to stack over is either first (channels-first) or 15 | last (channels-last), which is detected automatically using 16 | ``common.preprocessing.is_image_space_channels_first`` if 17 | observation is an image space. 18 | 19 | :param num_envs: number of environments 20 | :param n_stack: Number of frames to stack 21 | :param observation_space: Environment observation space. 22 | :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. 23 | If None, automatically detect channel to stack over in case of image observation or default to "last" (default). 24 | """ 25 | 26 | def __init__( 27 | self, 28 | num_envs: int, 29 | n_stack: int, 30 | observation_space: spaces.Space, 31 | channels_order: Optional[str] = None, 32 | ): 33 | 34 | self.n_stack = n_stack 35 | ( 36 | self.channels_first, 37 | self.stack_dimension, 38 | self.stackedobs, 39 | self.repeat_axis, 40 | ) = self.compute_stacking(num_envs, n_stack, observation_space, channels_order) 41 | super().__init__() 42 | 43 | @staticmethod 44 | def compute_stacking( 45 | num_envs: int, 46 | n_stack: int, 47 | observation_space: spaces.Box, 48 | channels_order: Optional[str] = None, 49 | ) -> Tuple[bool, int, np.ndarray, int]: 50 | """ 51 | Calculates the parameters in order to stack observations 52 | 53 | :param num_envs: Number of environments in the stack 54 | :param n_stack: The number of observations to stack 55 | :param observation_space: The observation space 56 | :param channels_order: The order of the channels 57 | :return: tuple of channels_first, stack_dimension, stackedobs, repeat_axis 58 | """ 59 | channels_first = False 60 | if channels_order is None: 61 | # Detect channel location automatically for images 62 | if is_image_space(observation_space): 63 | channels_first = is_image_space_channels_first(observation_space) 64 | else: 65 | # Default behavior for non-image space, stack on the last axis 66 | channels_first = False 67 | else: 68 | assert channels_order in { 69 | "last", 70 | "first", 71 | }, "`channels_order` must be one of following: 'last', 'first'" 72 | 73 | channels_first = channels_order == "first" 74 | 75 | # This includes the vec-env dimension (first) 76 | stack_dimension = 1 if channels_first else -1 77 | repeat_axis = 0 if channels_first else -1 78 | low = np.repeat(observation_space.low, n_stack, axis=repeat_axis) 79 | stackedobs = np.zeros((num_envs,) + low.shape, low.dtype) 80 | return channels_first, stack_dimension, stackedobs, repeat_axis 81 | 82 | def stack_observation_space(self, observation_space: spaces.Box) -> spaces.Box: 83 | """ 84 | Given an observation space, returns a new observation space with stacked observations 85 | 86 | :return: New observation space with stacked dimensions 87 | """ 88 | low = np.repeat(observation_space.low, self.n_stack, axis=self.repeat_axis) 89 | high = np.repeat(observation_space.high, self.n_stack, axis=self.repeat_axis) 90 | return spaces.Box(low=low, high=high, dtype=observation_space.dtype) 91 | 92 | def reset(self, observation: np.ndarray) -> np.ndarray: 93 | """ 94 | Resets the stackedobs, adds the reset observation to the stack, and returns the stack 95 | 96 | :param observation: Reset observation 97 | :return: The stacked reset observation 98 | """ 99 | self.stackedobs[...] = 0 100 | if self.channels_first: 101 | self.stackedobs[:, -observation.shape[self.stack_dimension] :, ...] = observation 102 | else: 103 | self.stackedobs[..., -observation.shape[self.stack_dimension] :] = observation 104 | return self.stackedobs 105 | 106 | def update( 107 | self, 108 | observations: np.ndarray, 109 | dones: np.ndarray, 110 | infos: List[Dict[str, Any]], 111 | ) -> Tuple[np.ndarray, List[Dict[str, Any]]]: 112 | """ 113 | Adds the observations to the stack and uses the dones to update the infos. 114 | 115 | :param observations: numpy array of observations 116 | :param dones: numpy array of done info 117 | :param infos: numpy array of info dicts 118 | :return: tuple of the stacked observations and the updated infos 119 | """ 120 | stack_ax_size = observations.shape[self.stack_dimension] 121 | self.stackedobs = np.roll(self.stackedobs, shift=-stack_ax_size, axis=self.stack_dimension) 122 | for i, done in enumerate(dones): 123 | if done: 124 | if "terminal_observation" in infos[i]: 125 | old_terminal = infos[i]["terminal_observation"] 126 | if self.channels_first: 127 | new_terminal = np.concatenate( 128 | (self.stackedobs[i, :-stack_ax_size, ...], old_terminal), 129 | axis=0, # self.stack_dimension - 1, as there is not batch dim 130 | ) 131 | else: 132 | new_terminal = np.concatenate( 133 | (self.stackedobs[i, ..., :-stack_ax_size], old_terminal), 134 | axis=self.stack_dimension, 135 | ) 136 | infos[i]["terminal_observation"] = new_terminal 137 | else: 138 | warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") 139 | self.stackedobs[i] = 0 140 | if self.channels_first: 141 | self.stackedobs[:, -observations.shape[self.stack_dimension] :, ...] = observations 142 | else: 143 | self.stackedobs[..., -observations.shape[self.stack_dimension] :] = observations 144 | return self.stackedobs, infos 145 | 146 | 147 | class StackedDictObservations(StackedObservations): 148 | """ 149 | Frame stacking wrapper for dictionary data. 150 | 151 | Dimension to stack over is either first (channels-first) or 152 | last (channels-last), which is detected automatically using 153 | ``common.preprocessing.is_image_space_channels_first`` if 154 | observation is an image space. 155 | 156 | :param num_envs: number of environments 157 | :param n_stack: Number of frames to stack 158 | :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. 159 | If None, automatically detect channel to stack over in case of image observation or default to "last" (default). 160 | """ 161 | 162 | def __init__( 163 | self, 164 | num_envs: int, 165 | n_stack: int, 166 | observation_space: spaces.Dict, 167 | channels_order: Optional[Union[str, Dict[str, str]]] = None, 168 | ): 169 | self.n_stack = n_stack 170 | self.channels_first = {} 171 | self.stack_dimension = {} 172 | self.stackedobs = {} 173 | self.repeat_axis = {} 174 | 175 | for key, subspace in observation_space.spaces.items(): 176 | assert isinstance(subspace, spaces.Box), "StackedDictObservations only works with nested gym.spaces.Box" 177 | if isinstance(channels_order, str) or channels_order is None: 178 | subspace_channel_order = channels_order 179 | else: 180 | subspace_channel_order = channels_order[key] 181 | ( 182 | self.channels_first[key], 183 | self.stack_dimension[key], 184 | self.stackedobs[key], 185 | self.repeat_axis[key], 186 | ) = self.compute_stacking(num_envs, n_stack, subspace, subspace_channel_order) 187 | 188 | def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict: 189 | """ 190 | Returns the stacked verson of a Dict observation space 191 | 192 | :param observation_space: Dict observation space to stack 193 | :return: stacked observation space 194 | """ 195 | spaces_dict = {} 196 | for key, subspace in observation_space.spaces.items(): 197 | low = np.repeat(subspace.low, self.n_stack, axis=self.repeat_axis[key]) 198 | high = np.repeat(subspace.high, self.n_stack, axis=self.repeat_axis[key]) 199 | spaces_dict[key] = spaces.Box(low=low, high=high, dtype=subspace.dtype) 200 | return spaces.Dict(spaces=spaces_dict) 201 | 202 | def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 203 | """ 204 | Resets the stacked observations, adds the reset observation to the stack, and returns the stack 205 | 206 | :param observation: Reset observation 207 | :return: Stacked reset observations 208 | """ 209 | for key, obs in observation.items(): 210 | self.stackedobs[key][...] = 0 211 | if self.channels_first[key]: 212 | self.stackedobs[key][:, -obs.shape[self.stack_dimension[key]] :, ...] = obs 213 | else: 214 | self.stackedobs[key][..., -obs.shape[self.stack_dimension[key]] :] = obs 215 | return self.stackedobs 216 | 217 | def update( 218 | self, 219 | observations: Dict[str, np.ndarray], 220 | dones: np.ndarray, 221 | infos: List[Dict[str, Any]], 222 | ) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: 223 | """ 224 | Adds the observations to the stack and uses the dones to update the infos. 225 | 226 | :param observations: Dict of numpy arrays of observations 227 | :param dones: numpy array of dones 228 | :param infos: dict of infos 229 | :return: tuple of the stacked observations and the updated infos 230 | """ 231 | for key in self.stackedobs.keys(): 232 | stack_ax_size = observations[key].shape[self.stack_dimension[key]] 233 | self.stackedobs[key] = np.roll( 234 | self.stackedobs[key], 235 | shift=-stack_ax_size, 236 | axis=self.stack_dimension[key], 237 | ) 238 | 239 | for i, done in enumerate(dones): 240 | if done: 241 | if "terminal_observation" in infos[i]: 242 | old_terminal = infos[i]["terminal_observation"][key] 243 | if self.channels_first[key]: 244 | new_terminal = np.vstack( 245 | ( 246 | self.stackedobs[key][i, :-stack_ax_size, ...], 247 | old_terminal, 248 | ) 249 | ) 250 | else: 251 | new_terminal = np.concatenate( 252 | ( 253 | self.stackedobs[key][i, ..., :-stack_ax_size], 254 | old_terminal, 255 | ), 256 | axis=self.stack_dimension[key], 257 | ) 258 | infos[i]["terminal_observation"][key] = new_terminal 259 | else: 260 | warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") 261 | self.stackedobs[key][i] = 0 262 | if self.channels_first[key]: 263 | self.stackedobs[key][:, -stack_ax_size:, ...] = observations[key] 264 | else: 265 | self.stackedobs[key][..., -stack_ax_size:] = observations[key] 266 | return self.stackedobs, infos 267 | -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/vec_normalize.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import warnings 3 | from copy import deepcopy 4 | from typing import Any, Dict, List, Optional, Union 5 | 6 | import gym 7 | import numpy as np 8 | 9 | from stable_baselines3.common import utils 10 | from stable_baselines3.common.running_mean_std import RunningMeanStd 11 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper 12 | 13 | 14 | class VecNormalize(VecEnvWrapper): 15 | """ 16 | A moving average, normalizing wrapper for vectorized environment. 17 | has support for saving/loading moving average, 18 | 19 | :param venv: the vectorized environment to wrap 20 | :param training: Whether to update or not the moving average 21 | :param norm_obs: Whether to normalize observation or not (default: True) 22 | :param norm_reward: Whether to normalize rewards or not (default: True) 23 | :param clip_obs: Max absolute value for observation 24 | :param clip_reward: Max value absolute for discounted reward 25 | :param gamma: discount factor 26 | :param epsilon: To avoid division by zero 27 | :param norm_obs_keys: Which keys from observation dict to normalize. 28 | If not specified, all keys will be normalized. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | venv: VecEnv, 34 | training: bool = True, 35 | norm_obs: bool = True, 36 | norm_reward: bool = True, 37 | clip_obs: float = 10.0, 38 | clip_reward: float = 10.0, 39 | gamma: float = 0.99, 40 | epsilon: float = 1e-8, 41 | norm_obs_keys: Optional[List[str]] = None, 42 | ): 43 | VecEnvWrapper.__init__(self, venv) 44 | 45 | self.norm_obs = norm_obs 46 | self.norm_obs_keys = norm_obs_keys 47 | # Check observation spaces 48 | if self.norm_obs: 49 | self._sanity_checks() 50 | 51 | if isinstance(self.observation_space, gym.spaces.Dict): 52 | self.obs_spaces = self.observation_space.spaces 53 | self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} 54 | else: 55 | self.obs_spaces = None 56 | self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) 57 | 58 | self.ret_rms = RunningMeanStd(shape=()) 59 | self.clip_obs = clip_obs 60 | self.clip_reward = clip_reward 61 | # Returns: discounted rewards 62 | self.returns = np.zeros(self.num_envs) 63 | self.gamma = gamma 64 | self.epsilon = epsilon 65 | self.training = training 66 | self.norm_obs = norm_obs 67 | self.norm_reward = norm_reward 68 | self.old_obs = np.array([]) 69 | self.old_reward = np.array([]) 70 | 71 | def _sanity_checks(self) -> None: 72 | """ 73 | Check the observations that are going to be normalized are of the correct type (spaces.Box). 74 | """ 75 | if isinstance(self.observation_space, gym.spaces.Dict): 76 | # By default, we normalize all keys 77 | if self.norm_obs_keys is None: 78 | self.norm_obs_keys = list(self.observation_space.spaces.keys()) 79 | # Check that all keys are of type Box 80 | for obs_key in self.norm_obs_keys: 81 | if not isinstance(self.observation_space.spaces[obs_key], gym.spaces.Box): 82 | raise ValueError( 83 | f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} " 84 | f"is of type {self.observation_space.spaces[obs_key]}. " 85 | "You should probably explicitely pass the observation keys " 86 | " that should be normalized via the `norm_obs_keys` parameter." 87 | ) 88 | 89 | elif isinstance(self.observation_space, gym.spaces.Box): 90 | if self.norm_obs_keys is not None: 91 | raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces") 92 | 93 | else: 94 | raise ValueError( 95 | "VecNormalize only supports `gym.spaces.Box` and `gym.spaces.Dict` observation spaces, " 96 | f"not {self.observation_space}" 97 | ) 98 | 99 | def __getstate__(self) -> Dict[str, Any]: 100 | """ 101 | Gets state for pickling. 102 | 103 | Excludes self.venv, as in general VecEnv's may not be pickleable.""" 104 | state = self.__dict__.copy() 105 | # these attributes are not pickleable 106 | del state["venv"] 107 | del state["class_attributes"] 108 | # these attributes depend on the above and so we would prefer not to pickle 109 | del state["returns"] 110 | return state 111 | 112 | def __setstate__(self, state: Dict[str, Any]) -> None: 113 | """ 114 | Restores pickled state. 115 | 116 | User must call set_venv() after unpickling before using. 117 | 118 | :param state:""" 119 | # Backward compatibility 120 | if "norm_obs_keys" not in state and isinstance(state["observation_space"], gym.spaces.Dict): 121 | state["norm_obs_keys"] = list(state["observation_space"].spaces.keys()) 122 | self.__dict__.update(state) 123 | assert "venv" not in state 124 | self.venv = None 125 | 126 | def set_venv(self, venv: VecEnv) -> None: 127 | """ 128 | Sets the vector environment to wrap to venv. 129 | 130 | Also sets attributes derived from this such as `num_env`. 131 | 132 | :param venv: 133 | """ 134 | if self.venv is not None: 135 | raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.") 136 | VecEnvWrapper.__init__(self, venv) 137 | 138 | # Check only that the observation_space match 139 | utils.check_for_correct_spaces(venv, self.observation_space, venv.action_space) 140 | self.returns = np.zeros(self.num_envs) 141 | 142 | def step_wait(self) -> VecEnvStepReturn: 143 | """ 144 | Apply sequence of actions to sequence of environments 145 | actions -> (observations, rewards, dones) 146 | 147 | where ``dones`` is a boolean vector indicating whether each element is new. 148 | """ 149 | obs, rewards, dones, infos = self.venv.step_wait() 150 | self.old_obs = obs 151 | self.old_reward = rewards 152 | 153 | if self.training and self.norm_obs: 154 | if isinstance(obs, dict) and isinstance(self.obs_rms, dict): 155 | for key in self.obs_rms.keys(): 156 | self.obs_rms[key].update(obs[key]) 157 | else: 158 | self.obs_rms.update(obs) 159 | 160 | obs = self.normalize_obs(obs) 161 | 162 | if self.training: 163 | self._update_reward(rewards) 164 | rewards = self.normalize_reward(rewards) 165 | 166 | # Normalize the terminal observations 167 | for idx, done in enumerate(dones): 168 | if not done: 169 | continue 170 | if "terminal_observation" in infos[idx]: 171 | infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"]) 172 | 173 | self.returns[dones] = 0 174 | return obs, rewards, dones, infos 175 | 176 | def _update_reward(self, reward: np.ndarray) -> None: 177 | """Update reward normalization statistics.""" 178 | self.returns = self.returns * self.gamma + reward 179 | self.ret_rms.update(self.returns) 180 | 181 | def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray: 182 | """ 183 | Helper to normalize observation. 184 | :param obs: 185 | :param obs_rms: associated statistics 186 | :return: normalized observation 187 | """ 188 | return np.clip((obs - obs_rms.mean) / np.sqrt(obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs) 189 | 190 | def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray: 191 | """ 192 | Helper to unnormalize observation. 193 | :param obs: 194 | :param obs_rms: associated statistics 195 | :return: unnormalized observation 196 | """ 197 | return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean 198 | 199 | def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: 200 | """ 201 | Normalize observations using this VecNormalize's observations statistics. 202 | Calling this method does not update statistics. 203 | """ 204 | # Avoid modifying by reference the original object 205 | obs_ = deepcopy(obs) 206 | if self.norm_obs: 207 | if isinstance(obs, dict) and isinstance(self.obs_rms, dict): 208 | # Only normalize the specified keys 209 | for key in self.norm_obs_keys: 210 | obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32) 211 | else: 212 | obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32) 213 | return obs_ 214 | 215 | def normalize_reward(self, reward: np.ndarray) -> np.ndarray: 216 | """ 217 | Normalize rewards using this VecNormalize's rewards statistics. 218 | Calling this method does not update statistics. 219 | """ 220 | if self.norm_reward: 221 | reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) 222 | return reward 223 | 224 | def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: 225 | # Avoid modifying by reference the original object 226 | obs_ = deepcopy(obs) 227 | if self.norm_obs: 228 | if isinstance(obs, dict) and isinstance(self.obs_rms, dict): 229 | for key in self.norm_obs_keys: 230 | obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key]) 231 | else: 232 | obs_ = self._unnormalize_obs(obs, self.obs_rms) 233 | return obs_ 234 | 235 | def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray: 236 | if self.norm_reward: 237 | return reward * np.sqrt(self.ret_rms.var + self.epsilon) 238 | return reward 239 | 240 | def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: 241 | """ 242 | Returns an unnormalized version of the observations from the most recent 243 | step or reset. 244 | """ 245 | return deepcopy(self.old_obs) 246 | 247 | def get_original_reward(self) -> np.ndarray: 248 | """ 249 | Returns an unnormalized version of the rewards from the most recent step. 250 | """ 251 | return self.old_reward.copy() 252 | 253 | def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: 254 | """ 255 | Reset all environments 256 | :return: first observation of the episode 257 | """ 258 | obs = self.venv.reset() 259 | self.old_obs = obs 260 | self.returns = np.zeros(self.num_envs) 261 | if self.training and self.norm_obs: 262 | if isinstance(obs, dict) and isinstance(self.obs_rms, dict): 263 | for key in self.obs_rms.keys(): 264 | self.obs_rms[key].update(obs[key]) 265 | else: 266 | self.obs_rms.update(obs) 267 | return self.normalize_obs(obs) 268 | 269 | @staticmethod 270 | def load(load_path: str, venv: VecEnv) -> "VecNormalize": 271 | """ 272 | Loads a saved VecNormalize object. 273 | 274 | :param load_path: the path to load from. 275 | :param venv: the VecEnv to wrap. 276 | :return: 277 | """ 278 | with open(load_path, "rb") as file_handler: 279 | vec_normalize = pickle.load(file_handler) 280 | vec_normalize.set_venv(venv) 281 | return vec_normalize 282 | 283 | def save(self, save_path: str) -> None: 284 | """ 285 | Save current VecNormalize object with 286 | all running statistics and settings (e.g. clip_obs) 287 | 288 | :param save_path: The path to save to 289 | """ 290 | with open(save_path, "wb") as file_handler: 291 | pickle.dump(self, file_handler) 292 | 293 | @property 294 | def ret(self) -> np.ndarray: 295 | warnings.warn("`VecNormalize` `ret` attribute is deprecated. Please use `returns` instead.", DeprecationWarning) 296 | return self.returns 297 | -------------------------------------------------------------------------------- /offline_baselines_jax/td3/policies.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Any, Dict, List, Optional, Type, Union, Callable, Tuple 3 | 4 | import flax.linen as nn 5 | import gym 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import optax 10 | 11 | from offline_baselines_jax.common.jax_layers import ( 12 | BaseFeaturesExtractor, 13 | CombinedExtractor, 14 | FlattenExtractor, 15 | NatureCNN, 16 | create_mlp, 17 | get_actor_critic_arch, 18 | ) 19 | from offline_baselines_jax.common.policies import Model 20 | from offline_baselines_jax.common.preprocessing import get_action_dim, preprocess_obs 21 | from offline_baselines_jax.common.type_aliases import Schedule, Params 22 | 23 | 24 | @functools.partial(jax.jit, static_argnames=("actor_apply_fn", "deterministic")) 25 | def sample_actions( 26 | rng: jnp.ndarray, 27 | actor_apply_fn: Callable[..., Any], 28 | actor_params: Params, 29 | observations: [np.ndarray, Dict], 30 | deterministic: bool, 31 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 32 | 33 | rng, dropout_key = jax.random.split(rng) 34 | rngs = {"dropout": dropout_key} 35 | action = actor_apply_fn({'params': actor_params}, observations, deterministic=deterministic, rngs=rngs) 36 | return rng, action 37 | 38 | 39 | class Actor(nn.Module): 40 | features_extractor: nn.Module 41 | observation_space: gym.spaces.Space 42 | action_space: gym.spaces.Space 43 | net_arch: List[int] 44 | activation_fn: Type[nn.Module] = nn.relu 45 | dropout: float = 0.0 46 | 47 | mu = None 48 | 49 | def setup(self): 50 | action_dim = get_action_dim(self.action_space) 51 | self.mu = create_mlp(action_dim, self.net_arch, self.activation_fn, self.dropout, squash_output=True) 52 | 53 | def __call__(self, *args, **kwargs): 54 | return self.forward(*args, **kwargs) 55 | 56 | def forward(self, observations: jnp.ndarray, deterministic: bool = False): 57 | observations = preprocess_obs(observations, self.observation_space) 58 | features = self.features_extractor(observations) 59 | mu = self.mu(features, deterministic=deterministic) 60 | return mu 61 | 62 | 63 | class SingleCritic(nn.Module): 64 | features_extractor: nn.Module 65 | observation_space: gym.spaces.Space 66 | net_arch: List[int] 67 | dropout: float 68 | activation_fn: Type[nn.Module] = nn.relu 69 | 70 | q_net = None 71 | 72 | def setup(self): 73 | self.q_net = create_mlp( 74 | output_dim=1, 75 | net_arch=self.net_arch, 76 | dropout=self.dropout 77 | ) 78 | 79 | def __call__(self, *args, **kwargs): 80 | return self.forward(*args, **kwargs) 81 | 82 | def forward( 83 | self, 84 | observations: jnp.ndarray, 85 | actions: jnp.ndarray, 86 | deterministic: bool = False 87 | ): 88 | observations = preprocess_obs(observations, self.observation_space) 89 | features = self.features_extractor(observations) 90 | q_input = jnp.concatenate((features, actions), axis=1) 91 | return self.q_net(q_input, deterministic=deterministic) 92 | 93 | 94 | class Critic(nn.Module): 95 | features_extractor: nn.Module 96 | observation_space: gym.spaces.Space 97 | net_arch: List[int] 98 | dropout: float = 0.0 99 | activation_fn: Type[nn.Module] = nn.relu 100 | n_critics: int = 2 101 | 102 | q_networks = None 103 | 104 | def setup(self): 105 | batch_qs = nn.vmap( 106 | SingleCritic, 107 | in_axes=None, 108 | out_axes=1, 109 | variable_axes={"params": 1}, 110 | split_rngs={"params": True, "dropout": True}, 111 | axis_size=self.n_critics 112 | ) 113 | self.q_networks = batch_qs( 114 | self.features_extractor, 115 | self.observation_space, 116 | self.net_arch, 117 | self.dropout, 118 | self.activation_fn 119 | ) 120 | 121 | def __call__(self, *args, **kwargs): 122 | return self.forward(*args, **kwargs) 123 | 124 | def forward(self, observations: jnp.ndarray, actions: jnp.ndarray, deterministic: bool = False): 125 | return self.q_networks(observations, actions, deterministic) 126 | 127 | 128 | class TD3Policy(object): 129 | def __init__( 130 | self, 131 | rng, 132 | observation_space: gym.spaces.Space, 133 | action_space: gym.spaces.Space, 134 | lr_schedule: Schedule, 135 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, 136 | activation_fn: Type[nn.Module] = nn.relu, 137 | features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, 138 | features_extractor_kwargs: Optional[Dict[str, Any]] = None, 139 | n_critics: int = 2, 140 | dropout: float = 0.0, 141 | ): 142 | self.observation_space = observation_space 143 | self.action_space = action_space 144 | 145 | self.rng, actor_key, critic_key, features_key, dropout_key = jax.random.split(rng, 5) 146 | 147 | # Default network architecture, from the original paper 148 | if net_arch is None: 149 | if features_extractor_class == NatureCNN: 150 | net_arch = [] 151 | else: 152 | net_arch = [400, 300] 153 | 154 | if features_extractor_kwargs is None: 155 | features_extractor_kwargs = {} 156 | 157 | actor_arch, critic_arch = get_actor_critic_arch(net_arch) 158 | 159 | features_extractor_def = features_extractor_class( 160 | _observation_space=observation_space, 161 | **features_extractor_kwargs 162 | ) 163 | 164 | actor_def = Actor( 165 | features_extractor=features_extractor_def, 166 | observation_space=observation_space, 167 | action_space=action_space, 168 | net_arch=actor_arch, 169 | activation_fn=activation_fn, 170 | dropout=dropout 171 | ) 172 | 173 | critic_def = Critic( 174 | features_extractor=features_extractor_def, 175 | observation_space=observation_space, 176 | net_arch=critic_arch, 177 | activation_fn=activation_fn, 178 | n_critics=n_critics 179 | ) 180 | 181 | # Init dummy inputs 182 | if isinstance(observation_space, gym.spaces.Dict): 183 | observation = observation_space.sample() 184 | for key, _ in observation_space.spaces.items(): 185 | observation[key] = observation_space[key][np.newaxis, ...] 186 | else: 187 | observation = observation_space.sample()[np.newaxis, ...] 188 | 189 | actor_rngs = {"params": actor_key, "dropout": dropout_key} 190 | actor = Model.create( 191 | actor_def, 192 | inputs=[actor_rngs, observation], 193 | tx=optax.adam(learning_rate=lr_schedule) 194 | ) 195 | actor_target = Model.create( 196 | actor_def, 197 | inputs=[actor_rngs, observation], 198 | tx=optax.adam(learning_rate=lr_schedule) 199 | ) 200 | 201 | if isinstance(observation_space, gym.spaces.Dict): 202 | observation = observation_space.sample() 203 | for key, _ in observation_space.spaces.items(): 204 | observation[key] = np.expand_dims(observation[key], axis=0) 205 | else: 206 | observation = np.expand_dims(observation_space.sample(), axis=0) 207 | action = np.expand_dims(action_space.sample(), axis=0) 208 | 209 | critic_rngs = {"params": critic_key, "dropout": dropout_key} 210 | critic = Model.create( 211 | critic_def, 212 | inputs=[critic_rngs, observation, action], 213 | tx=optax.adam(learning_rate=lr_schedule) 214 | ) 215 | critic_target = Model.create(critic_def, inputs=[critic_key, observation, action]) 216 | 217 | self.actor, self.actor_target = actor, actor_target 218 | self.critic, self.critic_target = critic, critic_target 219 | 220 | def _predict(self, observation: jnp.ndarray, deterministic: bool) -> jnp.ndarray: 221 | rng, actions = sample_actions(self.rng, self.actor.apply_fn, self.actor.params, observation, deterministic) 222 | self.rng = rng 223 | return np.asarray(actions) 224 | 225 | def predict(self, observation: jnp.ndarray, deterministic: bool = True) -> np.ndarray: 226 | actions = self._predict(observation, deterministic) 227 | if isinstance(self.action_space, gym.spaces.Box): 228 | # Actions could be on arbitrary scale, so clip the actions to avoid 229 | # out of bound error (e.g. if sampling from a Gaussian distribution) 230 | actions = np.clip(actions, self.action_space.low, self.action_space.high) 231 | return actions, None 232 | 233 | def scale_action(self, action: np.ndarray) -> np.ndarray: 234 | """ 235 | Rescale the action from [low, high] to [-1, 1] 236 | (no need for symmetric action space) 237 | :param action: Action to scale 238 | :return: Scaled action 239 | """ 240 | low, high = self.action_space.low, self.action_space.high 241 | return 2.0 * ((action - low) / (high - low)) - 1.0 242 | 243 | def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: 244 | """ 245 | Rescale the action from [-1, 1] to [low, high] 246 | (no need for symmetric action space) 247 | :param scaled_action: Action to un-scale 248 | """ 249 | low, high = self.action_space.low, self.action_space.high 250 | return low + (0.5 * (scaled_action + 1.0) * (high - low)) 251 | 252 | 253 | MlpPolicy = TD3Policy 254 | 255 | 256 | class CnnPolicy(TD3Policy): 257 | """ 258 | Policy class (with both actor and critic) for TD3. 259 | 260 | :param observation_space: Observation space 261 | :param action_space: Action space 262 | :param lr_schedule: Learning rate schedule (could be constant) 263 | :param net_arch: The specification of the policy and value networks. 264 | :param activation_fn: Activation function 265 | :param features_extractor_class: Features extractor to use. 266 | :param features_extractor_kwargs: Keyword arguments 267 | to pass to the features extractor. 268 | :param normalize_images: Whether to normalize images or not, 269 | dividing by 255.0 (True by default) 270 | :param optimizer_class: The optimizer to use, 271 | ``th.optim.Adam`` by default 272 | :param optimizer_kwargs: Additional keyword arguments, 273 | excluding the learning rate, to pass to the optimizer 274 | :param n_critics: Number of critic networks to create. 275 | :param share_features_extractor: Whether to share or not the features extractor 276 | between the actor and the critic (this saves computation time) 277 | """ 278 | 279 | def __init__( 280 | self, 281 | rng, 282 | observation_space: gym.spaces.Space, 283 | action_space: gym.spaces.Space, 284 | lr_schedule: Schedule, 285 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, 286 | activation_fn: Type[nn.Module] = nn.relu, 287 | features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, 288 | features_extractor_kwargs: Optional[Dict[str, Any]] = None, 289 | n_critics: int = 2, 290 | ): 291 | super(CnnPolicy, self).__init__( 292 | rng, 293 | observation_space, 294 | action_space, 295 | lr_schedule, 296 | net_arch, 297 | activation_fn, 298 | features_extractor_class, 299 | features_extractor_kwargs, 300 | n_critics, 301 | ) 302 | 303 | 304 | class MultiInputPolicy(TD3Policy): 305 | """ 306 | Policy class (with both actor and critic) for TD3 to be used with Dict observation spaces. 307 | 308 | :param observation_space: Observation space 309 | :param action_space: Action space 310 | :param lr_schedule: Learning rate schedule (could be constant) 311 | :param net_arch: The specification of the policy and value networks. 312 | :param activation_fn: Activation function 313 | :param features_extractor_class: Features extractor to use. 314 | :param features_extractor_kwargs: Keyword arguments 315 | to pass to the features extractor. 316 | :param normalize_images: Whether to normalize images or not, 317 | dividing by 255.0 (True by default) 318 | :param optimizer_class: The optimizer to use, 319 | ``th.optim.Adam`` by default 320 | :param optimizer_kwargs: Additional keyword arguments, 321 | excluding the learning rate, to pass to the optimizer 322 | :param n_critics: Number of critic networks to create. 323 | :param share_features_extractor: Whether to share or not the features extractor 324 | between the actor and the critic (this saves computation time) 325 | """ 326 | 327 | def __init__( 328 | self, 329 | rng, 330 | observation_space: gym.spaces.Dict, 331 | action_space: gym.spaces.Space, 332 | lr_schedule: Schedule, 333 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, 334 | activation_fn: Type[nn.Module] = nn.relu, 335 | features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, 336 | features_extractor_kwargs: Optional[Dict[str, Any]] = None, 337 | n_critics: int = 2, 338 | ): 339 | super(MultiInputPolicy, self).__init__( 340 | rng, 341 | observation_space, 342 | action_space, 343 | lr_schedule, 344 | net_arch, 345 | activation_fn, 346 | features_extractor_class, 347 | features_extractor_kwargs, 348 | n_critics, 349 | ) -------------------------------------------------------------------------------- /offline_baselines_jax/cql/cql.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Type, Union 2 | 3 | import flax.linen as nn 4 | import gym 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import optax 9 | from stable_baselines3.common.noise import ActionNoise 10 | 11 | from offline_baselines_jax.common.buffers import ReplayBuffer 12 | from offline_baselines_jax.common.off_policy_algorithm import OffPolicyAlgorithm 13 | from offline_baselines_jax.common.policies import Model 14 | from offline_baselines_jax.common.type_aliases import GymEnv, MaybeCallback, Schedule, Params 15 | from offline_baselines_jax.sac.policies import SACPolicy 16 | from .core import update_cql 17 | 18 | 19 | class LogAlphaCoef(nn.Module): 20 | init_value: float = 1.0 21 | 22 | @nn.compact 23 | def __call__(self) -> jnp.ndarray: 24 | log_temp = self.param('log_alpha', init_fn=lambda key: jnp.full((), jnp.log(self.init_value))) 25 | return log_temp 26 | 27 | class LogEntropyCoef(nn.Module): 28 | init_value: float = 1.0 29 | 30 | @nn.compact 31 | def __call__(self) -> jnp.ndarray: 32 | log_temp = self.param('log_temp', init_fn=lambda key: jnp.full((), jnp.log(self.init_value))) 33 | return log_temp 34 | 35 | 36 | class CQL(OffPolicyAlgorithm): 37 | """ 38 | Conservative Q Learning (CQL) 39 | 40 | :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) 41 | :param env: The environment to learn from (if registered in Gym, can be str) 42 | :param learning_rate: learning rate for adam optimizer, 43 | the same learning rate will be used for all networks (Q-Values, Actor and Value function) 44 | it can be a function of the current progress remaining (from 1 to 0) 45 | :param buffer_size: size of the replay buffer 46 | :param learning_starts: how many steps of the model to collect transitions for before learning starts 47 | :param batch_size: Minibatch size for each gradient update 48 | :param tau: the soft update coefficient ("Polyak update", between 0 and 1) 49 | :param gamma: the discount factor 50 | :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit 51 | like ``(5, "step")`` or ``(2, "episode")``. 52 | :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) 53 | Set to ``-1`` means to do as many gradient steps as steps done in the environment 54 | during the rollout. 55 | :param action_noise: the action noise type (None by default), this can help 56 | for hard exploration problem. Cf common.noise for the different action noise type. 57 | :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). 58 | If ``None``, it will be automatically selected. 59 | :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. 60 | :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer 61 | at a cost of more complexity. 62 | See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 63 | :param ent_coef: Entropy regularization coefficient. (Equivalent to 64 | inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. 65 | Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value) 66 | :param target_update_interval: update the target network every ``target_network_update_freq`` 67 | gradient steps. 68 | :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) 69 | :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) 70 | instead of action noise exploration (default: False) 71 | :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE 72 | Default: -1 (only sample at the beginning of the rollout) 73 | :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling 74 | during the warm up phase (before learning starts) 75 | :param create_eval_env: Whether to create a second environment that will be 76 | used for evaluating the agent periodically. (Only available when passing string for the environment) 77 | :param policy_kwargs: additional arguments to be passed to the policy on creation 78 | :param verbose: the verbosity level: 0 no output, 1 info, 2 debug 79 | :param seed: Seed for the pseudo random generators 80 | :param _init_setup_model: Whether or not to build the network at the creation of the instance 81 | """ 82 | 83 | def __init__( 84 | self, 85 | policy: Union[str, Type[SACPolicy]], 86 | env: Union[GymEnv, str], 87 | learning_rate: Union[float, Schedule] = 3e-4, 88 | buffer_size: int = 1_000_000, # 1e6 89 | learning_starts: int = 0, 90 | batch_size: int = 256, 91 | tau: float = 0.005, 92 | gamma: float = 0.99, 93 | train_freq: Union[int, Tuple[int, str]] = 1, 94 | gradient_steps: int = 1, 95 | action_noise: Optional[ActionNoise] = None, 96 | replay_buffer_class: Optional[ReplayBuffer] = None, 97 | replay_buffer_kwargs: Optional[Dict[str, Any]] = None, 98 | optimize_memory_usage: bool = False, 99 | ent_coef: Union[str, float] = "auto", 100 | target_update_interval: int = 1, 101 | target_entropy: Union[str, float] = "auto", 102 | tensorboard_log: Optional[str] = None, 103 | create_eval_env: bool = False, 104 | policy_kwargs: Optional[Dict[str, Any]] = None, 105 | verbose: int = 0, 106 | seed: int = 0, 107 | _init_setup_model: bool = True, 108 | 109 | # Add for CQL 110 | alpha_coef: float = "auto", 111 | lagrange_thresh: int = 10.0, 112 | without_exploration: bool = True, 113 | conservative_weight: float = 10.0, 114 | ): 115 | 116 | super(CQL, self).__init__( 117 | policy, 118 | env, 119 | learning_rate, 120 | buffer_size, 121 | learning_starts, 122 | batch_size, 123 | tau, 124 | gamma, 125 | train_freq, 126 | gradient_steps, 127 | action_noise, 128 | replay_buffer_class=replay_buffer_class, 129 | replay_buffer_kwargs=replay_buffer_kwargs, 130 | policy_kwargs=policy_kwargs, 131 | tensorboard_log=tensorboard_log, 132 | verbose=verbose, 133 | create_eval_env=create_eval_env, 134 | seed=seed, 135 | optimize_memory_usage=optimize_memory_usage, 136 | supported_action_spaces=(gym.spaces.Box), 137 | support_multi_env=True, 138 | without_exploration=without_exploration, 139 | ) 140 | 141 | self.target_entropy = target_entropy 142 | self.log_ent_coef = None 143 | self.log_alpha_coef = None 144 | # Entropy coefficient / Entropy temperature 145 | # Inverse of the reward scale 146 | self.ent_coef = ent_coef 147 | self.target_update_interval = target_update_interval 148 | self.entropy_update = True 149 | self.alpha_update = True 150 | self.lagrange_thresh = lagrange_thresh 151 | self.alpha_coef = alpha_coef 152 | self.conservative_weight = conservative_weight 153 | 154 | if _init_setup_model: 155 | self._setup_model() 156 | 157 | def _setup_model(self) -> None: 158 | super(CQL, self)._setup_model() 159 | self._create_aliases() 160 | # Target entropy is used when learning the entropy coefficient 161 | if self.target_entropy == "auto": 162 | # automatically set target entropy if needed 163 | self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) 164 | else: 165 | # Force conversion 166 | # this will also throw an error for unexpected string 167 | self.target_entropy = float(self.target_entropy) 168 | # The entropy coefficient or entropy can be learned automatically 169 | # see Automating Entropy Adjustment for Maximum Entropy RL section 170 | # of https://arxiv.org/abs/1812.05905 171 | if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"): 172 | # Default initial value of ent_coef when learned 173 | init_value = 1.0 174 | if "_" in self.ent_coef: 175 | init_value = float(self.ent_coef.split("_")[1]) 176 | assert init_value > 0.0, "The initial value of ent_coef must be greater than 0" 177 | 178 | # Note: we optimize the log of the entropy coeff which is slightly different from the paper 179 | # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 180 | log_ent_coef_def = LogEntropyCoef(init_value) 181 | self.key, temp_key = jax.random.split(self.key, 2) 182 | self.log_ent_coef = Model.create(log_ent_coef_def, inputs=[temp_key], 183 | tx=optax.adam(learning_rate=self.lr_schedule(1))) 184 | 185 | else: 186 | # Force conversion to float 187 | # this will throw an error if a malformed string (different from 'auto') 188 | # is passed 189 | log_ent_coef_def = LogEntropyCoef(self.ent_coef) 190 | self.key, temp_key = jax.random.split(self.key, 2) 191 | self.log_ent_coef = Model.create(log_ent_coef_def, inputs=[temp_key]) 192 | self.entropy_update = False 193 | 194 | if isinstance(self.alpha_coef, str) and self.ent_coef.startswith("auto"): 195 | # Default initial value of alpha_coef when learned 196 | init_value = 1.0 197 | log_alpha_coef_def = LogAlphaCoef(init_value) 198 | self.key, temp_key = jax.random.split(self.key, 2) 199 | self.log_alpha_coef = Model.create(log_alpha_coef_def, inputs=[temp_key], 200 | tx=optax.adam(learning_rate=self.lr_schedule(1))) 201 | 202 | else: 203 | # Force conversion to float 204 | # this will throw an error if a malformed string (different from 'auto') 205 | # is passed 206 | log_alpha_coef_def = LogAlphaCoef(self.ent_coef) 207 | self.key, temp_key = jax.random.split(self.key, 2) 208 | self.log_alpha_coef = Model.create(log_alpha_coef_def, inputs=[temp_key]) 209 | self.alpha_update = False 210 | 211 | def _create_aliases(self) -> None: 212 | self.actor = self.policy.actor 213 | self.critic = self.policy.critic 214 | self.critic_target = self.policy.critic_target 215 | 216 | def train(self, gradient_steps: int, batch_size: int = 64) -> None: 217 | pass 218 | 219 | def offline_train(self, gradient_steps: int, batch_size: int) -> None: 220 | ent_coef_losses, ent_coefs = [], [] 221 | actor_losses, critic_losses = [], [] 222 | alpha_coef_losses, alpha_coefs, conservative_losses = [], [], [] 223 | 224 | for gradient_step in range(gradient_steps): 225 | # Sample replay buffer 226 | replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) 227 | self.key, key = jax.random.split(self.key, 2) 228 | 229 | target_update_cond = gradient_step % self.target_update_interval == 0 230 | self.key, new_actor, new_critic, new_critic_target, new_log_ent_coef, new_log_alpha_coef, info \ 231 | = update_cql(key, self.actor, self.critic, self.critic_target, self.log_ent_coef, self.log_alpha_coef, replay_data, 232 | self.gamma, self.target_entropy, self.tau, target_update_cond, self.entropy_update, self.alpha_update, 233 | self.conservative_weight, self.lagrange_thresh) 234 | 235 | ent_coef_losses.append(info['ent_coef_loss']) 236 | ent_coefs.append(info['ent_coef']) 237 | critic_losses.append(info['critic_loss']) 238 | actor_losses.append(info['actor_loss']) 239 | alpha_coefs.append(info['alpha_coef']) 240 | alpha_coef_losses.append(info['alpha_coef_loss']) 241 | conservative_losses.append(info['conservative_loss']) 242 | 243 | self.policy.actor = new_actor 244 | self.policy.critic = new_critic 245 | self.policy.critic_target = new_critic_target 246 | self.log_ent_coef = new_log_ent_coef 247 | self.log_alpha_coef = new_log_alpha_coef 248 | 249 | self._create_aliases() 250 | 251 | self._n_updates += gradient_steps 252 | self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") 253 | self.logger.record("train/ent_coef", np.mean(ent_coefs)) 254 | self.logger.record("train/alpha_coef", np.mean(alpha_coefs)) 255 | self.logger.record("train/actor_loss", np.mean(actor_losses)) 256 | self.logger.record("train/critic_loss", np.mean(critic_losses)) 257 | self.logger.record("train/conservative_loss", np.mean(conservative_losses)) 258 | if len(ent_coef_losses) > 0: 259 | self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) 260 | self.logger.record("train/alpha_coef_loss", np.mean(alpha_coef_losses)) 261 | 262 | def learn( 263 | self, 264 | total_timesteps: int, 265 | callback: MaybeCallback = None, 266 | log_interval: int = 1000, 267 | eval_env: Optional[GymEnv] = None, 268 | eval_freq: int = -1, 269 | n_eval_episodes: int = 5, 270 | tb_log_name: str = "SAC", 271 | eval_log_path: Optional[str] = None, 272 | reset_num_timesteps: bool = True, 273 | ) -> OffPolicyAlgorithm: 274 | 275 | return super(CQL, self).learn( 276 | total_timesteps=total_timesteps, 277 | callback=callback, 278 | log_interval=log_interval, 279 | eval_env=eval_env, 280 | eval_freq=eval_freq, 281 | n_eval_episodes=n_eval_episodes, 282 | tb_log_name=tb_log_name, 283 | eval_log_path=eval_log_path, 284 | reset_num_timesteps=reset_num_timesteps, 285 | ) 286 | 287 | def _excluded_save_params(self) -> List[str]: 288 | return super(CQL, self)._excluded_save_params() + ["actor", "critic", "critic_target", "log_ent_coef", "log_alpha_coef"] 289 | 290 | def _get_jax_save_params(self) -> Dict[str, Params]: 291 | params_dict = {} 292 | params_dict['actor'] = self.actor.params 293 | params_dict['critic'] = self.critic.params 294 | params_dict['critic_target'] = self.critic_target.params 295 | params_dict['log_ent_coef'] = self.log_ent_coef.params 296 | params_dict['log_alpha_coef'] = self.log_alpha_coef.params 297 | return params_dict 298 | 299 | def _get_jax_load_params(self) -> List[str]: 300 | return ['actor', 'critic', 'critic_target', 'log_ent_coef', 'log_alpha_coef'] -------------------------------------------------------------------------------- /offline_baselines_jax/common/vec_env/base_vec_env.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union 5 | 6 | import cloudpickle 7 | import gym 8 | import numpy as np 9 | 10 | # Define type aliases here to avoid circular import 11 | # Used when we want to access one or more VecEnv 12 | VecEnvIndices = Union[None, int, Iterable[int]] 13 | # VecEnvObs is what is returned by the reset() method 14 | # it contains the observation for each env 15 | VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]] 16 | # VecEnvStepReturn is what is returned by the step() method 17 | # it contains the observation, reward, done, info for each env 18 | VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]] 19 | 20 | 21 | def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover 22 | """ 23 | Tile N images into one big PxQ image 24 | (P,Q) are chosen to be as close as possible, and if N 25 | is square, then P=Q. 26 | 27 | :param img_nhwc: list or array of images, ndim=4 once turned into array. img nhwc 28 | n = batch index, h = height, w = width, c = channel 29 | :return: img_HWc, ndim=3 30 | """ 31 | img_nhwc = np.asarray(img_nhwc) 32 | n_images, height, width, n_channels = img_nhwc.shape 33 | # new_height was named H before 34 | new_height = int(np.ceil(np.sqrt(n_images))) 35 | # new_width was named W before 36 | new_width = int(np.ceil(float(n_images) / new_height)) 37 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)]) 38 | # img_HWhwc 39 | out_image = img_nhwc.reshape((new_height, new_width, height, width, n_channels)) 40 | # img_HhWwc 41 | out_image = out_image.transpose(0, 2, 1, 3, 4) 42 | # img_Hh_Ww_c 43 | out_image = out_image.reshape((new_height * height, new_width * width, n_channels)) 44 | return out_image 45 | 46 | 47 | class VecEnv(ABC): 48 | """ 49 | An abstract asynchronous, vectorized environment. 50 | 51 | :param num_envs: the number of environments 52 | :param observation_space: the observation space 53 | :param action_space: the action space 54 | """ 55 | 56 | metadata = {"render.modes": ["human", "rgb_array"]} 57 | 58 | def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_space: gym.spaces.Space): 59 | self.num_envs = num_envs 60 | self.observation_space = observation_space 61 | self.action_space = action_space 62 | 63 | @abstractmethod 64 | def reset(self) -> VecEnvObs: 65 | """ 66 | Reset all the environments and return an array of 67 | observations, or a tuple of observation arrays. 68 | 69 | If step_async is still doing work, that work will 70 | be cancelled and step_wait() should not be called 71 | until step_async() is invoked again. 72 | 73 | :return: observation 74 | """ 75 | raise NotImplementedError() 76 | 77 | @abstractmethod 78 | def step_async(self, actions: np.ndarray) -> None: 79 | """ 80 | Tell all the environments to start taking a step 81 | with the given actions. 82 | Call step_wait() to get the results of the step. 83 | 84 | You should not call this if a step_async run is 85 | already pending. 86 | """ 87 | raise NotImplementedError() 88 | 89 | @abstractmethod 90 | def step_wait(self) -> VecEnvStepReturn: 91 | """ 92 | Wait for the step taken with step_async(). 93 | 94 | :return: observation, reward, done, information 95 | """ 96 | raise NotImplementedError() 97 | 98 | @abstractmethod 99 | def close(self) -> None: 100 | """ 101 | Clean up the environment's resources. 102 | """ 103 | raise NotImplementedError() 104 | 105 | @abstractmethod 106 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: 107 | """ 108 | Return attribute from vectorized environment. 109 | 110 | :param attr_name: The name of the attribute whose value to return 111 | :param indices: Indices of envs to get attribute from 112 | :return: List of values of 'attr_name' in all environments 113 | """ 114 | raise NotImplementedError() 115 | 116 | @abstractmethod 117 | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: 118 | """ 119 | Set attribute inside vectorized environments. 120 | 121 | :param attr_name: The name of attribute to assign new value 122 | :param value: Value to assign to `attr_name` 123 | :param indices: Indices of envs to assign value 124 | :return: 125 | """ 126 | raise NotImplementedError() 127 | 128 | @abstractmethod 129 | def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: 130 | """ 131 | Call instance methods of vectorized environments. 132 | 133 | :param method_name: The name of the environment method to invoke. 134 | :param indices: Indices of envs whose method to call 135 | :param method_args: Any positional arguments to provide in the call 136 | :param method_kwargs: Any keyword arguments to provide in the call 137 | :return: List of items returned by the environment's method call 138 | """ 139 | raise NotImplementedError() 140 | 141 | @abstractmethod 142 | def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: 143 | """ 144 | Check if environments are wrapped with a given wrapper. 145 | 146 | :param method_name: The name of the environment method to invoke. 147 | :param indices: Indices of envs whose method to call 148 | :param method_args: Any positional arguments to provide in the call 149 | :param method_kwargs: Any keyword arguments to provide in the call 150 | :return: True if the env is wrapped, False otherwise, for each env queried. 151 | """ 152 | raise NotImplementedError() 153 | 154 | def step(self, actions: np.ndarray) -> VecEnvStepReturn: 155 | """ 156 | Step the environments with the given action 157 | 158 | :param actions: the action 159 | :return: observation, reward, done, information 160 | """ 161 | self.step_async(actions) 162 | return self.step_wait() 163 | 164 | def get_images(self) -> Sequence[np.ndarray]: 165 | """ 166 | Return RGB images from each environment 167 | """ 168 | raise NotImplementedError 169 | 170 | def render(self, mode: str = "human") -> Optional[np.ndarray]: 171 | """ 172 | Gym environment rendering 173 | 174 | :param mode: the rendering type 175 | """ 176 | try: 177 | imgs = self.get_images() 178 | except NotImplementedError: 179 | warnings.warn(f"Render not defined for {self}") 180 | return 181 | 182 | # Create a big image by tiling images from subprocesses 183 | bigimg = tile_images(imgs) 184 | if mode == "human": 185 | import cv2 # pytype:disable=import-error 186 | 187 | cv2.imshow("vecenv", bigimg[:, :, ::-1]) 188 | cv2.waitKey(1) 189 | elif mode == "rgb_array": 190 | return bigimg 191 | else: 192 | raise NotImplementedError(f"Render mode {mode} is not supported by VecEnvs") 193 | 194 | @abstractmethod 195 | def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: 196 | """ 197 | Sets the random seeds for all environments, based on a given seed. 198 | Each individual environment will still get its own seed, by incrementing the given seed. 199 | 200 | :param seed: The random seed. May be None for completely random seeding. 201 | :return: Returns a list containing the seeds for each individual env. 202 | Note that all list elements may be None, if the env does not return anything when being seeded. 203 | """ 204 | pass 205 | 206 | @property 207 | def unwrapped(self) -> "VecEnv": 208 | if isinstance(self, VecEnvWrapper): 209 | return self.venv.unwrapped 210 | else: 211 | return self 212 | 213 | def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]: 214 | """Check if an attribute reference is being hidden in a recursive call to __getattr__ 215 | 216 | :param name: name of attribute to check for 217 | :param already_found: whether this attribute has already been found in a wrapper 218 | :return: name of module whose attribute is being shadowed, if any. 219 | """ 220 | if hasattr(self, name) and already_found: 221 | return f"{type(self).__module__}.{type(self).__name__}" 222 | else: 223 | return None 224 | 225 | def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]: 226 | """ 227 | Convert a flexibly-typed reference to environment indices to an implied list of indices. 228 | 229 | :param indices: refers to indices of envs. 230 | :return: the implied list of indices. 231 | """ 232 | if indices is None: 233 | indices = range(self.num_envs) 234 | elif isinstance(indices, int): 235 | indices = [indices] 236 | return indices 237 | 238 | 239 | class VecEnvWrapper(VecEnv): 240 | """ 241 | Vectorized environment base class 242 | 243 | :param venv: the vectorized environment to wrap 244 | :param observation_space: the observation space (can be None to load from venv) 245 | :param action_space: the action space (can be None to load from venv) 246 | """ 247 | 248 | def __init__( 249 | self, 250 | venv: VecEnv, 251 | observation_space: Optional[gym.spaces.Space] = None, 252 | action_space: Optional[gym.spaces.Space] = None, 253 | ): 254 | self.venv = venv 255 | VecEnv.__init__( 256 | self, 257 | num_envs=venv.num_envs, 258 | observation_space=observation_space or venv.observation_space, 259 | action_space=action_space or venv.action_space, 260 | ) 261 | self.class_attributes = dict(inspect.getmembers(self.__class__)) 262 | 263 | def step_async(self, actions: np.ndarray) -> None: 264 | self.venv.step_async(actions) 265 | 266 | @abstractmethod 267 | def reset(self) -> VecEnvObs: 268 | pass 269 | 270 | @abstractmethod 271 | def step_wait(self) -> VecEnvStepReturn: 272 | pass 273 | 274 | def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: 275 | return self.venv.seed(seed) 276 | 277 | def close(self) -> None: 278 | return self.venv.close() 279 | 280 | def render(self, mode: str = "human") -> Optional[np.ndarray]: 281 | return self.venv.render(mode=mode) 282 | 283 | def get_images(self) -> Sequence[np.ndarray]: 284 | return self.venv.get_images() 285 | 286 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: 287 | return self.venv.get_attr(attr_name, indices) 288 | 289 | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: 290 | return self.venv.set_attr(attr_name, value, indices) 291 | 292 | def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: 293 | return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs) 294 | 295 | def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: 296 | return self.venv.env_is_wrapped(wrapper_class, indices=indices) 297 | 298 | def __getattr__(self, name: str) -> Any: 299 | """Find attribute from wrapped venv(s) if this wrapper does not have it. 300 | Useful for accessing attributes from venvs which are wrapped with multiple wrappers 301 | which have unique attributes of interest. 302 | """ 303 | blocked_class = self.getattr_depth_check(name, already_found=False) 304 | if blocked_class is not None: 305 | own_class = f"{type(self).__module__}.{type(self).__name__}" 306 | error_str = ( 307 | f"Error: Recursive attribute lookup for {name} from {own_class} is " 308 | f"ambiguous and hides attribute from {blocked_class}" 309 | ) 310 | raise AttributeError(error_str) 311 | 312 | return self.getattr_recursive(name) 313 | 314 | def _get_all_attributes(self) -> Dict[str, Any]: 315 | """Get all (inherited) instance and class attributes 316 | 317 | :return: all_attributes 318 | """ 319 | all_attributes = self.__dict__.copy() 320 | all_attributes.update(self.class_attributes) 321 | return all_attributes 322 | 323 | def getattr_recursive(self, name: str) -> Any: 324 | """Recursively check wrappers to find attribute. 325 | 326 | :param name: name of attribute to look for 327 | :return: attribute 328 | """ 329 | all_attributes = self._get_all_attributes() 330 | if name in all_attributes: # attribute is present in this wrapper 331 | attr = getattr(self, name) 332 | elif hasattr(self.venv, "getattr_recursive"): 333 | # Attribute not present, child is wrapper. Call getattr_recursive rather than getattr 334 | # to avoid a duplicate call to getattr_depth_check. 335 | attr = self.venv.getattr_recursive(name) 336 | else: # attribute not present, child is an unwrapped VecEnv 337 | attr = getattr(self.venv, name) 338 | 339 | return attr 340 | 341 | def getattr_depth_check(self, name: str, already_found: bool) -> str: 342 | """See base class. 343 | 344 | :return: name of module whose attribute is being shadowed, if any. 345 | """ 346 | all_attributes = self._get_all_attributes() 347 | if name in all_attributes and already_found: 348 | # this venv's attribute is being hidden because of a higher venv. 349 | shadowed_wrapper_class = f"{type(self).__module__}.{type(self).__name__}" 350 | elif name in all_attributes and not already_found: 351 | # we have found the first reference to the attribute. Now check for duplicates. 352 | shadowed_wrapper_class = self.venv.getattr_depth_check(name, True) 353 | else: 354 | # this wrapper does not have the attribute. Keep searching. 355 | shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found) 356 | 357 | return shadowed_wrapper_class 358 | 359 | 360 | class CloudpickleWrapper: 361 | """ 362 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 363 | 364 | :param var: the variable you wish to wrap for pickling with cloudpickle 365 | """ 366 | 367 | def __init__(self, var: Any): 368 | self.var = var 369 | 370 | def __getstate__(self) -> Any: 371 | return cloudpickle.dumps(self.var) 372 | 373 | def __setstate__(self, var: Any) -> None: 374 | self.var = cloudpickle.loads(var) 375 | -------------------------------------------------------------------------------- /offline_baselines_jax/sac/policies.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable 3 | 4 | import flax.linen as nn 5 | import gym 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import optax 10 | from tensorflow_probability.substrates import jax as tfp 11 | 12 | from offline_baselines_jax.common.jax_layers import ( 13 | BaseFeaturesExtractor, 14 | CombinedExtractor, 15 | FlattenExtractor, 16 | NatureCNN, 17 | create_mlp, 18 | get_actor_critic_arch, 19 | ) 20 | from offline_baselines_jax.common.policies import Model 21 | from offline_baselines_jax.common.preprocessing import get_action_dim, preprocess_obs 22 | from offline_baselines_jax.common.type_aliases import Schedule, Params 23 | from offline_baselines_jax.common.utils import get_basic_rngs 24 | 25 | tfd = tfp.distributions 26 | tfb = tfp.bijectors 27 | # CAP the standard deviation of the actor 28 | LOG_STD_MAX = 2 29 | LOG_STD_MIN = -20 30 | 31 | 32 | @functools.partial(jax.jit, static_argnames=("actor_apply_fn", "deterministic")) 33 | def sample_actions( 34 | rng: jnp.ndarray, 35 | actor_apply_fn: Callable[..., Any], 36 | actor_params: Params, 37 | observations: Union[np.ndarray, Dict], 38 | deterministic: bool 39 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 40 | rng, dropout_key = jax.random.split(rng) 41 | rngs = {"dropout": dropout_key} 42 | dist = actor_apply_fn({'params': actor_params}, observations, deterministic=deterministic, rngs=rngs) 43 | rng, key = jax.random.split(rng) 44 | return rng, dist.sample(seed=rng) 45 | 46 | 47 | class Actor(nn.Module): 48 | features_extractor: nn.Module 49 | observation_space: gym.spaces.Space 50 | action_space: gym.spaces.Space 51 | net_arch: List[int] 52 | activation_fn: Type[nn.Module] = nn.relu 53 | dropout: float = 0.0 54 | 55 | latent_pi = None 56 | mu = None 57 | log_std = None 58 | 59 | def setup(self): 60 | self.latent_pi = create_mlp(-1, self.net_arch, self.activation_fn, self.dropout) 61 | action_dim = get_action_dim(self.action_space) 62 | self.mu = create_mlp(action_dim, self.net_arch, self.activation_fn, self.dropout) 63 | self.log_std = create_mlp(action_dim, self.net_arch, self.activation_fn, self.dropout) 64 | 65 | def __call__(self, *args, **kwargs): 66 | return self.forward(*args, **kwargs) 67 | 68 | def forward(self, observations: jnp.ndarray, deterministic: bool = False, **kwargs) -> jnp.ndarray: 69 | mean_actions, log_stds = self.get_action_dist_params(observations, deterministic=deterministic, **kwargs) 70 | return self.actions_from_params(mean_actions, log_stds) 71 | 72 | def get_action_dist_params( 73 | self, 74 | observations: jnp.ndarray, 75 | deterministic: bool = False, 76 | **kwargs, 77 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 78 | 79 | observations = preprocess_obs(observations, self.observation_space) 80 | features = self.features_extractor(observations, **kwargs) 81 | 82 | latent_pi = self.latent_pi(features, deterministic=deterministic) 83 | mean_actions = self.mu(latent_pi, deterministic=deterministic) 84 | log_stds = self.log_std(latent_pi, deterministic=deterministic) 85 | log_stds = jnp.clip(log_stds, LOG_STD_MIN, LOG_STD_MAX) 86 | 87 | return mean_actions, log_stds 88 | 89 | def actions_from_params(self, mean: jnp.ndarray, log_std: jnp.ndarray): 90 | base_dist = tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)) 91 | sampling_dist = tfd.TransformedDistribution(distribution=base_dist, bijector=tfb.Tanh()) 92 | return sampling_dist 93 | 94 | 95 | class SingleCritic(nn.Module): 96 | features_extractor: nn.Module 97 | observation_space: gym.spaces.Space 98 | net_arch: List[int] 99 | dropout: float 100 | activation_fn: Type[nn.Module] = nn.relu 101 | 102 | q_net = None 103 | 104 | def setup(self): 105 | self.q_net = create_mlp( 106 | output_dim=1, 107 | net_arch=self.net_arch, 108 | dropout=self.dropout, 109 | ) 110 | 111 | def __call__(self, *args, **kwargs): 112 | return self.forward(*args, **kwargs) 113 | 114 | def forward( 115 | self, 116 | observations: jnp.ndarray, 117 | actions: jnp.ndarray, 118 | deterministic: bool = False 119 | ): 120 | observations = preprocess_obs(observations, self.observation_space) 121 | features = self.features_extractor(observations) 122 | q_input = jnp.concatenate((features, actions), axis=1) 123 | return self.q_net(q_input, deterministic=deterministic) 124 | 125 | 126 | class Critic(nn.Module): 127 | features_extractor: nn.Module 128 | observation_space: gym.spaces.Space 129 | net_arch: List[int] 130 | dropout: float = 0.0 131 | activation_fn: Type[nn.Module] = nn.relu 132 | n_critics: int = 2 133 | 134 | q_networks = None 135 | 136 | def setup(self): 137 | batch_qs = nn.vmap( 138 | SingleCritic, 139 | in_axes=None, 140 | out_axes=1, 141 | variable_axes={"params": 1, "batch_stats": 1}, 142 | split_rngs={"params": True, "dropout": True}, 143 | axis_size=self.n_critics 144 | ) 145 | self.q_networks = batch_qs( 146 | self.features_extractor, 147 | self.observation_space, 148 | self.net_arch, 149 | self.dropout, 150 | self.activation_fn 151 | ) 152 | 153 | def __call__(self, *args, **kwargs): 154 | return self.forward(*args, **kwargs) 155 | 156 | def forward(self, observations: jnp.ndarray, actions: jnp.ndarray, deterministic: bool = False): 157 | return self.q_networks(observations, actions, deterministic) 158 | 159 | 160 | 161 | class SACPolicy(object): 162 | def __init__( 163 | self, 164 | rng, 165 | observation_space: gym.spaces.Space, 166 | action_space: gym.spaces.Space, 167 | lr_schedule: Schedule, 168 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, 169 | activation_fn: Type[nn.Module] = nn.relu, 170 | features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, 171 | features_extractor_kwargs: Optional[Dict[str, Any]] = None, 172 | n_critics: int = 2, 173 | dropout: float = 0.0, 174 | ): 175 | self.observation_space = observation_space 176 | self.action_space = action_space 177 | 178 | self.rng, actor_key, critic_key, features_key, dropout_key = jax.random.split(rng, 5) 179 | 180 | if net_arch is None: 181 | if features_extractor_class == NatureCNN: 182 | net_arch = [] 183 | else: 184 | net_arch = [256, 256] 185 | 186 | if features_extractor_kwargs is None: 187 | features_extractor_kwargs = {} 188 | 189 | features_extractor_def = features_extractor_class( 190 | _observation_space=observation_space, 191 | **features_extractor_kwargs 192 | ) 193 | actor_arch, critic_arch = get_actor_critic_arch(net_arch) 194 | actor_def = Actor( 195 | features_extractor=features_extractor_def, 196 | observation_space=observation_space, 197 | action_space=action_space, 198 | net_arch=actor_arch, 199 | activation_fn=activation_fn, 200 | dropout=dropout 201 | ) 202 | 203 | critic_def = Critic( 204 | features_extractor=features_extractor_def, 205 | observation_space=observation_space, 206 | net_arch=critic_arch, 207 | activation_fn=activation_fn, 208 | n_critics=n_critics, 209 | dropout=dropout 210 | ) 211 | 212 | # Init dummy inputs 213 | if isinstance(observation_space, gym.spaces.Dict): 214 | observation = observation_space.sample() 215 | for key, _ in observation_space.spaces.items(): 216 | observation[key] = observation[key][np.newaxis, ...] 217 | else: 218 | observation = observation_space.sample()[np.newaxis, ...] 219 | action = action_space.sample()[np.newaxis, ...] 220 | 221 | self.rng, actor_rngs = get_basic_rngs(self.rng) 222 | actor = Model.create(actor_def, inputs=[actor_rngs, observation], tx=optax.adam(learning_rate=lr_schedule)) 223 | 224 | self.rng, critic_rngs = get_basic_rngs(self.rng) 225 | critic = Model.create( 226 | critic_def, 227 | inputs=[critic_rngs, observation, action], 228 | tx=optax.adam(learning_rate=lr_schedule) 229 | ) 230 | critic_target = Model.create(critic_def, inputs=[critic_rngs, observation, action]) 231 | self.actor = actor 232 | self.critic, self.critic_target = critic, critic_target 233 | 234 | def _predict(self, observation: jnp.ndarray, deterministic: bool) -> np.ndarray: 235 | rng, actions = sample_actions(self.rng, self.actor.apply_fn, self.actor.params, observation, deterministic) 236 | 237 | self.rng = rng 238 | return np.asarray(actions) 239 | 240 | def predict(self, observation: jnp.ndarray, deterministic: bool = True, **kwargs) -> np.ndarray: 241 | actions = self._predict(observation, deterministic) 242 | if isinstance(self.action_space, gym.spaces.Box): 243 | # Actions could be on arbitrary scale, so clip the actions to avoid 244 | # out of bound error (e.g. if sampling from a Gaussian distribution) 245 | actions = np.clip(actions, self.action_space.low, self.action_space.high) 246 | return actions, None 247 | 248 | def scale_action(self, action: np.ndarray) -> np.ndarray: 249 | """ 250 | Rescale the action from [low, high] to [-1, 1] 251 | (no need for symmetric action space) 252 | :param action: Action to scale 253 | :return: Scaled action 254 | """ 255 | low, high = self.action_space.low, self.action_space.high 256 | return 2.0 * ((action - low) / (high - low)) - 1.0 257 | 258 | def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: 259 | """ 260 | Rescale the action from [-1, 1] to [low, high] 261 | (no need for symmetric action space) 262 | :param scaled_action: Action to un-scale 263 | """ 264 | low, high = self.action_space.low, self.action_space.high 265 | return low + (0.5 * (scaled_action + 1.0) * (high - low)) 266 | 267 | 268 | MlpPolicy = SACPolicy 269 | 270 | 271 | class CnnPolicy(SACPolicy): 272 | """ 273 | Policy class (with both actor and critic) for SAC. 274 | 275 | :param observation_space: Observation space 276 | :param action_space: Action space 277 | :param lr_schedule: Learning rate schedule (could be constant) 278 | :param net_arch: The specification of the policy and value networks. 279 | :param activation_fn: Activation function 280 | :param use_sde: Whether to use State Dependent Exploration or not 281 | :param log_std_init: Initial value for the log standard deviation 282 | :param sde_net_arch: Network architecture for extracting features 283 | when using gSDE. If None, the latent features from the policy will be used. 284 | Pass an empty list to use the states as features. 285 | :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure 286 | a positive standard deviation (cf paper). It allows to keep variance 287 | above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. 288 | :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. 289 | :param features_extractor_class: Features extractor to use. 290 | :param normalize_images: Whether to normalize images or not, 291 | dividing by 255.0 (True by default) 292 | :param optimizer_class: The optimizer to use, 293 | ``th.optim.Adam`` by default 294 | :param optimizer_kwargs: Additional keyword arguments, 295 | excluding the learning rate, to pass to the optimizer 296 | :param n_critics: Number of critic networks to create. 297 | :param share_features_extractor: Whether to share or not the features extractor 298 | between the actor and the critic (this saves computation time) 299 | """ 300 | 301 | def __init__( 302 | self, 303 | rng, 304 | observation_space: gym.spaces.Space, 305 | action_space: gym.spaces.Space, 306 | lr_schedule: Schedule, 307 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, 308 | activation_fn: Type[nn.Module] = nn.relu, 309 | features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, 310 | features_extractor_kwargs: Optional[Dict[str, Any]] = None, 311 | n_critics: int = 2, 312 | dropout: float = 0.0, 313 | ): 314 | super(CnnPolicy, self).__init__( 315 | rng, 316 | observation_space, 317 | action_space, 318 | lr_schedule, 319 | net_arch, 320 | activation_fn, 321 | features_extractor_class, 322 | features_extractor_kwargs, 323 | n_critics, 324 | dropout 325 | ) 326 | 327 | 328 | class MultiInputPolicy(SACPolicy): 329 | """ 330 | Policy class (with both actor and critic) for SAC. 331 | 332 | :param observation_space: Observation space 333 | :param action_space: Action space 334 | :param lr_schedule: Learning rate schedule (could be constant) 335 | :param net_arch: The specification of the policy and value networks. 336 | :param activation_fn: Activation function 337 | :param use_sde: Whether to use State Dependent Exploration or not 338 | :param log_std_init: Initial value for the log standard deviation 339 | :param sde_net_arch: Network architecture for extracting features 340 | when using gSDE. If None, the latent features from the policy will be used. 341 | Pass an empty list to use the states as features. 342 | :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure 343 | a positive standard deviation (cf paper). It allows to keep variance 344 | above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. 345 | :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. 346 | :param features_extractor_class: Features extractor to use. 347 | :param normalize_images: Whether to normalize images or not, 348 | dividing by 255.0 (True by default) 349 | :param optimizer_class: The optimizer to use, 350 | ``th.optim.Adam`` by default 351 | :param optimizer_kwargs: Additional keyword arguments, 352 | excluding the learning rate, to pass to the optimizer 353 | :param n_critics: Number of critic networks to create. 354 | :param share_features_extractor: Whether to share or not the features extractor 355 | between the actor and the critic (this saves computation time) 356 | """ 357 | 358 | def __init__( 359 | self, 360 | rng, 361 | observation_space: gym.spaces.Space, 362 | action_space: gym.spaces.Space, 363 | lr_schedule: Schedule, 364 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, 365 | activation_fn: Type[nn.Module] = nn.relu, 366 | features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, 367 | features_extractor_kwargs: Optional[Dict[str, Any]] = None, 368 | n_critics: int = 2, 369 | dropout: float = 0.0, 370 | ): 371 | super(MultiInputPolicy, self).__init__( 372 | rng, 373 | observation_space, 374 | action_space, 375 | lr_schedule, 376 | net_arch, 377 | activation_fn, 378 | features_extractor_class, 379 | features_extractor_kwargs, 380 | n_critics, 381 | dropout 382 | ) -------------------------------------------------------------------------------- /offline_baselines_jax/common/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import platform 4 | import random 5 | from collections import deque 6 | from itertools import zip_longest 7 | from typing import Dict, Iterable, Optional, Tuple, Union 8 | 9 | import jax.numpy as jnp 10 | 11 | import gym 12 | import jax 13 | import numpy as np 14 | 15 | import offline_baselines_jax 16 | 17 | # Check if tensorboard is available for pytorch 18 | try: 19 | from torch.utils.tensorboard import SummaryWriter 20 | except ImportError: 21 | SummaryWriter = None 22 | 23 | from stable_baselines3.common.logger import Logger, configure 24 | from offline_baselines_jax.common.type_aliases import GymEnv, Schedule, TrainFreq, TrainFrequencyUnit 25 | 26 | 27 | def set_random_seed(seed: int) -> None: 28 | """ 29 | Seed the different random generators. 30 | 31 | :param seed: 32 | :param using_cuda: 33 | """ 34 | # Seed python RNG 35 | random.seed(seed) 36 | # Seed numpy RNG 37 | np.random.seed(seed) 38 | 39 | 40 | def get_basic_rngs(rng: jnp.ndarray) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]: 41 | rng, param_key, dropout_key, batch_key = jax.random.split(rng, 4) 42 | return rng, {"params": param_key, "dropout": dropout_key, "batch_stats": batch_key} 43 | 44 | 45 | # From stable baselines 46 | def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: 47 | """ 48 | Computes fraction of variance that ypred explains about y. 49 | Returns 1 - Var[y-ypred] / Var[y] 50 | 51 | interpretation: 52 | ev=0 => might as well have predicted zero 53 | ev=1 => perfect prediction 54 | ev<0 => worse than just predicting zero 55 | 56 | :param y_pred: the prediction 57 | :param y_true: the expected value 58 | :return: explained variance of ypred and y 59 | """ 60 | assert y_true.ndim == 1 and y_pred.ndim == 1 61 | var_y = np.var(y_true) 62 | return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y 63 | 64 | 65 | def get_schedule_fn(value_schedule: Union[Schedule, float, int]) -> Schedule: 66 | """ 67 | Transform (if needed) learning rate and clip range (for PPO) 68 | to callable. 69 | 70 | :param value_schedule: 71 | :return: 72 | """ 73 | # If the passed schedule is a float 74 | # create a constant function 75 | if isinstance(value_schedule, (float, int)): 76 | # Cast to float to avoid errors 77 | value_schedule = constant_fn(float(value_schedule)) 78 | else: 79 | assert callable(value_schedule) 80 | return value_schedule 81 | 82 | 83 | def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule: 84 | """ 85 | Create a function that interpolates linearly between start and end 86 | between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``. 87 | This is used in DQN for linearly annealing the exploration fraction 88 | (epsilon for the epsilon-greedy strategy). 89 | 90 | :params start: value to start with if ``progress_remaining`` = 1 91 | :params end: value to end with if ``progress_remaining`` = 0 92 | :params end_fraction: fraction of ``progress_remaining`` 93 | where end is reached e.g 0.1 then end is reached after 10% 94 | of the complete training process. 95 | :return: 96 | """ 97 | 98 | def func(progress_remaining: float) -> float: 99 | if (1 - progress_remaining) > end_fraction: 100 | return end 101 | else: 102 | return start + (1 - progress_remaining) * (end - start) / end_fraction 103 | 104 | return func 105 | 106 | 107 | def constant_fn(val: float) -> Schedule: 108 | """ 109 | Create a function that returns a constant 110 | It is useful for learning rate schedule (to avoid code duplication) 111 | 112 | :param val: 113 | :return: 114 | """ 115 | 116 | def func(_): 117 | return val 118 | 119 | return func 120 | 121 | 122 | def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int: 123 | """ 124 | Returns the latest run number for the given log name and log path, 125 | by finding the greatest number in the directories. 126 | 127 | :return: latest run number 128 | """ 129 | max_run_id = 0 130 | for path in glob.glob(f"{log_path}/{log_name}_[0-9]*"): 131 | file_name = path.split(os.sep)[-1] 132 | ext = file_name.split("_")[-1] 133 | if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: 134 | max_run_id = int(ext) 135 | return max_run_id 136 | 137 | 138 | def configure_logger( 139 | verbose: int = 1, 140 | tensorboard_log: Optional[str] = None, 141 | tb_log_name: str = "", 142 | reset_num_timesteps: bool = True, 143 | ) -> Logger: 144 | """ 145 | Configure the logger's outputs. 146 | 147 | :param verbose: the verbosity level: 0 no output, 1 info, 2 debug 148 | :param tensorboard_log: the log location for tensorboard (if None, no logging) 149 | :param tb_log_name: tensorboard log 150 | :param reset_num_timesteps: Whether the ``num_timesteps`` attribute is reset or not. 151 | It allows to continue a previous learning curve (``reset_num_timesteps=False``) 152 | or start from t=0 (``reset_num_timesteps=True``, the default). 153 | :return: The logger object 154 | """ 155 | save_path, format_strings = None, ["stdout"] 156 | 157 | if tensorboard_log is not None and SummaryWriter is None: 158 | raise ImportError("Trying to log data to tensorboard but tensorboard is not installed.") 159 | 160 | if tensorboard_log is not None and SummaryWriter is not None: 161 | latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name) 162 | if not reset_num_timesteps: 163 | # Continue training in the same directory 164 | latest_run_id -= 1 165 | save_path = os.path.join(tensorboard_log, f"{tb_log_name}_{latest_run_id + 1}") 166 | if verbose >= 1: 167 | format_strings = ["stdout", "tensorboard"] 168 | else: 169 | format_strings = ["tensorboard"] 170 | elif verbose == 0: 171 | format_strings = [""] 172 | 173 | return configure(save_path, format_strings=format_strings) 174 | 175 | 176 | def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, action_space: gym.spaces.Space) -> None: 177 | """ 178 | Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if 179 | spaces match after loading the model with given env. 180 | Checked parameters: 181 | - observation_space 182 | - action_space 183 | 184 | :param env: Environment to check for valid spaces 185 | :param observation_space: Observation space to check against 186 | :param action_space: Action space to check against 187 | """ 188 | if observation_space != env.observation_space: 189 | raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}") 190 | if action_space != env.action_space: 191 | raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}") 192 | 193 | 194 | def is_vectorized_box_observation(observation: np.ndarray, observation_space: gym.spaces.Box) -> bool: 195 | """ 196 | For box observation type, detects and validates the shape, 197 | then returns whether or not the observation is vectorized. 198 | 199 | :param observation: the input observation to validate 200 | :param observation_space: the observation space 201 | :return: whether the given observation is vectorized or not 202 | """ 203 | if observation.shape == observation_space.shape: 204 | return False 205 | elif observation.shape[1:] == observation_space.shape: 206 | return True 207 | else: 208 | raise ValueError( 209 | f"Error: Unexpected observation shape {observation.shape} for " 210 | + f"Box environment, please use {observation_space.shape} " 211 | + "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape))) 212 | ) 213 | 214 | 215 | def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], observation_space: gym.spaces.Discrete) -> bool: 216 | """ 217 | For discrete observation type, detects and validates the shape, 218 | then returns whether or not the observation is vectorized. 219 | 220 | :param observation: the input observation to validate 221 | :param observation_space: the observation space 222 | :return: whether the given observation is vectorized or not 223 | """ 224 | if isinstance(observation, int) or observation.shape == (): # A numpy array of a number, has shape empty tuple '()' 225 | return False 226 | elif len(observation.shape) == 1: 227 | return True 228 | else: 229 | raise ValueError( 230 | f"Error: Unexpected observation shape {observation.shape} for " 231 | + "Discrete environment, please use () or (n_env,) for the observation shape." 232 | ) 233 | 234 | 235 | def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: gym.spaces.MultiDiscrete) -> bool: 236 | """ 237 | For multidiscrete observation type, detects and validates the shape, 238 | then returns whether or not the observation is vectorized. 239 | 240 | :param observation: the input observation to validate 241 | :param observation_space: the observation space 242 | :return: whether the given observation is vectorized or not 243 | """ 244 | if observation.shape == (len(observation_space.nvec),): 245 | return False 246 | elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): 247 | return True 248 | else: 249 | raise ValueError( 250 | f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete " 251 | + f"environment, please use ({len(observation_space.nvec)},) or " 252 | + f"(n_env, {len(observation_space.nvec)}) for the observation shape." 253 | ) 254 | 255 | 256 | def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: gym.spaces.MultiBinary) -> bool: 257 | """ 258 | For multibinary observation type, detects and validates the shape, 259 | then returns whether or not the observation is vectorized. 260 | 261 | :param observation: the input observation to validate 262 | :param observation_space: the observation space 263 | :return: whether the given observation is vectorized or not 264 | """ 265 | if observation.shape == (observation_space.n,): 266 | return False 267 | elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: 268 | return True 269 | else: 270 | raise ValueError( 271 | f"Error: Unexpected observation shape {observation.shape} for MultiBinary " 272 | + f"environment, please use ({observation_space.n},) or " 273 | + f"(n_env, {observation_space.n}) for the observation shape." 274 | ) 275 | 276 | 277 | def is_vectorized_dict_observation(observation: np.ndarray, observation_space: gym.spaces.Dict) -> bool: 278 | """ 279 | For dict observation type, detects and validates the shape, 280 | then returns whether or not the observation is vectorized. 281 | 282 | :param observation: the input observation to validate 283 | :param observation_space: the observation space 284 | :return: whether the given observation is vectorized or not 285 | """ 286 | # We first assume that all observations are not vectorized 287 | all_non_vectorized = True 288 | for key, subspace in observation_space.spaces.items(): 289 | # This fails when the observation is not vectorized 290 | # or when it has the wrong shape 291 | if observation[key].shape != subspace.shape: 292 | all_non_vectorized = False 293 | break 294 | 295 | if all_non_vectorized: 296 | return False 297 | 298 | all_vectorized = True 299 | # Now we check that all observation are vectorized and have the correct shape 300 | for key, subspace in observation_space.spaces.items(): 301 | if observation[key].shape[1:] != subspace.shape: 302 | all_vectorized = False 303 | break 304 | 305 | if all_vectorized: 306 | return True 307 | else: 308 | # Retrieve error message 309 | error_msg = "" 310 | try: 311 | is_vectorized_observation(observation[key], observation_space.spaces[key]) 312 | except ValueError as e: 313 | error_msg = f"{e}" 314 | raise ValueError( 315 | f"There seems to be a mix of vectorized and non-vectorized observations. " 316 | f"Unexpected observation shape {observation[key].shape} for key {key} " 317 | f"of type {observation_space.spaces[key]}. {error_msg}" 318 | ) 319 | 320 | 321 | def is_vectorized_observation(observation: Union[int, np.ndarray], observation_space: gym.spaces.Space) -> bool: 322 | """ 323 | For every observation type, detects and validates the shape, 324 | then returns whether or not the observation is vectorized. 325 | 326 | :param observation: the input observation to validate 327 | :param observation_space: the observation space 328 | :return: whether the given observation is vectorized or not 329 | """ 330 | 331 | is_vec_obs_func_dict = { 332 | gym.spaces.Box: is_vectorized_box_observation, 333 | gym.spaces.Discrete: is_vectorized_discrete_observation, 334 | gym.spaces.MultiDiscrete: is_vectorized_multidiscrete_observation, 335 | gym.spaces.MultiBinary: is_vectorized_multibinary_observation, 336 | gym.spaces.Dict: is_vectorized_dict_observation, 337 | } 338 | 339 | for space_type, is_vec_obs_func in is_vec_obs_func_dict.items(): 340 | if isinstance(observation_space, space_type): 341 | return is_vec_obs_func(observation, observation_space) 342 | else: 343 | # for-else happens if no break is called 344 | raise ValueError(f"Error: Cannot determine if the observation is vectorized with the space type {observation_space}.") 345 | 346 | 347 | def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray: 348 | """ 349 | Compute the mean of an array if there is at least one element. 350 | For empty array, return NaN. It is used for logging only. 351 | 352 | :param arr: 353 | :return: 354 | """ 355 | return np.nan if len(arr) == 0 else np.mean(arr) 356 | 357 | 358 | def zip_strict(*iterables: Iterable) -> Iterable: 359 | r""" 360 | ``zip()`` function but enforces that iterables are of equal length. 361 | Raises ``ValueError`` if iterables not of equal length. 362 | Code inspired by Stackoverflow answer for question #32954486. 363 | 364 | :param \*iterables: iterables to ``zip()`` 365 | """ 366 | # As in Stackoverflow #32954486, use 367 | # new object for "empty" in case we have 368 | # Nones in iterable. 369 | sentinel = object() 370 | for combo in zip_longest(*iterables, fillvalue=sentinel): 371 | if sentinel in combo: 372 | raise ValueError("Iterables have different lengths") 373 | yield combo 374 | 375 | 376 | def should_collect_more_steps( 377 | train_freq: TrainFreq, 378 | num_collected_steps: int, 379 | num_collected_episodes: int, 380 | ) -> bool: 381 | """ 382 | Helper used in ``collect_rollouts()`` of off-policy algorithms 383 | to determine the termination condition. 384 | 385 | :param train_freq: How much experience should be collected before updating the policy. 386 | :param num_collected_steps: The number of already collected steps. 387 | :param num_collected_episodes: The number of already collected episodes. 388 | :return: Whether to continue or not collecting experience 389 | by doing rollouts of the current policy. 390 | """ 391 | if train_freq.unit == TrainFrequencyUnit.STEP: 392 | return num_collected_steps < train_freq.frequency 393 | 394 | elif train_freq.unit == TrainFrequencyUnit.EPISODE: 395 | return num_collected_episodes < train_freq.frequency 396 | 397 | else: 398 | raise ValueError( 399 | "The unit of the `train_freq` must be either TrainFrequencyUnit.STEP " 400 | f"or TrainFrequencyUnit.EPISODE not '{train_freq.unit}'!" 401 | ) 402 | 403 | 404 | def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: 405 | """ 406 | Retrieve system and python env info for the current system. 407 | 408 | :param print_info: Whether to print or not those infos 409 | :return: Dictionary summing up the version for each relevant package 410 | and a formatted string. 411 | """ 412 | env_info = { 413 | "OS": f"{platform.platform()} {platform.version()}", 414 | "Python": platform.python_version(), 415 | "Offline_Baselines_JAX": offline_baselines_jax.__version__, 416 | "JAX": jax.__version__, 417 | "Numpy": np.__version__, 418 | "Gym": gym.__version__, 419 | } 420 | env_info_str = "" 421 | for key, value in env_info.items(): 422 | env_info_str += f"{key}: {value}\n" 423 | if print_info: 424 | print(env_info_str) 425 | return env_info, env_info_str 426 | --------------------------------------------------------------------------------