├── networks ├── __init__.py ├── .DS_Store ├── README.md ├── policy.py ├── networks.py └── distributions.py ├── .DS_Store ├── .gitignore ├── images ├── .DS_Store ├── perf_plot_Ant.png ├── perf_plot_Hopper.png ├── perf_plot_Humanoid.png └── perf_plot_HalfCheetah.png ├── envs ├── README.md ├── transition.py ├── __init__.py ├── evaluate.py ├── make_env.py ├── state.py ├── utils.py ├── atari_wrappers.py ├── normalize.py ├── batched_env.py └── base.py ├── LICENSE ├── requirements.txt ├── README.md ├── reinforce.py ├── a2c.py └── ppo.py /networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matt00n/PolicyGradientsJax/HEAD/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | __pycache__/ 3 | Exp*/ 4 | *.pkl 5 | experiments/ 6 | weights/ 7 | videos/ -------------------------------------------------------------------------------- /images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matt00n/PolicyGradientsJax/HEAD/images/.DS_Store -------------------------------------------------------------------------------- /networks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matt00n/PolicyGradientsJax/HEAD/networks/.DS_Store -------------------------------------------------------------------------------- /images/perf_plot_Ant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matt00n/PolicyGradientsJax/HEAD/images/perf_plot_Ant.png -------------------------------------------------------------------------------- /images/perf_plot_Hopper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matt00n/PolicyGradientsJax/HEAD/images/perf_plot_Hopper.png -------------------------------------------------------------------------------- /images/perf_plot_Humanoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matt00n/PolicyGradientsJax/HEAD/images/perf_plot_Humanoid.png -------------------------------------------------------------------------------- /images/perf_plot_HalfCheetah.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Matt00n/PolicyGradientsJax/HEAD/images/perf_plot_HalfCheetah.png -------------------------------------------------------------------------------- /networks/README.md: -------------------------------------------------------------------------------- 1 | Logic to create policy and value networks as well as action distributions. Mostly based on [Brax](https://github.com/google/brax). -------------------------------------------------------------------------------- /envs/README.md: -------------------------------------------------------------------------------- 1 | Utils to create batched environments and apply wrappers. Mostly based on [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3). -------------------------------------------------------------------------------- /envs/transition.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import jax.numpy as jnp 4 | 5 | NestedArray = jnp.ndarray 6 | 7 | class Transition(NamedTuple): 8 | """Container for a transition.""" 9 | observation: NestedArray 10 | action: NestedArray 11 | reward: NestedArray 12 | discount: NestedArray 13 | next_observation: NestedArray 14 | extras: NestedArray = () # pytype: disable=annotation-type-mismatch # jax-ndarray -------------------------------------------------------------------------------- /networks/policy.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Protocol, Tuple, TypeVar 2 | 3 | import jax.numpy as jnp 4 | 5 | 6 | 7 | NetworkType = TypeVar('NetworkType') 8 | 9 | 10 | class Policy(Protocol): 11 | 12 | def __call__( 13 | self, 14 | observation: jnp.ndarray, 15 | key: jnp.ndarray, 16 | ) -> Tuple[jnp.ndarray, Mapping[str, Any]]: 17 | pass 18 | 19 | 20 | class NetworkFactory(Protocol[NetworkType]): 21 | 22 | def __call__( 23 | self, 24 | observation_size: int, 25 | action_size: int, 26 | preprocess_observations_fn, 27 | ) -> NetworkType: 28 | pass -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Matt00n 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gymnasium import spaces 2 | 3 | from envs.batched_env import SequencedBatchedEnv, ParallelBatchedEnv 4 | from envs.normalize import VecNormalize 5 | from envs.make_env import make_env 6 | from envs.state import State 7 | from envs.transition import Transition 8 | from envs.evaluate import RecordScores, Evaluator 9 | 10 | 11 | ATARI_ENVS = [ 12 | "adventure", 13 | "airraid", 14 | "alien", 15 | "amidar", 16 | "assault", 17 | "asterix", 18 | "asteroids", 19 | "atlantis", 20 | "bankheist", 21 | "battlezone", 22 | "beamrider", 23 | "berzerk", 24 | "bowling", 25 | "boxing", 26 | "breakout", 27 | "carnival", 28 | "centipede", 29 | "choppercommand", 30 | "crazyclimber", 31 | "defender", 32 | "demonattack", 33 | "doubledunk", 34 | "elevator_action", 35 | "enduro", 36 | "fishingderby", 37 | "freeway", 38 | "frostbite", 39 | "gopher", 40 | "gravitar", 41 | "hero", 42 | "icehockey", 43 | "jamesbond", 44 | "journeyescape", 45 | "kangaroo", 46 | "krull", 47 | "kungfumaster", 48 | "montezumarevenge", 49 | "mspacman", 50 | "namethisgame", 51 | "phoenix", 52 | "pitfall", 53 | "pong", 54 | "pooyan", 55 | "privateeye", 56 | "qbert", 57 | "riverraid", 58 | "roadrunner", 59 | "robotank", 60 | "seaquest", 61 | "skiing", 62 | "solaris", 63 | "spaceinvaders", 64 | "stargunner", 65 | "tennis", 66 | "timepilot", 67 | "tutankham", 68 | "upndown", 69 | "venture", 70 | "videopinball", 71 | "wizardofwor", 72 | "yarsrevenge", 73 | "zaxxon", 74 | ] 75 | 76 | def has_discrete_action_space(env): 77 | return isinstance(env.action_space, spaces.Discrete) 78 | 79 | 80 | def is_atari_env(env_id): 81 | env_id = env_id.lower() 82 | for env in ATARI_ENVS: 83 | if env in env_id: 84 | return True 85 | return False -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | ale-py==0.8.1 3 | AutoROM==0.4.2 4 | AutoROM.accept-rom-license==0.6.1 5 | cachetools==5.3.1 6 | certifi==2023.7.22 7 | charset-normalizer==3.3.0 8 | chex==0.1.83 9 | click==8.1.7 10 | cloudpickle==2.2.1 11 | contourpy==1.1.1 12 | cycler==0.12.0 13 | dm-control==1.0.14 14 | dm-env==1.6 15 | dm-haiku==0.0.10 16 | dm-tree==0.1.8 17 | etils==1.5.0 18 | Farama-Notifications==0.0.4 19 | filelock==3.12.4 20 | flatbuffers==23.5.26 21 | flax==0.7.4 22 | fonttools==4.43.0 23 | fsspec==2023.9.2 24 | gast==0.5.4 25 | glfw==2.6.2 26 | google-auth==2.23.2 27 | google-auth-oauthlib==1.0.0 28 | google-pasta==0.2.0 29 | grpcio==1.59.0 30 | gymnasium==0.29.1 31 | h5py==3.9.0 32 | idna==3.4 33 | imageio==2.31.5 34 | importlib-resources==6.1.0 35 | jax==0.4.17 36 | jaxlib==0.4.14 37 | Jinja2==3.1.2 38 | jmp==0.0.4 39 | keras==2.14.0 40 | kiwisolver==1.4.5 41 | labmaze==1.0.6 42 | libclang==16.0.6 43 | lxml==4.9.3 44 | Markdown==3.4.4 45 | markdown-it-py==3.0.0 46 | MarkupSafe==2.1.3 47 | matplotlib==3.8.0 48 | mdurl==0.1.2 49 | ml-dtypes==0.2.0 50 | mpmath==1.3.0 51 | msgpack==1.0.7 52 | mujoco==2.3.7 53 | networkx==3.1 54 | numpy==1.26.0 55 | oauthlib==3.2.2 56 | opencv-python==4.8.1.78 57 | opt-einsum==3.3.0 58 | optax==0.1.7 59 | orbax-checkpoint==0.4.1 60 | Pillow==10.0.1 61 | protobuf==4.24.3 62 | pyasn1==0.5.0 63 | pyasn1-modules==0.3.0 64 | pygame==2.5.2 65 | PyOpenGL==3.1.7 66 | pyparsing==3.1.1 67 | PyYAML==6.0.1 68 | requests==2.31.0 69 | requests-oauthlib==1.3.1 70 | rich==13.6.0 71 | rsa==4.9 72 | scipy==1.11.3 73 | Shimmy==0.2.1 74 | sympy==1.12 75 | tabulate==0.9.0 76 | tensorboard==2.14.1 77 | tensorboard-data-server==0.7.1 78 | tensorflow==2.14.0 79 | tensorflow-estimator==2.14.0 80 | tensorflow-io-gcs-filesystem==0.34.0 81 | tensorflow-macos==2.14.0 82 | tensorflow-probability==0.22.0 83 | tensorstore==0.1.45 84 | termcolor==2.3.0 85 | toolz==0.12.0 86 | torch==2.1.0 87 | tqdm==4.66.1 88 | typing_extensions==4.5.0 89 | urllib3==2.0.6 90 | Werkzeug==3.0.0 91 | wrapt==1.14.1 92 | -------------------------------------------------------------------------------- /envs/evaluate.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from typing import List, Tuple 3 | 4 | import gymnasium as gym 5 | 6 | from envs.base import VecEnv, VecEnvWrapper 7 | from envs.state import State 8 | 9 | 10 | 11 | class RecordScores(gym.Wrapper): 12 | """ 13 | Records episode scores. 14 | """ 15 | def __init__(self, env: gym.Env): 16 | gym.Wrapper.__init__(self, env) 17 | self.episode_return = 0 18 | self.episode_length = 0 19 | 20 | def step(self, action): 21 | obs, reward, term, trunc, info = self.env.step(action) 22 | self.episode_return += reward 23 | self.episode_length += 1 24 | if term or trunc: 25 | info['episode_return'] = self.episode_return 26 | info['episode_length'] = self.episode_length 27 | return obs, reward, term, trunc, info 28 | 29 | def reset(self, seed=None): 30 | self.episode_return = 0 31 | self.episode_length = 0 32 | return self.env.reset(seed=seed) 33 | 34 | 35 | class Evaluator(VecEnvWrapper): 36 | """ 37 | Accumulates episode metrics generated by RecordScores and tracks 38 | the number of completed episodes. 39 | 40 | :param venv: the vectorized environment to wrap. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | venv: VecEnv, 46 | ): 47 | VecEnvWrapper.__init__(self, venv) 48 | 49 | self.returns = [] 50 | self.episode_lengths = [] 51 | 52 | 53 | def step_wait(self) -> Tuple[State, List]: 54 | """ 55 | Apply sequence of actions to sequence of environments 56 | actions -> (observations, rewards, dones) 57 | 58 | where ``dones`` is a boolean vector indicating whether each element is new. 59 | """ 60 | env_state = self.venv.step_wait() 61 | for info in env_state.info: 62 | if 'episode_return' in info: 63 | self.returns.append(info['episode_return']) 64 | self.episode_lengths.append(info['episode_length']) 65 | return env_state 66 | 67 | def evaluate(self) -> Tuple[List, List]: 68 | """Report recorded metrics and reset them internally.""" 69 | returns = copy(self.returns) 70 | episode_lengths = copy(self.episode_lengths) 71 | self.returns = [] 72 | self.episode_lengths = [] 73 | return returns, episode_lengths 74 | 75 | def reset(self) -> State: 76 | """ 77 | Reset all environments 78 | :return: first observation of the episode 79 | """ 80 | return self.venv.reset() 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On-Policy Policy Gradient Algorithms in JAX 2 | 3 | This Deep Reinforcement Learning repository contains the most prominent **On-Policy Policy Gradient** Algorithms. 4 | All algorithms are implemented in **JAX**. Our implementations are based on [Brax's](https://github.com/google/brax) implementation of PPO. We use [Brax's](https://github.com/google/brax) logic for policy networks and distributions and [Stable Baselines3's](https://github.com/DLR-RM/stable-baselines3) environment infrastructure to create batched environments. Inspired by [CleanRL](https://github.com/vwxyzjn/cleanrl), we 5 | provide all algorithm logic including hyperparameters in a single file. However, for efficiency we have joint files for creating networks and distributions. 6 | 7 | 8 | ## Algorithms 9 | 10 | We implemented the following algorithms in JAX: 11 | * [REINFORCE](https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf) 12 | * [Advantage Actor-Critic (A2C)](https://arxiv.org/abs/1602.01783) 13 | * [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477) 14 | * [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347) 15 | * [V-MPO*](https://arxiv.org/abs/1909.12238) 16 | 17 | You can read more about these algorithms in our upcoming comprehensive overview of Policy Gradient Algorithms. 18 | 19 | *on-policy variant of [Maximum a Posteriori Policy Optimization (MPO)](https://arxiv.org/abs/1806.06920) 20 | 21 | 22 | ## Benchmark Results 23 | 24 | We report the performance of our implementations on common MuJoCo environments (v4), interfaced through [Gymnasium](https://gymnasium.farama.org). 25 | 26 | |![](/images/perf_plot_HalfCheetah.png) | ![](/images/perf_plot_Ant.png)| 27 | :-------------------------:|:-------------------------: 28 | |![](/images/perf_plot_Humanoid.png) | ![](/images/perf_plot_Hopper.png)| 29 | 30 | 31 | 32 | ## Get started 33 | 34 | Prerequisites: 35 | * Tested with Python ==3.11.6 36 | * See requirements.txt for further dependencies (Note that that file bloated, not all libraries are actually needed.). 37 | 38 | To run the algorithms locally, simply run the respective python file: 39 | 40 | ```bash 41 | python ppo.py 42 | ``` 43 | 44 | 45 | ## Citing PolicyGradientsJax 46 | 47 | If you use this repository in your work or find it useful, please cite our [paper](https://arxiv.org/abs/2401.13662): 48 | 49 | ```bibtex 50 | @article{lehmann2024definitive, 51 | title={The Definitive Guide to Policy Gradients in Deep Reinforcement Learning: Theory, Algorithms and Implementations}, 52 | author={Matthias Lehmann}, 53 | year={2024}, 54 | eprint={2401.13662}, 55 | archivePrefix={arXiv}, 56 | primaryClass={cs.LG} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /envs/make_env.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Union 2 | 3 | import gymnasium as gym 4 | 5 | from envs.atari_wrappers import ( 6 | ClipRewardEnv, 7 | EpisodicLifeEnv, 8 | FireResetEnv, 9 | MaxAndSkipEnv, 10 | NoopResetEnv, 11 | ) 12 | from envs.batched_env import SequencedBatchedEnv, ParallelBatchedEnv 13 | from envs.evaluate import RecordScores, Evaluator 14 | from envs.normalize import VecNormalize 15 | 16 | 17 | 18 | def make_env( 19 | env_id: Union[str, Any], 20 | num_envs: int, 21 | parallel: bool = True, 22 | clip_actions: bool = False, 23 | norm_obs: bool = False, 24 | norm_reward: bool = False, 25 | clip_obs: float = 10., 26 | clip_rewards: float = 10., 27 | gamma: float = 0.99, 28 | epsilon: float = 1e-8, 29 | norm_obs_keys: Optional[List[str]] = None, 30 | evaluate: bool = False, 31 | is_atari: bool = False, 32 | capture_video: bool = False 33 | ) -> Union[SequencedBatchedEnv, ParallelBatchedEnv]: 34 | 35 | # TODO seed env 36 | # TODO handle atari 37 | def f(): 38 | if is_atari: 39 | env_kwargs = { 40 | "repeat_action_probability": 0., # 0.25, 41 | "full_action_space": False, 42 | "frameskip": 1, 43 | } 44 | else: 45 | env_kwargs = {} 46 | if capture_video: 47 | env_kwargs['render_mode'] = 'rgb_array' 48 | if isinstance(env_id, str): 49 | env = gym.make(env_id, **env_kwargs) # TODO env kwards 50 | else: 51 | env = env_id 52 | if capture_video: 53 | env = gym.wrappers.RecordVideo(env, "./videos", episode_trigger=lambda x: x%10==1) 54 | if evaluate: 55 | env = RecordScores(env) 56 | # env = gym.wrappers.RecordEpisodeStatistics(env, 10) 57 | if is_atari: 58 | env = NoopResetEnv(env, noop_max=30) 59 | env = MaxAndSkipEnv(env, skip=4) 60 | if not evaluate: 61 | env = EpisodicLifeEnv(env) 62 | if "FIRE" in env.unwrapped.get_action_meanings(): 63 | env = FireResetEnv(env) 64 | if not evaluate: 65 | env = ClipRewardEnv(env) 66 | env = gym.wrappers.ResizeObservation(env, (84, 84)) 67 | env = gym.wrappers.GrayScaleObservation(env) 68 | env = gym.wrappers.FrameStack(env, 4) 69 | 70 | if clip_actions: 71 | env = gym.wrappers.ClipAction(env) 72 | return env 73 | 74 | if parallel: 75 | envs = ParallelBatchedEnv([f for _ in range(num_envs)]) 76 | else: 77 | envs = SequencedBatchedEnv([f for _ in range(num_envs)]) 78 | 79 | 80 | if norm_obs or norm_reward: 81 | envs = VecNormalize( 82 | venv=envs, 83 | norm_obs=norm_obs, 84 | norm_reward=norm_reward, 85 | clip_obs=clip_obs, 86 | clip_reward=clip_rewards, 87 | gamma=gamma, 88 | epsilon=epsilon, 89 | norm_obs_keys=norm_obs_keys, 90 | ) 91 | 92 | if evaluate: 93 | envs = Evaluator(envs) 94 | 95 | return envs 96 | -------------------------------------------------------------------------------- /envs/state.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import functools 4 | from typing import Any, Dict, Optional, Sequence, Union 5 | 6 | from flax import struct 7 | import jax 8 | from jax import numpy as jp 9 | from jax import vmap 10 | from jax.tree_util import tree_map 11 | 12 | 13 | 14 | @struct.dataclass 15 | class Base: 16 | """Base functionality extending all brax types. 17 | 18 | These methods allow for brax types to be operated like arrays/matrices. 19 | """ 20 | 21 | def __add__(self, o: Any) -> Any: 22 | return tree_map(lambda x, y: x + y, self, o) 23 | 24 | def __sub__(self, o: Any) -> Any: 25 | return tree_map(lambda x, y: x - y, self, o) 26 | 27 | def __mul__(self, o: Any) -> Any: 28 | return tree_map(lambda x: x * o, self) 29 | 30 | def __neg__(self) -> Any: 31 | return tree_map(lambda x: -x, self) 32 | 33 | def __truediv__(self, o: Any) -> Any: 34 | return tree_map(lambda x: x / o, self) 35 | 36 | def reshape(self, shape: Sequence[int]) -> Any: 37 | return tree_map(lambda x: x.reshape(shape), self) 38 | 39 | def select(self, o: Any, cond: jp.ndarray) -> Any: 40 | return tree_map(lambda x, y: (x.T * cond + y.T * (1 - cond)).T, self, o) 41 | 42 | def slice(self, beg: int, end: int) -> Any: 43 | return tree_map(lambda x: x[beg:end], self) 44 | 45 | def take(self, i, axis=0) -> Any: 46 | return tree_map(lambda x: jp.take(x, i, axis=axis, mode='wrap'), self) 47 | 48 | def concatenate(self, *others: Any, axis: int = 0) -> Any: 49 | return tree_map(lambda *x: jp.concatenate(x, axis=axis), self, *others) 50 | 51 | def index_set( 52 | self, idx: Union[jp.ndarray, Sequence[jp.ndarray]], o: Any 53 | ) -> Any: 54 | return tree_map(lambda x, y: x.at[idx].set(y), self, o) 55 | 56 | def index_sum( 57 | self, idx: Union[jp.ndarray, Sequence[jp.ndarray]], o: Any 58 | ) -> Any: 59 | return tree_map(lambda x, y: x.at[idx].add(y), self, o) 60 | 61 | def vmap(self, in_axes=0, out_axes=0): 62 | """Returns an object that vmaps each follow-on instance method call.""" 63 | 64 | # TODO: i think this is kinda handy, but maybe too clever? 65 | 66 | outer_self = self 67 | 68 | class VmapField: 69 | """Returns instance method calls as vmapped.""" 70 | 71 | def __init__(self, in_axes, out_axes): 72 | self.in_axes = [in_axes] 73 | self.out_axes = [out_axes] 74 | 75 | def vmap(self, in_axes=0, out_axes=0): 76 | self.in_axes.append(in_axes) 77 | self.out_axes.append(out_axes) 78 | return self 79 | 80 | def __getattr__(self, attr): 81 | fun = getattr(outer_self.__class__, attr) 82 | # load the stack from the bottom up 83 | vmap_order = reversed(list(zip(self.in_axes, self.out_axes))) 84 | for in_axes, out_axes in vmap_order: 85 | fun = vmap(fun, in_axes=in_axes, out_axes=out_axes) 86 | fun = functools.partial(fun, outer_self) 87 | return fun 88 | 89 | return VmapField(in_axes, out_axes) 90 | 91 | def tree_replace( 92 | self, params: Dict[str, Optional[jax.typing.ArrayLike]] 93 | ) -> 'Base': 94 | """Creates a new object with parameters set. 95 | 96 | Args: 97 | params: a dictionary of key value pairs to replace 98 | 99 | Returns: 100 | data clas with new values 101 | 102 | Example: 103 | If a system has 3 links, the following code replaces the mass 104 | of each link in the System: 105 | >>> sys = sys.tree_replace( 106 | >>> {'link.inertia.mass', jp.array([1.0, 1.2, 1.3])}) 107 | """ 108 | new = self 109 | for k, v in params.items(): 110 | new = _tree_replace(new, k.split('.'), v) 111 | return new 112 | 113 | @property 114 | def T(self): # pylint:disable=invalid-name 115 | return tree_map(lambda x: x.T, self) 116 | 117 | 118 | def _tree_replace( 119 | base: Base, 120 | attr: Sequence[str], 121 | val: Optional[jax.typing.ArrayLike], 122 | ) -> Base: 123 | """Sets attributes in a struct.dataclass with values.""" 124 | if not attr: 125 | return base 126 | 127 | # special case for List attribute 128 | if len(attr) > 1 and isinstance(getattr(base, attr[0]), list): 129 | lst = copy.deepcopy(getattr(base, attr[0])) 130 | 131 | for i, g in enumerate(lst): 132 | if not hasattr(g, attr[1]): 133 | continue 134 | v = val if not hasattr(val, '__iter__') else val[i] 135 | lst[i] = _tree_replace(g, attr[1:], v) 136 | 137 | return base.replace(**{attr[0]: lst}) 138 | 139 | if len(attr) == 1: 140 | return base.replace(**{attr[0]: val}) 141 | 142 | return base.replace( 143 | **{attr[0]: _tree_replace(getattr(base, attr[0]), attr[1:], val)} 144 | ) 145 | 146 | 147 | @struct.dataclass 148 | class State(Base): 149 | """Environment state for training and inference.""" 150 | 151 | obs: jp.ndarray 152 | reward: jp.ndarray 153 | done: jp.ndarray 154 | metrics: Dict[str, jp.ndarray] = struct.field(default_factory=dict) 155 | info: Dict[str, Any] = struct.field(default_factory=dict) -------------------------------------------------------------------------------- /networks/networks.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Any, Callable, Sequence 3 | 4 | from flax import linen 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | 9 | 10 | ActivationFn = Callable[[jnp.ndarray], jnp.ndarray] 11 | Initializer = Callable[..., Any] 12 | 13 | 14 | @dataclasses.dataclass 15 | class FeedForwardNetwork: 16 | init: Callable[..., Any] 17 | apply: Callable[..., Any] 18 | 19 | 20 | class MLP(linen.Module): 21 | """MLP module.""" 22 | layer_sizes: Sequence[int] 23 | activation: ActivationFn = linen.relu 24 | kernel_init: Initializer = jax.nn.initializers.lecun_uniform() 25 | activate_final: bool = False 26 | bias: bool = True 27 | 28 | @linen.compact 29 | def __call__(self, data: jnp.ndarray): 30 | hidden = data 31 | for i, hidden_size in enumerate(self.layer_sizes): 32 | hidden = linen.Dense( 33 | hidden_size, 34 | name=f'hidden_{i}', 35 | kernel_init=self.kernel_init, 36 | use_bias=self.bias)( 37 | hidden) 38 | if i != len(self.layer_sizes) - 1 or self.activate_final: 39 | hidden = self.activation(hidden) 40 | return hidden 41 | 42 | 43 | class AtariTorso(linen.Module): 44 | """ConvNet Feature Extractor.""" 45 | layer_sizes: Sequence[int] = (512,) 46 | activation: ActivationFn = linen.relu 47 | kernel_init: Initializer = jax.nn.initializers.orthogonal(jnp.sqrt(2)) 48 | bias: bool = True 49 | 50 | @linen.compact 51 | def __call__(self, data: jnp.ndarray): 52 | hidden = jnp.moveaxis(data, -3, -1) # jnp.transpose(data, (0, 2, 3, 1)) 53 | hidden = hidden / (255.0) 54 | hidden = linen.Conv( 55 | 32, 56 | kernel_size=(8, 8), 57 | strides=(4, 4), 58 | padding="VALID", 59 | kernel_init=self.kernel_init, 60 | bias_init=jax.nn.initializers.constant(0.0), 61 | )(hidden) 62 | hidden = self.activation(hidden) 63 | hidden = linen.Conv( 64 | 64, 65 | kernel_size=(4, 4), 66 | strides=(2, 2), 67 | padding="VALID", 68 | kernel_init=self.kernel_init, 69 | bias_init=jax.nn.initializers.constant(0.0), 70 | )(hidden) 71 | hidden = self.activation(hidden) 72 | hidden = linen.Conv( 73 | 64, 74 | kernel_size=(3, 3), 75 | strides=(1, 1), 76 | padding="VALID", 77 | kernel_init=self.kernel_init, 78 | bias_init=jax.nn.initializers.constant(0.0), 79 | )(hidden) 80 | hidden = self.activation(hidden) 81 | hidden = hidden.reshape(hidden.shape[:-3] + (-1,)) 82 | hidden = linen.Dense(512, 83 | kernel_init=self.kernel_init, 84 | bias_init=jax.nn.initializers.constant(0.0) 85 | )(hidden) 86 | for i, hidden_size in enumerate(self.layer_sizes): 87 | hidden = linen.Dense( 88 | hidden_size, 89 | name=f'hidden_{i}', 90 | kernel_init=self.kernel_init, 91 | bias_init=jax.nn.initializers.constant(0.0), 92 | use_bias=self.bias)( 93 | hidden) 94 | hidden = self.activation(hidden) 95 | return hidden 96 | 97 | 98 | def make_atari_feature_extractor( 99 | obs_size: int, 100 | hidden_layer_sizes: Sequence[int] = (256, 256), 101 | activation: ActivationFn = linen.relu 102 | ) -> FeedForwardNetwork: 103 | """Creates a CNN feature extractor.""" 104 | feature_extractor = AtariTorso( 105 | layer_sizes=list(hidden_layer_sizes), 106 | activation=activation, 107 | ) 108 | 109 | def apply(policy_params, obs): 110 | return feature_extractor.apply(policy_params, obs) 111 | 112 | dummy_obs = jnp.zeros((1,) + obs_size) 113 | return FeedForwardNetwork( 114 | init=lambda key: feature_extractor.init(key, dummy_obs), apply=apply) 115 | 116 | def make_policy_network( 117 | param_size: int, 118 | obs_size: int, 119 | hidden_layer_sizes: Sequence[int] = (256, 256), 120 | activation: ActivationFn = linen.relu) -> FeedForwardNetwork: 121 | """Creates a policy network.""" 122 | policy_module = MLP( 123 | layer_sizes=list(hidden_layer_sizes) + [param_size], 124 | activation=activation, 125 | kernel_init=jax.nn.initializers.lecun_uniform()) 126 | 127 | def apply(policy_params, obs): 128 | return policy_module.apply(policy_params, obs) 129 | 130 | dummy_obs = jnp.zeros((1, obs_size)) 131 | return FeedForwardNetwork( 132 | init=lambda key: policy_module.init(key, dummy_obs), apply=apply) 133 | 134 | 135 | def make_value_network( 136 | obs_size: int, 137 | hidden_layer_sizes: Sequence[int] = (256, 256), 138 | activation: ActivationFn = linen.relu) -> FeedForwardNetwork: 139 | """Creates a policy network.""" 140 | value_module = MLP( 141 | layer_sizes=list(hidden_layer_sizes) + [1], 142 | activation=activation, 143 | kernel_init=jax.nn.initializers.lecun_uniform()) 144 | 145 | def apply(policy_params, obs): 146 | return jnp.squeeze(value_module.apply(policy_params, obs), axis=-1) 147 | 148 | dummy_obs = jnp.zeros((1, obs_size)) 149 | return FeedForwardNetwork( 150 | init=lambda key: value_module.init(key, dummy_obs), apply=apply) -------------------------------------------------------------------------------- /envs/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Any, Dict, List, Tuple 3 | import warnings 4 | 5 | from gymnasium import spaces 6 | import numpy as np 7 | 8 | from envs.base import VecEnvObs 9 | 10 | 11 | 12 | def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None: 13 | """ 14 | If the spaces are Box, check that they have the same shape. 15 | 16 | If the spaces are Dict, it recursively checks the subspaces. 17 | 18 | :param space1: Space 19 | :param space2: Other space 20 | """ 21 | if isinstance(space1, spaces.Dict): 22 | assert isinstance(space2, spaces.Dict), "spaces must be of the same type" 23 | assert space1.spaces.keys() == space2.spaces.keys(), "spaces must have the same keys" 24 | for key in space1.spaces.keys(): 25 | check_shape_equal(space1.spaces[key], space2.spaces[key]) 26 | elif isinstance(space1, spaces.Box): 27 | assert space1.shape == space2.shape, "spaces must have the same shape" 28 | 29 | 30 | def check_for_nested_spaces(obs_space: spaces.Space) -> None: 31 | """ 32 | Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples). 33 | If so, raise an Exception informing that there is no support for this. 34 | 35 | :param obs_space: an observation space 36 | """ 37 | if isinstance(obs_space, (spaces.Dict, spaces.Tuple)): 38 | sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces 39 | for sub_space in sub_spaces: 40 | if isinstance(sub_space, (spaces.Dict, spaces.Tuple)): 41 | raise NotImplementedError( 42 | "Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)." 43 | ) 44 | 45 | 46 | def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 47 | """ 48 | Deep-copy a dict of numpy arrays. 49 | 50 | :param obs: a dict of numpy arrays. 51 | :return: a dict of copied numpy arrays. 52 | """ 53 | assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" 54 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) 55 | 56 | 57 | def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: 58 | """ 59 | Convert an internal representation raw_obs into the appropriate type 60 | specified by space. 61 | 62 | :param obs_space: an observation space. 63 | :param obs_dict: a dict of numpy arrays. 64 | :return: returns an observation of the same type as space. 65 | If space is Dict, function is identity; if space is Tuple, converts dict to Tuple; 66 | otherwise, space is unstructured and returns the value raw_obs[None]. 67 | """ 68 | if isinstance(obs_space, spaces.Dict): 69 | return obs_dict 70 | elif isinstance(obs_space, spaces.Tuple): 71 | assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" 72 | return tuple(obs_dict[i] for i in range(len(obs_space.spaces))) 73 | else: 74 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" 75 | return obs_dict[None] 76 | 77 | 78 | def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]: 79 | """ 80 | Get dict-structured information about a gym.Space. 81 | 82 | Dict spaces are represented directly by their dict of subspaces. 83 | Tuple spaces are converted into a dict with keys indexing into the tuple. 84 | Unstructured spaces are represented by {None: obs_space}. 85 | 86 | :param obs_space: an observation space 87 | :return: A tuple (keys, shapes, dtypes): 88 | keys: a list of dict keys. 89 | shapes: a dict mapping keys to shapes. 90 | dtypes: a dict mapping keys to dtypes. 91 | """ 92 | check_for_nested_spaces(obs_space) 93 | if isinstance(obs_space, spaces.Dict): 94 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" 95 | subspaces = obs_space.spaces 96 | elif isinstance(obs_space, spaces.Tuple): 97 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment] 98 | else: 99 | assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" 100 | subspaces = {None: obs_space} # type: ignore[assignment] 101 | keys = [] 102 | shapes = {} 103 | dtypes = {} 104 | for key, box in subspaces.items(): 105 | keys.append(key) 106 | shapes[key] = box.shape 107 | dtypes[key] = box.dtype 108 | return keys, shapes, dtypes 109 | 110 | 111 | def is_image_space_channels_first(observation_space: spaces.Box) -> bool: 112 | """ 113 | Check if an image observation space (see ``is_image_space``) 114 | is channels-first (CxHxW, True) or channels-last (HxWxC, False). 115 | 116 | Use a heuristic that channel dimension is the smallest of the three. 117 | If second dimension is smallest, raise an exception (no support). 118 | 119 | :param observation_space: 120 | :return: True if observation space is channels-first image, False if channels-last. 121 | """ 122 | smallest_dimension = np.argmin(observation_space.shape).item() 123 | if smallest_dimension == 1: 124 | warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.") 125 | return smallest_dimension == 0 126 | 127 | 128 | def is_image_space( 129 | observation_space: spaces.Space, 130 | check_channels: bool = False, 131 | normalized_image: bool = False, 132 | ) -> bool: 133 | """ 134 | Check if a observation space has the shape, limits and dtype 135 | of a valid image. 136 | The check is conservative, so that it returns False if there is a doubt. 137 | 138 | Valid images: RGB, RGBD, GrayScale with values in [0, 255] 139 | 140 | :param observation_space: 141 | :param check_channels: Whether to do or not the check for the number of channels. 142 | e.g., with frame-stacking, the observation space may have more channels than expected. 143 | :param normalized_image: Whether to assume that the image is already normalized 144 | or not (this disables dtype and bounds checks): when True, it only checks that 145 | the space is a Box and has 3 dimensions. 146 | Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]). 147 | :return: 148 | """ 149 | check_dtype = check_bounds = not normalized_image 150 | if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3: 151 | # Check the type 152 | if check_dtype and observation_space.dtype != np.uint8: 153 | return False 154 | 155 | # Check the value range 156 | incorrect_bounds = np.any(observation_space.low != 0) or np.any(observation_space.high != 255) 157 | if check_bounds and incorrect_bounds: 158 | return False 159 | 160 | # Skip channels check 161 | if not check_channels: 162 | return True 163 | # Check the number of channels 164 | if is_image_space_channels_first(observation_space): 165 | n_channels = observation_space.shape[0] 166 | else: 167 | n_channels = observation_space.shape[-1] 168 | # GrayScale, RGB, RGBD 169 | return n_channels in [1, 3, 4] 170 | return False -------------------------------------------------------------------------------- /networks/distributions.py: -------------------------------------------------------------------------------- 1 | """Probability distributions in JAX.""" 2 | 3 | import abc 4 | import jax 5 | import jax.numpy as jnp 6 | import tensorflow_probability.substrates.jax.distributions as tfd 7 | 8 | 9 | class ParametricDistribution(abc.ABC): 10 | """Abstract class for parametric (action) distribution.""" 11 | 12 | def __init__(self, param_size, postprocessor, event_ndims, reparametrizable): 13 | """Abstract class for parametric (action) distribution. 14 | 15 | Specifies how to transform distribution parameters (i.e. actor output) 16 | into a distribution over actions. 17 | 18 | Args: 19 | param_size: size of the parameters for the distribution 20 | postprocessor: bijector which is applied after sampling (in practice, it's 21 | tanh or identity) 22 | event_ndims: rank of the distribution sample (i.e. action) 23 | reparametrizable: is the distribution reparametrizable 24 | """ 25 | self._param_size = param_size 26 | self._postprocessor = postprocessor 27 | self._event_ndims = event_ndims # rank of events 28 | self._reparametrizable = reparametrizable 29 | assert event_ndims in [0, 1] 30 | 31 | @abc.abstractmethod 32 | def create_dist(self, parameters): 33 | """Creates distribution from parameters.""" 34 | pass 35 | 36 | @property 37 | def param_size(self): 38 | return self._param_size 39 | 40 | @property 41 | def reparametrizable(self): 42 | return self._reparametrizable 43 | 44 | def postprocess(self, event): 45 | return self._postprocessor.forward(event) 46 | 47 | def inverse_postprocess(self, event): 48 | return self._postprocessor.inverse(event) 49 | 50 | def sample_no_postprocessing(self, parameters, seed): 51 | return self.create_dist(parameters).sample(seed=seed) 52 | 53 | def sample(self, parameters, seed): 54 | """Returns a sample from the postprocessed distribution.""" 55 | return self.postprocess(self.sample_no_postprocessing(parameters, seed)) 56 | 57 | def mode(self, parameters): 58 | """Returns the mode of the postprocessed distribution.""" 59 | return self.postprocess(self.create_dist(parameters).mode()) 60 | 61 | def log_prob(self, parameters, actions): 62 | """Compute the log probability of actions.""" 63 | dist = self.create_dist(parameters) 64 | log_probs = dist.log_prob(actions) 65 | log_probs -= self._postprocessor.forward_log_det_jacobian(actions) 66 | if self._event_ndims == 1: 67 | log_probs = jnp.sum(log_probs, axis=-1) # sum over action dimension 68 | return log_probs 69 | 70 | def entropy(self, parameters, seed): 71 | """Return the entropy of the given distribution.""" 72 | dist = self.create_dist(parameters) 73 | entropy = dist.entropy() 74 | entropy += self._postprocessor.forward_log_det_jacobian( 75 | dist.sample(seed=seed)) 76 | if self._event_ndims == 1: 77 | entropy = jnp.sum(entropy, axis=-1) 78 | return entropy 79 | 80 | def kl_divergence(self, p_parameters, q_parameters): 81 | """Return the KL divergence of the given distributions.""" 82 | p_distribution = self.create_dist(p_parameters) 83 | q_distribution = self.create_dist(q_parameters) 84 | 85 | diff_log_scale = jnp.log(p_distribution.scale) - jnp.log(q_distribution.scale) 86 | return ( 87 | 0.5 * jnp.square(p_distribution.loc / q_distribution.scale - q_distribution.loc / q_distribution.scale) + 88 | 0.5 * (jnp.exp(2. * diff_log_scale) - 1) - 89 | diff_log_scale) 90 | 91 | def kl_divergence_mu(self, p_parameters, q_parameters): 92 | """Return the decoupled KL divergence for the mean of the given distributions.""" 93 | p_distribution = self.create_dist(p_parameters) 94 | q_distribution = self.create_dist(q_parameters) 95 | 96 | diff_loc = q_distribution.loc - p_distribution.loc 97 | return 0.5 * jnp.sum(diff_loc / p_distribution.scale * diff_loc, axis=-1) # transposing needed? 98 | 99 | def kl_divergence_sigma(self, p_parameters, q_parameters): 100 | """Return the decoupled KL divergence for the covariance of the given distributions.""" 101 | p_distribution = self.create_dist(p_parameters) 102 | q_distribution = self.create_dist(q_parameters) 103 | 104 | return 0.5 * (jnp.sum(p_distribution.scale / q_distribution.scale, axis=-1) - 105 | q_distribution.scale.shape[-1] + 106 | jnp.prod(q_distribution.scale, axis=-1) / jnp.prod(p_distribution.scale, axis=-1)) 107 | 108 | 109 | class NormalDistribution: 110 | """Normal distribution.""" 111 | 112 | def __init__(self, loc, scale): 113 | self.loc = loc 114 | self.scale = scale 115 | 116 | def sample(self, seed): 117 | return jax.random.normal(seed, shape=self.loc.shape) * self.scale + self.loc 118 | 119 | def mode(self): 120 | return self.loc 121 | 122 | def log_prob(self, x): 123 | log_unnormalized = -0.5 * jnp.square(x / self.scale - self.loc / self.scale) 124 | log_normalization = 0.5 * jnp.log(2. * jnp.pi) + jnp.log(self.scale) 125 | return log_unnormalized - log_normalization 126 | 127 | def entropy(self): 128 | log_normalization = 0.5 * jnp.log(2. * jnp.pi) + jnp.log(self.scale) 129 | entropy = 0.5 + log_normalization 130 | return entropy * jnp.ones_like(self.loc) 131 | 132 | 133 | class TanhBijector: 134 | """Tanh Bijector.""" 135 | 136 | def forward(self, x): 137 | return jnp.tanh(x) 138 | 139 | def inverse(self, y): 140 | return jnp.arctanh(y) 141 | 142 | def forward_log_det_jacobian(self, x): 143 | return 2. * (jnp.log(2.) - x - jax.nn.softplus(-2. * x)) 144 | 145 | 146 | class NormalTanhDistribution(ParametricDistribution): 147 | """Normal distribution followed by tanh.""" 148 | 149 | def __init__(self, event_size, min_std=0.001): 150 | """Initialize the distribution. 151 | 152 | Args: 153 | event_size: the size of events (i.e. actions). 154 | min_std: minimum std for the gaussian. 155 | """ 156 | # We apply tanh to gaussian actions to bound them. 157 | # Normally we would use TransformedDistribution to automatically 158 | # apply tanh to the distribution. 159 | # We can't do it here because of tanh saturation 160 | # which would make log_prob computations impossible. Instead, most 161 | # of the code operate on pre-tanh actions and we take the postprocessor 162 | # jacobian into account in log_prob computations. 163 | super().__init__( 164 | param_size=2 * event_size, 165 | postprocessor=TanhBijector(), 166 | event_ndims=1, 167 | reparametrizable=True) 168 | self._min_std = min_std 169 | 170 | def create_dist(self, parameters): 171 | loc, scale = jnp.split(parameters, 2, axis=-1) 172 | scale = jax.nn.softplus(scale) + self._min_std 173 | return NormalDistribution(loc=loc, scale=scale) 174 | 175 | 176 | class IdentityPostprocessor: 177 | """Tanh Bijector.""" 178 | 179 | def forward(self, x): 180 | return x 181 | 182 | def inverse(self, y): 183 | return y 184 | 185 | def forward_log_det_jacobian(self, x): 186 | return 0 187 | 188 | 189 | class PolicyNormalDistribution(ParametricDistribution): 190 | """Normal distribution for clipping.""" 191 | 192 | def __init__(self, event_size, min_std=0.001): 193 | """Initialize the distribution. 194 | 195 | Args: 196 | event_size: the size of events (i.e. actions). 197 | min_std: minimum std for the gaussian. 198 | """ 199 | # We apply tanh to gaussian actions to bound them. 200 | # Normally we would use TransformedDistribution to automatically 201 | # apply tanh to the distribution. 202 | # We can't do it here because of tanh saturation 203 | # which would make log_prob computations impossible. Instead, most 204 | # of the code operate on pre-tanh actions and we take the postprocessor 205 | # jacobian into account in log_prob computations. 206 | super().__init__( 207 | param_size=2 * event_size, 208 | postprocessor=IdentityPostprocessor(), 209 | event_ndims=1, 210 | reparametrizable=True) 211 | self._min_std = min_std 212 | 213 | def create_dist(self, parameters): 214 | loc, scale = jnp.split(parameters, 2, axis=-1) 215 | scale = jax.nn.softplus(scale) + self._min_std 216 | return NormalDistribution(loc=loc, scale=scale) 217 | 218 | 219 | class DiscreteDistribution(abc.ABC): 220 | """Discrete (action) distribution.""" 221 | 222 | def __init__(self, param_size): 223 | """Discrete (action) distribution. 224 | 225 | Args: 226 | param_size: size of the parameters for the distribution, i.e. number of 227 | discrete actions. 228 | """ 229 | self._param_size = param_size 230 | self._event_ndims = 1 # rank of events 231 | self._reparametrizable = False 232 | 233 | @property 234 | def param_size(self): 235 | return self._param_size 236 | 237 | @property 238 | def reparametrizable(self): 239 | return self._reparametrizable 240 | 241 | def postprocess(self, event): 242 | return event 243 | 244 | def inverse_postprocess(self, event): 245 | return event 246 | 247 | def sample_no_postprocessing(self, parameters, seed): 248 | return tfd.Categorical(logits=parameters).sample(seed=seed) 249 | 250 | def sample(self, parameters, seed): 251 | """Returns a sample from the postprocessed distribution.""" 252 | return self.postprocess(self.sample_no_postprocessing(parameters, seed)) 253 | 254 | def mode(self, parameters): 255 | """Returns the mode of the discrete distribution.""" 256 | return tfd.Categorical(logits=parameters).mode() 257 | 258 | def log_prob(self, parameters, actions): 259 | """Compute the log probability of actions.""" 260 | return tfd.Categorical(logits=parameters).log_prob(actions) 261 | 262 | def entropy(self, parameters, seed): 263 | """Return the entropy of the given distribution.""" 264 | return tfd.Categorical(logits=parameters).entropy() 265 | 266 | def kl_divergence(self, p_parameters, q_parameters): 267 | """Return the KL divergence of the given distributions.""" 268 | p_distribution = tfd.Categorical(logits=p_parameters) 269 | q_distribution = tfd.Categorical(logits=q_parameters) 270 | return tfd.kl_divergence(p_distribution, q_distribution) -------------------------------------------------------------------------------- /envs/atari_wrappers.py: -------------------------------------------------------------------------------- 1 | """Atari wrappers from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/atari_wrappers.py""" 2 | 3 | from typing import Dict, SupportsFloat 4 | 5 | from typing import Any, Tuple 6 | 7 | import gymnasium as gym 8 | import numpy as np 9 | from gymnasium import spaces 10 | 11 | try: 12 | import cv2 # pytype:disable=import-error 13 | 14 | cv2.ocl.setUseOpenCL(False) 15 | except ImportError: 16 | cv2 = None # type: ignore[assignment] 17 | 18 | 19 | 20 | AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]] 21 | AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]] 22 | 23 | 24 | class StickyActionEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): 25 | """ 26 | Sticky action. 27 | 28 | Paper: https://arxiv.org/abs/1709.06009 29 | Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment 30 | 31 | :param env: Environment to wrap 32 | :param action_repeat_probability: Probability of repeating the last action 33 | """ 34 | 35 | def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: 36 | super().__init__(env) 37 | self.action_repeat_probability = action_repeat_probability 38 | assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] 39 | 40 | def reset(self, **kwargs) -> AtariResetReturn: 41 | self._sticky_action = 0 # NOOP 42 | return self.env.reset(**kwargs) 43 | 44 | def step(self, action: int) -> AtariStepReturn: 45 | if self.np_random.random() >= self.action_repeat_probability: 46 | self._sticky_action = action 47 | return self.env.step(self._sticky_action) 48 | 49 | 50 | class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): 51 | """ 52 | Sample initial states by taking random number of no-ops on reset. 53 | No-op is assumed to be action 0. 54 | 55 | :param env: Environment to wrap 56 | :param noop_max: Maximum value of no-ops to run 57 | """ 58 | 59 | def __init__(self, env: gym.Env, noop_max: int = 30) -> None: 60 | super().__init__(env) 61 | self.noop_max = noop_max 62 | self.override_num_noops = None 63 | self.noop_action = 0 64 | assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] 65 | 66 | def reset(self, **kwargs) -> AtariResetReturn: 67 | self.env.reset(**kwargs) 68 | if self.override_num_noops is not None: 69 | noops = self.override_num_noops 70 | else: 71 | noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) 72 | assert noops > 0 73 | obs = np.zeros(0) 74 | info: Dict = {} 75 | for _ in range(noops): 76 | obs, _, terminated, truncated, info = self.env.step(self.noop_action) 77 | if terminated or truncated: 78 | obs, info = self.env.reset(**kwargs) 79 | return obs, info 80 | 81 | 82 | class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): 83 | """ 84 | Take action on reset for environments that are fixed until firing. 85 | 86 | :param env: Environment to wrap 87 | """ 88 | 89 | def __init__(self, env: gym.Env) -> None: 90 | super().__init__(env) 91 | assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined] 92 | assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined] 93 | 94 | def reset(self, **kwargs) -> AtariResetReturn: 95 | self.env.reset(**kwargs) 96 | obs, _, terminated, truncated, _ = self.env.step(1) 97 | if terminated or truncated: 98 | self.env.reset(**kwargs) 99 | obs, _, terminated, truncated, _ = self.env.step(2) 100 | if terminated or truncated: 101 | self.env.reset(**kwargs) 102 | return obs, {} 103 | 104 | 105 | class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): 106 | """ 107 | Make end-of-life == end-of-episode, but only reset on true game over. 108 | Done by DeepMind for the DQN and co. since it helps value estimation. 109 | 110 | :param env: Environment to wrap 111 | """ 112 | 113 | def __init__(self, env: gym.Env) -> None: 114 | super().__init__(env) 115 | self.lives = 0 116 | self.was_real_done = True 117 | 118 | def step(self, action: int) -> AtariStepReturn: 119 | obs, reward, terminated, truncated, info = self.env.step(action) 120 | self.was_real_done = terminated or truncated 121 | # check current lives, make loss of life terminal, 122 | # then update lives to handle bonus lives 123 | lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] 124 | if 0 < lives < self.lives: 125 | # for Qbert sometimes we stay in lives == 0 condition for a few frames 126 | # so its important to keep lives > 0, so that we only reset once 127 | # the environment advertises done. 128 | terminated = True 129 | self.lives = lives 130 | return obs, reward, terminated, truncated, info 131 | 132 | def reset(self, **kwargs) -> AtariResetReturn: 133 | """ 134 | Calls the Gym environment reset, only when lives are exhausted. 135 | This way all states are still reachable even though lives are episodic, 136 | and the learner need not know about any of this behind-the-scenes. 137 | 138 | :param kwargs: Extra keywords passed to env.reset() call 139 | :return: the first observation of the environment 140 | """ 141 | if self.was_real_done: 142 | obs, info = self.env.reset(**kwargs) 143 | else: 144 | # no-op step to advance from terminal/lost life state 145 | obs, _, terminated, truncated, info = self.env.step(0) 146 | 147 | # The no-op step can lead to a game over, so we need to check it again 148 | # to see if we should reset the environment and avoid the 149 | # monitor.py `RuntimeError: Tried to step environment that needs reset` 150 | if terminated or truncated: 151 | obs, info = self.env.reset(**kwargs) 152 | self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] 153 | return obs, info 154 | 155 | 156 | class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): 157 | """ 158 | Return only every ``skip``-th frame (frameskipping) 159 | and return the max between the two last frames. 160 | 161 | :param env: Environment to wrap 162 | :param skip: Number of ``skip``-th frame 163 | The same action will be taken ``skip`` times. 164 | """ 165 | 166 | def __init__(self, env: gym.Env, skip: int = 4) -> None: 167 | super().__init__(env) 168 | # most recent raw observations (for max pooling across time steps) 169 | assert env.observation_space.dtype is not None, "No dtype specified for the observation space" 170 | assert env.observation_space.shape is not None, "No shape defined for the observation space" 171 | self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype) 172 | self._skip = skip 173 | 174 | def step(self, action: int) -> AtariStepReturn: 175 | """ 176 | Step the environment with the given action 177 | Repeat action, sum reward, and max over last observations. 178 | 179 | :param action: the action 180 | :return: observation, reward, terminated, truncated, information 181 | """ 182 | total_reward = 0.0 183 | terminated = truncated = False 184 | for i in range(self._skip): 185 | obs, reward, terminated, truncated, info = self.env.step(action) 186 | done = terminated or truncated 187 | if i == self._skip - 2: 188 | self._obs_buffer[0] = obs 189 | if i == self._skip - 1: 190 | self._obs_buffer[1] = obs 191 | total_reward += float(reward) 192 | if done: 193 | break 194 | # Note that the observation on the done=True frame 195 | # doesn't matter 196 | max_frame = self._obs_buffer.max(axis=0) 197 | 198 | return max_frame, total_reward, terminated, truncated, info 199 | 200 | 201 | class ClipRewardEnv(gym.RewardWrapper): 202 | """ 203 | Clip the reward to {+1, 0, -1} by its sign. 204 | 205 | :param env: Environment to wrap 206 | """ 207 | 208 | def __init__(self, env: gym.Env) -> None: 209 | super().__init__(env) 210 | 211 | def reward(self, reward: SupportsFloat) -> float: 212 | """ 213 | Bin reward to {+1, 0, -1} by its sign. 214 | 215 | :param reward: 216 | :return: 217 | """ 218 | return np.sign(float(reward)) 219 | 220 | 221 | class WarpFrame(gym.ObservationWrapper[np.ndarray, int, np.ndarray]): 222 | """ 223 | Convert to grayscale and warp frames to 84x84 (default) 224 | as done in the Nature paper and later work. 225 | 226 | :param env: Environment to wrap 227 | :param width: New frame width 228 | :param height: New frame height 229 | """ 230 | 231 | def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None: 232 | super().__init__(env) 233 | self.width = width 234 | self.height = height 235 | assert isinstance(env.observation_space, spaces.Box), f"Expected Box space, got {env.observation_space}" 236 | 237 | self.observation_space = spaces.Box( 238 | low=0, 239 | high=255, 240 | shape=(self.height, self.width, 1), 241 | dtype=env.observation_space.dtype, # type: ignore[arg-type] 242 | ) 243 | 244 | def observation(self, frame: np.ndarray) -> np.ndarray: 245 | """ 246 | returns the current observation from a frame 247 | 248 | :param frame: environment frame 249 | :return: the observation 250 | """ 251 | assert cv2 is not None, "OpenCV is not installed, you can do `pip install opencv-python`" 252 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 253 | frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) 254 | return frame[:, :, None] 255 | 256 | 257 | class AtariWrapper(gym.Wrapper[np.ndarray, int, np.ndarray, int]): 258 | """ 259 | Atari 2600 preprocessings 260 | 261 | Specifically: 262 | 263 | * Noop reset: obtain initial state by taking random number of no-ops on reset. 264 | * Frame skipping: 4 by default 265 | * Max-pooling: most recent two observations 266 | * Termination signal when a life is lost. 267 | * Resize to a square image: 84x84 by default 268 | * Grayscale observation 269 | * Clip reward to {-1, 0, 1} 270 | * Sticky actions: disabled by default 271 | 272 | See https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/ 273 | for a visual explanation. 274 | 275 | .. warning:: 276 | Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``. 277 | 278 | :param env: Environment to wrap 279 | :param noop_max: Max number of no-ops 280 | :param frame_skip: Frequency at which the agent experiences the game. 281 | This correspond to repeating the action ``frame_skip`` times. 282 | :param screen_size: Resize Atari frame 283 | :param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost. 284 | :param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign. 285 | :param action_repeat_probability: Probability of repeating the last action 286 | """ 287 | 288 | def __init__( 289 | self, 290 | env: gym.Env, 291 | noop_max: int = 30, 292 | frame_skip: int = 4, 293 | screen_size: int = 84, 294 | terminal_on_life_loss: bool = True, 295 | clip_reward: bool = True, 296 | action_repeat_probability: float = 0.0, 297 | ) -> None: 298 | if action_repeat_probability > 0.0: 299 | env = StickyActionEnv(env, action_repeat_probability) 300 | if noop_max > 0: 301 | env = NoopResetEnv(env, noop_max=noop_max) 302 | # frame_skip=1 is the same as no frame-skip (action repeat) 303 | if frame_skip > 1: 304 | env = MaxAndSkipEnv(env, skip=frame_skip) 305 | if terminal_on_life_loss: 306 | env = EpisodicLifeEnv(env) 307 | if "FIRE" in env.unwrapped.get_action_meanings(): # type: ignore[attr-defined] 308 | env = FireResetEnv(env) 309 | env = WarpFrame(env, width=screen_size, height=screen_size) 310 | if clip_reward: 311 | env = ClipRewardEnv(env) 312 | 313 | super().__init__(env) -------------------------------------------------------------------------------- /envs/normalize.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import pickle 3 | from copy import deepcopy 4 | from typing import Any, Dict, List, Optional, Tuple, Union 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | from gymnasium import spaces 10 | 11 | from envs.base import VecEnv, VecEnvWrapper 12 | from envs.state import State 13 | from envs.utils import check_shape_equal, is_image_space 14 | 15 | 16 | 17 | class RunningMeanStd: 18 | def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): 19 | """ 20 | Calulates the running mean and std of a data stream 21 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 22 | 23 | :param epsilon: helps with arithmetic issues 24 | :param shape: the shape of the data stream's output 25 | """ 26 | self.mean = np.zeros(shape, np.float64) 27 | self.var = np.ones(shape, np.float64) 28 | self.count = epsilon 29 | 30 | def copy(self) -> "RunningMeanStd": 31 | """ 32 | :return: Return a copy of the current object. 33 | """ 34 | new_object = RunningMeanStd(shape=self.mean.shape) 35 | new_object.mean = self.mean.copy() 36 | new_object.var = self.var.copy() 37 | new_object.count = float(self.count) 38 | return new_object 39 | 40 | def combine(self, other: "RunningMeanStd") -> None: 41 | """ 42 | Combine stats from another ``RunningMeanStd`` object. 43 | 44 | :param other: The other object to combine with. 45 | """ 46 | self.update_from_moments(other.mean, other.var, other.count) 47 | 48 | def update(self, arr: np.ndarray) -> None: 49 | batch_mean = np.mean(arr, axis=0) 50 | batch_var = np.var(arr, axis=0) 51 | batch_count = arr.shape[0] 52 | self.update_from_moments(batch_mean, batch_var, batch_count) 53 | 54 | def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: float) -> None: 55 | delta = batch_mean - self.mean 56 | tot_count = self.count + batch_count 57 | 58 | new_mean = self.mean + delta * batch_count / tot_count 59 | m_a = self.var * self.count 60 | m_b = batch_var * batch_count 61 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) 62 | new_var = m_2 / (self.count + batch_count) 63 | 64 | new_count = batch_count + self.count 65 | 66 | self.mean = new_mean 67 | self.var = new_var 68 | self.count = new_count 69 | 70 | 71 | class VecNormalize(VecEnvWrapper): 72 | """ 73 | A moving average, normalizing wrapper for vectorized environment. 74 | has support for saving/loading moving average, 75 | 76 | :param venv: the vectorized environment to wrap 77 | :param training: Whether to update or not the moving average 78 | :param norm_obs: Whether to normalize observation or not (default: True) 79 | :param norm_reward: Whether to normalize rewards or not (default: True) 80 | :param clip_obs: Max absolute value for observation 81 | :param clip_reward: Max value absolute for discounted reward 82 | :param gamma: discount factor 83 | :param epsilon: To avoid division by zero 84 | :param norm_obs_keys: Which keys from observation dict to normalize. 85 | If not specified, all keys will be normalized. 86 | """ 87 | 88 | def __init__( 89 | self, 90 | venv: VecEnv, 91 | training: bool = True, 92 | norm_obs: bool = True, 93 | norm_reward: bool = True, 94 | clip_obs: float = 10.0, 95 | clip_reward: float = 10.0, 96 | gamma: float = 0.99, 97 | epsilon: float = 1e-8, 98 | norm_obs_keys: Optional[List[str]] = None, 99 | ): 100 | VecEnvWrapper.__init__(self, venv) 101 | 102 | self.norm_obs = norm_obs 103 | self.norm_obs_keys = norm_obs_keys 104 | # Check observation spaces 105 | if self.norm_obs: 106 | self._sanity_checks() 107 | 108 | if isinstance(self.observation_space, spaces.Dict): 109 | self.obs_spaces = self.observation_space.spaces 110 | self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} 111 | # Update observation space when using image 112 | # See explanation below and GH #1214 113 | for key in self.obs_rms.keys(): 114 | if is_image_space(self.obs_spaces[key]): 115 | self.observation_space.spaces[key] = spaces.Box( 116 | low=-clip_obs, 117 | high=clip_obs, 118 | shape=self.obs_spaces[key].shape, 119 | dtype=np.float32, 120 | ) 121 | 122 | else: 123 | self.obs_spaces = None 124 | self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) 125 | # Update observation space when using image 126 | # See GH #1214 127 | # This is to raise proper error when 128 | # VecNormalize is used with an image-like input and 129 | # normalize_images=True. 130 | # For correctness, we should also update the bounds 131 | # in other cases but this will cause backward-incompatible change 132 | # and break already saved policies. 133 | if is_image_space(self.observation_space): 134 | self.observation_space = spaces.Box( 135 | low=-clip_obs, 136 | high=clip_obs, 137 | shape=self.observation_space.shape, 138 | dtype=np.float32, 139 | ) 140 | 141 | self.ret_rms = RunningMeanStd(shape=()) 142 | self.clip_obs = clip_obs 143 | self.clip_reward = clip_reward 144 | # Returns: discounted rewards 145 | self.returns = np.zeros(self.num_envs) 146 | self.gamma = gamma 147 | self.epsilon = epsilon 148 | self.training = training 149 | self.norm_obs = norm_obs 150 | self.norm_reward = norm_reward 151 | self.old_obs = np.array([]) 152 | self.old_reward = np.array([]) 153 | 154 | def _sanity_checks(self) -> None: 155 | """ 156 | Check the observations that are going to be normalized are of the correct type (spaces.Box). 157 | """ 158 | if isinstance(self.observation_space, spaces.Dict): 159 | # By default, we normalize all keys 160 | if self.norm_obs_keys is None: 161 | self.norm_obs_keys = list(self.observation_space.spaces.keys()) 162 | # Check that all keys are of type Box 163 | for obs_key in self.norm_obs_keys: 164 | if not isinstance(self.observation_space.spaces[obs_key], spaces.Box): 165 | raise ValueError( 166 | f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} " 167 | f"is of type {self.observation_space.spaces[obs_key]}. " 168 | "You should probably explicitely pass the observation keys " 169 | " that should be normalized via the `norm_obs_keys` parameter." 170 | ) 171 | 172 | elif isinstance(self.observation_space, spaces.Box): 173 | if self.norm_obs_keys is not None: 174 | raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces") 175 | 176 | else: 177 | raise ValueError( 178 | "VecNormalize only supports `gym.spaces.Box` and `gym.spaces.Dict` observation spaces, " 179 | f"not {self.observation_space}" 180 | ) 181 | 182 | def __getstate__(self) -> Dict[str, Any]: 183 | """ 184 | Gets state for pickling. 185 | 186 | Excludes self.venv, as in general VecEnv's may not be pickleable.""" 187 | state = self.__dict__.copy() 188 | # these attributes are not pickleable 189 | del state["venv"] 190 | del state["class_attributes"] 191 | # these attributes depend on the above and so we would prefer not to pickle 192 | del state["returns"] 193 | return state 194 | 195 | def __setstate__(self, state: Dict[str, Any]) -> None: 196 | """ 197 | Restores pickled state. 198 | 199 | User must call set_venv() after unpickling before using. 200 | 201 | :param state:""" 202 | # Backward compatibility 203 | if "norm_obs_keys" not in state and isinstance(state["observation_space"], spaces.Dict): 204 | state["norm_obs_keys"] = list(state["observation_space"].spaces.keys()) 205 | self.__dict__.update(state) 206 | assert "venv" not in state 207 | self.venv = None 208 | 209 | def set_venv(self, venv: VecEnv) -> None: 210 | """ 211 | Sets the vector environment to wrap to venv. 212 | 213 | Also sets attributes derived from this such as `num_env`. 214 | 215 | :param venv: 216 | """ 217 | if self.venv is not None: 218 | raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.") 219 | self.venv = venv 220 | self.num_envs = venv.num_envs 221 | self.class_attributes = dict(inspect.getmembers(self.__class__)) 222 | self.render_mode = venv.render_mode 223 | 224 | # Check that the observation_space shape match 225 | check_shape_equal(self.observation_space, venv.observation_space) 226 | self.returns = np.zeros(self.num_envs) 227 | 228 | def step_wait(self) -> State: 229 | """ 230 | Apply sequence of actions to sequence of environments 231 | actions -> (observations, rewards, dones) 232 | 233 | where ``dones`` is a boolean vector indicating whether each element is new. 234 | """ 235 | env_state = self.venv.step_wait() 236 | env_state = jax.tree_util.tree_map(np.asarray, env_state) 237 | obs, rewards, dones, infos = env_state.obs, env_state.reward, env_state.done, env_state.info 238 | self.old_obs = obs 239 | self.old_reward = rewards 240 | 241 | if self.training and self.norm_obs: 242 | if isinstance(obs, dict) and isinstance(self.obs_rms, dict): 243 | for key in self.obs_rms.keys(): 244 | self.obs_rms[key].update(obs[key]) 245 | else: 246 | self.obs_rms.update(obs) 247 | 248 | obs = self.normalize_obs(obs) 249 | 250 | if self.training: 251 | self._update_reward(rewards) 252 | rewards = self.normalize_reward(rewards) 253 | 254 | # Normalize the terminal observations 255 | for idx, done in enumerate(dones): 256 | if not done: 257 | continue 258 | if "terminal_observation" in infos[idx]: 259 | infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"]) 260 | 261 | self.returns[dones] = 0 262 | # return obs, rewards, dones, infos 263 | return State(obs=obs, 264 | reward=rewards, 265 | done=dones, 266 | info=infos) 267 | 268 | def _update_reward(self, reward: np.ndarray) -> None: 269 | """Update reward normalization statistics.""" 270 | self.returns = self.returns * self.gamma + reward 271 | self.ret_rms.update(self.returns) 272 | 273 | def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray: 274 | """ 275 | Helper to normalize observation. 276 | :param obs: 277 | :param obs_rms: associated statistics 278 | :return: normalized observation 279 | """ 280 | return np.clip((obs - obs_rms.mean) / np.sqrt(obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs) 281 | 282 | def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray: 283 | """ 284 | Helper to unnormalize observation. 285 | :param obs: 286 | :param obs_rms: associated statistics 287 | :return: unnormalized observation 288 | """ 289 | return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean 290 | 291 | def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: 292 | """ 293 | Normalize observations using this VecNormalize's observations statistics. 294 | Calling this method does not update statistics. 295 | """ 296 | # Avoid modifying by reference the original object 297 | obs_ = deepcopy(obs) 298 | if self.norm_obs: 299 | if isinstance(obs, dict) and isinstance(self.obs_rms, dict): 300 | # Only normalize the specified keys 301 | for key in self.norm_obs_keys: 302 | obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32) 303 | else: 304 | obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32) 305 | return obs_ 306 | 307 | def normalize_reward(self, reward: np.ndarray) -> np.ndarray: 308 | """ 309 | Normalize rewards using this VecNormalize's rewards statistics. 310 | Calling this method does not update statistics. 311 | """ 312 | if self.norm_reward: 313 | reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) 314 | return reward 315 | 316 | def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: 317 | # Avoid modifying by reference the original object 318 | obs_ = deepcopy(obs) 319 | if self.norm_obs: 320 | if isinstance(obs, dict) and isinstance(self.obs_rms, dict): 321 | for key in self.norm_obs_keys: 322 | obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key]) 323 | else: 324 | obs_ = self._unnormalize_obs(obs, self.obs_rms) 325 | return obs_ 326 | 327 | def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray: 328 | if self.norm_reward: 329 | return reward * np.sqrt(self.ret_rms.var + self.epsilon) 330 | return reward 331 | 332 | def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: 333 | """ 334 | Returns an unnormalized version of the observations from the most recent 335 | step or reset. 336 | """ 337 | return deepcopy(self.old_obs) 338 | 339 | def get_original_reward(self) -> np.ndarray: 340 | """ 341 | Returns an unnormalized version of the rewards from the most recent step. 342 | """ 343 | return self.old_reward.copy() 344 | 345 | def reset(self) -> State: # Union[np.ndarray, Dict[str, np.ndarray]]: 346 | """ 347 | Reset all environments 348 | :return: first observation of the episode 349 | """ 350 | env_state = self.venv.reset() 351 | obs = np.asarray(env_state.obs) 352 | self.old_obs = obs 353 | self.returns = np.zeros(self.num_envs) 354 | if self.training and self.norm_obs: 355 | if isinstance(obs, dict) and isinstance(self.obs_rms, dict): 356 | for key in self.obs_rms.keys(): 357 | self.obs_rms[key].update(obs[key]) 358 | else: 359 | self.obs_rms.update(obs) 360 | # return self.normalize_obs(obs) 361 | return State(obs=self.normalize_obs(obs), 362 | reward=jnp.zeros(self.num_envs), 363 | done=jnp.zeros(self.num_envs)) 364 | 365 | @staticmethod 366 | def load(load_path: str, venv: VecEnv) -> "VecNormalize": 367 | """ 368 | Loads a saved VecNormalize object. 369 | 370 | :param load_path: the path to load from. 371 | :param venv: the VecEnv to wrap. 372 | :return: 373 | """ 374 | with open(load_path, "rb") as file_handler: 375 | vec_normalize = pickle.load(file_handler) 376 | vec_normalize.set_venv(venv) 377 | return vec_normalize 378 | 379 | def save(self, save_path: str) -> None: 380 | """ 381 | Save current VecNormalize object with 382 | all running statistics and settings (e.g. clip_obs) 383 | 384 | :param save_path: The path to save to 385 | """ 386 | with open(save_path, "wb") as file_handler: 387 | pickle.dump(self, file_handler) -------------------------------------------------------------------------------- /envs/batched_env.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import warnings 3 | from collections import OrderedDict 4 | from copy import deepcopy 5 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union 6 | 7 | import jax.numpy as jnp 8 | import gymnasium as gym 9 | from gymnasium import spaces 10 | import numpy as np 11 | 12 | from envs.base import ( 13 | CloudpickleWrapper, 14 | VecEnv, 15 | VecEnvIndices, 16 | VecEnvObs, 17 | VecEnvStepReturn, 18 | ) 19 | from envs.state import State 20 | from envs.utils import copy_obs_dict, dict_to_obs, obs_space_info 21 | 22 | 23 | 24 | class SequencedBatchedEnv(VecEnv): 25 | """ 26 | Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current 27 | Python process. This is useful for computationally simple environment such as ``Cartpole-v1``, 28 | as the overhead of multiprocess or multithread outweighs the environment computation time. 29 | This can also be used for RL methods that 30 | require a vectorized environment, but that you want a single environments to train with. 31 | 32 | :param env_fns: a list of functions 33 | that return environments to vectorize 34 | :raises ValueError: If the same environment instance is passed as the output of two or more different env_fn. 35 | """ 36 | 37 | actions: np.ndarray 38 | 39 | def __init__(self, env_fns: List[Callable[[], gym.Env]]): 40 | self.envs = [fn() for fn in env_fns] 41 | if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs): 42 | raise ValueError( 43 | "You tried to create multiple environments, but the function to create them returned the same instance " 44 | "instead of creating different objects. " 45 | "You are probably using `make_vec_env(lambda: env)` or `SequencedBatchedEnv([lambda: env] * n_envs)`. " 46 | "You should replace `lambda: env` by a `make_env` function that " 47 | "creates a new instance of the environment at every call " 48 | "(using `gym.make()` for instance). You can take a look at the documentation for an example. " 49 | "Please read https://github.com/DLR-RM/stable-baselines3/issues/1151 for more information." 50 | ) 51 | env = self.envs[0] 52 | super().__init__(len(env_fns), env.observation_space, env.action_space) 53 | obs_space = env.observation_space 54 | self.keys, shapes, dtypes = obs_space_info(obs_space) 55 | 56 | self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys]) 57 | self.buf_dones = np.zeros((self.num_envs,), dtype=bool) 58 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) 59 | self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)] 60 | self.metadata = env.metadata 61 | 62 | def step_async(self, actions: np.ndarray) -> None: 63 | self.actions = actions 64 | 65 | def step_wait(self) -> State: # VecEnvStepReturn: 66 | # Avoid circular imports 67 | for env_idx in range(self.num_envs): 68 | obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step( 69 | self.actions[env_idx] 70 | ) 71 | # convert to SB3 VecEnv api 72 | self.buf_dones[env_idx] = terminated or truncated 73 | # See https://github.com/openai/gym/issues/3102 74 | # Gym 0.26 introduces a breaking change 75 | # self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated 76 | self.buf_infos[env_idx]['truncation'] = truncated and not terminated 77 | 78 | if self.buf_dones[env_idx]: 79 | # save final observation where user can get it, then reset 80 | self.buf_infos[env_idx]["terminal_observation"] = obs 81 | obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() 82 | self._save_obs(env_idx, obs) 83 | #return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) 84 | return State(obs=self._obs_from_buf(), 85 | reward=jnp.copy(self.buf_rews), 86 | done=jnp.copy(self.buf_dones), 87 | # metrics=dict(), 88 | info=deepcopy(self.buf_infos) 89 | ) 90 | 91 | def reset(self) -> State: 92 | for env_idx in range(self.num_envs): 93 | obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx]) 94 | self._save_obs(env_idx, obs) 95 | # Seeds are only used once 96 | self._reset_seeds() 97 | obs = self._obs_from_buf() 98 | return State(obs=obs, 99 | reward=jnp.zeros(self.num_envs), 100 | done=jnp.zeros(self.num_envs), 101 | #metrics=dict() 102 | ) 103 | 104 | def close(self) -> None: 105 | for env in self.envs: 106 | env.close() 107 | 108 | def get_images(self) -> Sequence[Optional[np.ndarray]]: 109 | if self.render_mode != "rgb_array": 110 | warnings.warn( 111 | f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." 112 | ) 113 | return [None for _ in self.envs] 114 | return [env.render() for env in self.envs] # type: ignore[misc] 115 | 116 | def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: 117 | """ 118 | Gym environment rendering. If there are multiple environments then 119 | they are tiled together in one image via ``BaseVecEnv.render()``. 120 | 121 | :param mode: The rendering type. 122 | """ 123 | return super().render(mode=mode) 124 | 125 | def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: 126 | for key in self.keys: 127 | if key is None: 128 | self.buf_obs[key][env_idx] = obs 129 | else: 130 | self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload] 131 | 132 | def _obs_from_buf(self) -> VecEnvObs: 133 | return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) 134 | 135 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: 136 | """Return attribute from vectorized environment (see base class).""" 137 | target_envs = self._get_target_envs(indices) 138 | return [getattr(env_i, attr_name) for env_i in target_envs] 139 | 140 | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: 141 | """Set attribute inside vectorized environments (see base class).""" 142 | target_envs = self._get_target_envs(indices) 143 | for env_i in target_envs: 144 | setattr(env_i, attr_name, value) 145 | 146 | def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: 147 | """Call instance methods of vectorized environments.""" 148 | target_envs = self._get_target_envs(indices) 149 | return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] 150 | 151 | def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]: 152 | indices = self._get_indices(indices) 153 | return [self.envs[i] for i in indices] 154 | 155 | 156 | 157 | def _worker( 158 | remote: mp.connection.Connection, 159 | parent_remote: mp.connection.Connection, 160 | env_fn_wrapper: CloudpickleWrapper, 161 | ) -> None: 162 | parent_remote.close() 163 | env = env_fn_wrapper.var() 164 | reset_info: Optional[Dict[str, Any]] = {} 165 | while True: 166 | try: 167 | cmd, data = remote.recv() 168 | if cmd == "step": 169 | observation, reward, terminated, truncated, info = env.step(data) 170 | # convert to SB3 VecEnv api 171 | done = terminated or truncated 172 | # info["TimeLimit.truncated"] = truncated and not terminated 173 | info['truncation'] = truncated and not terminated 174 | if done: 175 | # save final observation where user can get it, then reset 176 | info["terminal_observation"] = observation 177 | observation, reset_info = env.reset() 178 | remote.send((observation, reward, done, info, reset_info)) 179 | elif cmd == "reset": 180 | observation, reset_info = env.reset(seed=data) 181 | remote.send((observation, reset_info)) 182 | elif cmd == "render": 183 | remote.send(env.render()) 184 | elif cmd == "close": 185 | env.close() 186 | remote.close() 187 | break 188 | elif cmd == "get_spaces": 189 | remote.send((env.observation_space, env.action_space)) 190 | elif cmd == "env_method": 191 | method = getattr(env, data[0]) 192 | remote.send(method(*data[1], **data[2])) 193 | elif cmd == "get_attr": 194 | remote.send(getattr(env, data)) 195 | elif cmd == "set_attr": 196 | remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value] 197 | else: 198 | raise NotImplementedError(f"`{cmd}` is not implemented in the worker") 199 | except EOFError: 200 | break 201 | 202 | 203 | class ParallelBatchedEnv(VecEnv): 204 | """ 205 | Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own 206 | process, allowing significant speed up when the environment is computationally complex. 207 | 208 | For performance reasons, if your environment is not IO bound, the number of environments should not exceed the 209 | number of logical cores on your CPU. 210 | 211 | .. warning:: 212 | 213 | Only 'forkserver' and 'spawn' start methods are thread-safe, 214 | which is important when TensorFlow sessions or other non thread-safe 215 | libraries are used in the parent (see issue #217). However, compared to 216 | 'fork' they incur a small start-up cost and have restrictions on 217 | global variables. With those methods, users must wrap the code in an 218 | ``if __name__ == "__main__":`` block. 219 | For more information, see the multiprocessing documentation. 220 | 221 | :param env_fns: Environments to run in subprocesses 222 | :param start_method: method used to start the subprocesses. 223 | Must be one of the methods returned by multiprocessing.get_all_start_methods(). 224 | Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. 225 | """ 226 | 227 | def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None): 228 | self.waiting = False 229 | self.closed = False 230 | n_envs = len(env_fns) 231 | 232 | if start_method is None: 233 | # Fork is not a thread safe method (see issue #217) 234 | # but is more user friendly (does not require to wrap the code in 235 | # a `if __name__ == "__main__":`) 236 | forkserver_available = "forkserver" in mp.get_all_start_methods() 237 | start_method = "forkserver" if forkserver_available else "spawn" 238 | ctx = mp.get_context(start_method) 239 | 240 | self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)]) 241 | self.processes = [] 242 | for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns): 243 | args = (work_remote, remote, CloudpickleWrapper(env_fn)) 244 | # daemon=True: if the main process crashes, we should not cause things to hang 245 | process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined] 246 | process.start() 247 | self.processes.append(process) 248 | work_remote.close() 249 | 250 | self.remotes[0].send(("get_spaces", None)) 251 | observation_space, action_space = self.remotes[0].recv() 252 | 253 | super().__init__(len(env_fns), observation_space, action_space) 254 | 255 | def step_async(self, actions: np.ndarray) -> None: 256 | for remote, action in zip(self.remotes, actions): 257 | remote.send(("step", action)) 258 | self.waiting = True 259 | 260 | def step_wait(self) -> State: 261 | results = [remote.recv() for remote in self.remotes] 262 | self.waiting = False 263 | obs, rews, dones, infos, self.reset_infos = zip(*results) 264 | # return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos 265 | return State(obs=_flatten_obs(obs, self.observation_space), 266 | reward=jnp.stack(rews), 267 | done=jnp.stack(dones), 268 | info=infos) 269 | 270 | def reset(self) -> State: 271 | for env_idx, remote in enumerate(self.remotes): 272 | remote.send(("reset", self._seeds[env_idx])) 273 | results = [remote.recv() for remote in self.remotes] 274 | obs, self.reset_infos = zip(*results) 275 | # Seeds are only used once 276 | self._reset_seeds() 277 | # return _flatten_obs(obs, self.observation_space) 278 | return State(obs=_flatten_obs(obs, self.observation_space), 279 | reward=jnp.zeros(self.num_envs), 280 | done=jnp.zeros(self.num_envs), 281 | ) 282 | 283 | def close(self) -> None: 284 | if self.closed: 285 | return 286 | if self.waiting: 287 | for remote in self.remotes: 288 | remote.recv() 289 | for remote in self.remotes: 290 | remote.send(("close", None)) 291 | for process in self.processes: 292 | process.join() 293 | self.closed = True 294 | 295 | def get_images(self) -> Sequence[Optional[np.ndarray]]: 296 | if self.render_mode != "rgb_array": 297 | warnings.warn( 298 | f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." 299 | ) 300 | return [None for _ in self.remotes] 301 | for pipe in self.remotes: 302 | # gather render return from subprocesses 303 | pipe.send(("render", None)) 304 | outputs = [pipe.recv() for pipe in self.remotes] 305 | return outputs 306 | 307 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: 308 | """Return attribute from vectorized environment (see base class).""" 309 | target_remotes = self._get_target_remotes(indices) 310 | for remote in target_remotes: 311 | remote.send(("get_attr", attr_name)) 312 | return [remote.recv() for remote in target_remotes] 313 | 314 | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: 315 | """Set attribute inside vectorized environments (see base class).""" 316 | target_remotes = self._get_target_remotes(indices) 317 | for remote in target_remotes: 318 | remote.send(("set_attr", (attr_name, value))) 319 | for remote in target_remotes: 320 | remote.recv() 321 | 322 | def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: 323 | """Call instance methods of vectorized environments.""" 324 | target_remotes = self._get_target_remotes(indices) 325 | for remote in target_remotes: 326 | remote.send(("env_method", (method_name, method_args, method_kwargs))) 327 | return [remote.recv() for remote in target_remotes] 328 | 329 | def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]: 330 | """ 331 | Get the connection object needed to communicate with the wanted 332 | envs that are in subprocesses. 333 | 334 | :param indices: refers to indices of envs. 335 | :return: Connection object to communicate between processes. 336 | """ 337 | indices = self._get_indices(indices) 338 | return [self.remotes[i] for i in indices] 339 | 340 | 341 | def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs: 342 | """ 343 | Flatten observations, depending on the observation space. 344 | 345 | :param obs: observations. 346 | A list or tuple of observations, one per environment. 347 | Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays. 348 | :return: flattened observations. 349 | A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays. 350 | Each NumPy array has the environment index as its first axis. 351 | """ 352 | assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" 353 | assert len(obs) > 0, "need observations from at least one environment" 354 | 355 | if isinstance(space, spaces.Dict): 356 | assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" 357 | assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" 358 | return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) 359 | elif isinstance(space, spaces.Tuple): 360 | assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" 361 | obs_len = len(space.spaces) 362 | return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index] 363 | else: 364 | return np.stack(obs) # type: ignore[arg-type] -------------------------------------------------------------------------------- /envs/base.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union 5 | 6 | import cloudpickle 7 | import gymnasium as gym 8 | import numpy as np 9 | from gymnasium import spaces 10 | 11 | from envs.state import State 12 | 13 | VecEnvIndices = Union[None, int, Iterable[int]] 14 | VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]] 15 | VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]] 16 | 17 | 18 | def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover 19 | """ 20 | Tile N images into one big PxQ image 21 | (P,Q) are chosen to be as close as possible, and if N 22 | is square, then P=Q. 23 | 24 | :param images_nhwc: list or array of images, ndim=4 once turned into array. 25 | n = batch index, h = height, w = width, c = channel 26 | :return: img_HWc, ndim=3 27 | """ 28 | img_nhwc = np.asarray(images_nhwc) 29 | n_images, height, width, n_channels = img_nhwc.shape 30 | # new_height was named H before 31 | new_height = int(np.ceil(np.sqrt(n_images))) 32 | # new_width was named W before 33 | new_width = int(np.ceil(float(n_images) / new_height)) 34 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)]) 35 | # img_HWhwc 36 | out_image = img_nhwc.reshape((new_height, new_width, height, width, n_channels)) 37 | # img_HhWwc 38 | out_image = out_image.transpose(0, 2, 1, 3, 4) 39 | # img_Hh_Ww_c 40 | out_image = out_image.reshape((new_height * height, new_width * width, n_channels)) 41 | return out_image 42 | 43 | 44 | class VecEnv(ABC): 45 | """ 46 | An abstract asynchronous, vectorized environment. 47 | 48 | :param num_envs: Number of environments 49 | :param observation_space: Observation space 50 | :param action_space: Action space 51 | """ 52 | 53 | def __init__( 54 | self, 55 | num_envs: int, 56 | observation_space: spaces.Space, 57 | action_space: spaces.Space, 58 | ): 59 | self.num_envs = num_envs 60 | self.observation_space = observation_space 61 | self.action_space = action_space 62 | # store info returned by the reset method 63 | self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)] 64 | # seeds to be used in the next call to env.reset() 65 | self._seeds: List[Optional[int]] = [None for _ in range(num_envs)] 66 | 67 | try: 68 | render_modes = self.get_attr("render_mode") 69 | except AttributeError: 70 | warnings.warn("The `render_mode` attribute is not defined in your environment. It will be set to None.") 71 | render_modes = [None for _ in range(num_envs)] 72 | 73 | assert all( 74 | render_mode == render_modes[0] for render_mode in render_modes 75 | ), "render_mode mode should be the same for all environments" 76 | self.render_mode = render_modes[0] 77 | 78 | render_modes = [] 79 | if self.render_mode is not None: 80 | if self.render_mode == "rgb_array": 81 | # SB3 uses OpenCV for the "human" mode 82 | render_modes = ["human", "rgb_array"] 83 | else: 84 | render_modes = [self.render_mode] 85 | 86 | self.metadata = {"render_modes": render_modes} 87 | 88 | def _reset_seeds(self) -> None: 89 | """ 90 | Reset the seeds that are going to be used at the next reset. 91 | """ 92 | self._seeds = [None for _ in range(self.num_envs)] 93 | 94 | @abstractmethod 95 | def reset(self) -> State: 96 | """ 97 | Reset all the environments and return an array of 98 | observations, or a tuple of observation arrays. 99 | 100 | If step_async is still doing work, that work will 101 | be cancelled and step_wait() should not be called 102 | until step_async() is invoked again. 103 | 104 | :return: observation 105 | """ 106 | raise NotImplementedError() 107 | 108 | @abstractmethod 109 | def step_async(self, actions: np.ndarray) -> None: 110 | """ 111 | Tell all the environments to start taking a step 112 | with the given actions. 113 | Call step_wait() to get the results of the step. 114 | 115 | You should not call this if a step_async run is 116 | already pending. 117 | """ 118 | raise NotImplementedError() 119 | 120 | @abstractmethod 121 | def step_wait(self) -> State: 122 | """ 123 | Wait for the step taken with step_async(). 124 | 125 | :return: observation, reward, done, information 126 | """ 127 | raise NotImplementedError() 128 | 129 | @abstractmethod 130 | def close(self) -> None: 131 | """ 132 | Clean up the environment's resources. 133 | """ 134 | raise NotImplementedError() 135 | 136 | @abstractmethod 137 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: 138 | """ 139 | Return attribute from vectorized environment. 140 | 141 | :param attr_name: The name of the attribute whose value to return 142 | :param indices: Indices of envs to get attribute from 143 | :return: List of values of 'attr_name' in all environments 144 | """ 145 | raise NotImplementedError() 146 | 147 | @abstractmethod 148 | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: 149 | """ 150 | Set attribute inside vectorized environments. 151 | 152 | :param attr_name: The name of attribute to assign new value 153 | :param value: Value to assign to `attr_name` 154 | :param indices: Indices of envs to assign value 155 | :return: 156 | """ 157 | raise NotImplementedError() 158 | 159 | @abstractmethod 160 | def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: 161 | """ 162 | Call instance methods of vectorized environments. 163 | 164 | :param method_name: The name of the environment method to invoke. 165 | :param indices: Indices of envs whose method to call 166 | :param method_args: Any positional arguments to provide in the call 167 | :param method_kwargs: Any keyword arguments to provide in the call 168 | :return: List of items returned by the environment's method call 169 | """ 170 | raise NotImplementedError() 171 | 172 | def step(self, actions: np.ndarray) -> VecEnvStepReturn: 173 | """ 174 | Step the environments with the given action 175 | 176 | :param actions: the action 177 | :return: observation, reward, done, information 178 | """ 179 | self.step_async(actions) 180 | return self.step_wait() 181 | 182 | def get_images(self) -> Sequence[Optional[np.ndarray]]: 183 | """ 184 | Return RGB images from each environment when available 185 | """ 186 | raise NotImplementedError 187 | 188 | def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: 189 | """ 190 | Gym environment rendering 191 | 192 | :param mode: the rendering type 193 | """ 194 | 195 | if mode == "human" and self.render_mode != mode: 196 | # Special case, if the render_mode="rgb_array" 197 | # we can still display that image using opencv 198 | if self.render_mode != "rgb_array": 199 | warnings.warn( 200 | f"You tried to render a VecEnv with mode='{mode}' " 201 | "but the render mode defined when initializing the environment must be " 202 | f"'human' or 'rgb_array', not '{self.render_mode}'." 203 | ) 204 | return None 205 | 206 | elif mode and self.render_mode != mode: 207 | warnings.warn( 208 | f"""Starting from gymnasium v0.26, render modes are determined during the initialization of the environment. 209 | We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode}) 210 | has to be the same as the environment render mode ({self.render_mode}) which is not the case.""" 211 | ) 212 | return None 213 | 214 | mode = mode or self.render_mode 215 | 216 | if mode is None: 217 | warnings.warn("You tried to call render() but no `render_mode` was passed to the env constructor.") 218 | return None 219 | 220 | # mode == self.render_mode == "human" 221 | # In that case, we try to call `self.env.render()` but it might 222 | # crash for subprocesses 223 | if self.render_mode == "human": 224 | self.env_method("render") 225 | return None 226 | 227 | if mode == "rgb_array" or mode == "human": 228 | # call the render method of the environments 229 | images = self.get_images() 230 | # Create a big image by tiling images from subprocesses 231 | bigimg = tile_images(images) # type: ignore[arg-type] 232 | 233 | if mode == "human": 234 | # Display it using OpenCV 235 | import cv2 # pytype:disable=import-error 236 | 237 | cv2.imshow("vecenv", bigimg[:, :, ::-1]) 238 | cv2.waitKey(1) 239 | else: 240 | return bigimg 241 | 242 | else: 243 | # Other render modes: 244 | # In that case, we try to call `self.env.render()` but it might 245 | # crash for subprocesses 246 | # and we don't return the values 247 | self.env_method("render") 248 | return None 249 | 250 | def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: 251 | """ 252 | Sets the random seeds for all environments, based on a given seed. 253 | Each individual environment will still get its own seed, by incrementing the given seed. 254 | WARNING: since gym 0.26, those seeds will only be passed to the environment 255 | at the next reset. 256 | 257 | :param seed: The random seed. May be None for completely random seeding. 258 | :return: Returns a list containing the seeds for each individual env. 259 | Note that all list elements may be None, if the env does not return anything when being seeded. 260 | """ 261 | if seed is None: 262 | # To ensure that subprocesses have different seeds, 263 | # we still populate the seed variable when no argument is passed 264 | seed = int(np.random.randint(0, np.iinfo(np.uint32).max, dtype=np.uint32)) 265 | 266 | self._seeds = [seed + idx for idx in range(self.num_envs)] 267 | return self._seeds 268 | 269 | @property 270 | def unwrapped(self) -> "VecEnv": 271 | if isinstance(self, VecEnvWrapper): 272 | return self.venv.unwrapped 273 | else: 274 | return self 275 | 276 | def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]: 277 | """Check if an attribute reference is being hidden in a recursive call to __getattr__ 278 | 279 | :param name: name of attribute to check for 280 | :param already_found: whether this attribute has already been found in a wrapper 281 | :return: name of module whose attribute is being shadowed, if any. 282 | """ 283 | if hasattr(self, name) and already_found: 284 | return f"{type(self).__module__}.{type(self).__name__}" 285 | else: 286 | return None 287 | 288 | def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]: 289 | """ 290 | Convert a flexibly-typed reference to environment indices to an implied list of indices. 291 | 292 | :param indices: refers to indices of envs. 293 | :return: the implied list of indices. 294 | """ 295 | if indices is None: 296 | indices = range(self.num_envs) 297 | elif isinstance(indices, int): 298 | indices = [indices] 299 | return indices 300 | 301 | 302 | class VecEnvWrapper(VecEnv): 303 | """ 304 | Vectorized environment base class 305 | 306 | :param venv: the vectorized environment to wrap 307 | :param observation_space: the observation space (can be None to load from venv) 308 | :param action_space: the action space (can be None to load from venv) 309 | """ 310 | 311 | def __init__( 312 | self, 313 | venv: VecEnv, 314 | observation_space: Optional[spaces.Space] = None, 315 | action_space: Optional[spaces.Space] = None, 316 | ): 317 | self.venv = venv 318 | 319 | super().__init__( 320 | num_envs=venv.num_envs, 321 | observation_space=observation_space or venv.observation_space, 322 | action_space=action_space or venv.action_space, 323 | ) 324 | self.class_attributes = dict(inspect.getmembers(self.__class__)) 325 | 326 | def step_async(self, actions: np.ndarray) -> None: 327 | self.venv.step_async(actions) 328 | 329 | @abstractmethod 330 | def reset(self) -> State: 331 | pass 332 | 333 | @abstractmethod 334 | def step_wait(self) -> State: 335 | pass 336 | 337 | def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: 338 | return self.venv.seed(seed) 339 | 340 | def close(self) -> None: 341 | return self.venv.close() 342 | 343 | def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: 344 | return self.venv.render(mode=mode) 345 | 346 | def get_images(self) -> Sequence[Optional[np.ndarray]]: 347 | return self.venv.get_images() 348 | 349 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: 350 | return self.venv.get_attr(attr_name, indices) 351 | 352 | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: 353 | return self.venv.set_attr(attr_name, value, indices) 354 | 355 | def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: 356 | return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs) 357 | 358 | def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: 359 | return self.venv.env_is_wrapped(wrapper_class, indices=indices) 360 | 361 | def __getattr__(self, name: str) -> Any: 362 | """Find attribute from wrapped venv(s) if this wrapper does not have it. 363 | Useful for accessing attributes from venvs which are wrapped with multiple wrappers 364 | which have unique attributes of interest. 365 | """ 366 | blocked_class = self.getattr_depth_check(name, already_found=False) 367 | if blocked_class is not None: 368 | own_class = f"{type(self).__module__}.{type(self).__name__}" 369 | error_str = ( 370 | f"Error: Recursive attribute lookup for {name} from {own_class} is " 371 | f"ambiguous and hides attribute from {blocked_class}" 372 | ) 373 | raise AttributeError(error_str) 374 | 375 | return self.getattr_recursive(name) 376 | 377 | def _get_all_attributes(self) -> Dict[str, Any]: 378 | """Get all (inherited) instance and class attributes 379 | 380 | :return: all_attributes 381 | """ 382 | all_attributes = self.__dict__.copy() 383 | all_attributes.update(self.class_attributes) 384 | return all_attributes 385 | 386 | def getattr_recursive(self, name: str) -> Any: 387 | """Recursively check wrappers to find attribute. 388 | 389 | :param name: name of attribute to look for 390 | :return: attribute 391 | """ 392 | all_attributes = self._get_all_attributes() 393 | if name in all_attributes: # attribute is present in this wrapper 394 | attr = getattr(self, name) 395 | elif hasattr(self.venv, "getattr_recursive"): 396 | # Attribute not present, child is wrapper. Call getattr_recursive rather than getattr 397 | # to avoid a duplicate call to getattr_depth_check. 398 | attr = self.venv.getattr_recursive(name) 399 | else: # attribute not present, child is an unwrapped VecEnv 400 | attr = getattr(self.venv, name) 401 | 402 | return attr 403 | 404 | def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]: 405 | """See base class. 406 | 407 | :return: name of module whose attribute is being shadowed, if any. 408 | """ 409 | all_attributes = self._get_all_attributes() 410 | if name in all_attributes and already_found: 411 | # this venv's attribute is being hidden because of a higher venv. 412 | shadowed_wrapper_class: Optional[str] = f"{type(self).__module__}.{type(self).__name__}" 413 | elif name in all_attributes and not already_found: 414 | # we have found the first reference to the attribute. Now check for duplicates. 415 | shadowed_wrapper_class = self.venv.getattr_depth_check(name, True) 416 | else: 417 | # this wrapper does not have the attribute. Keep searching. 418 | shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found) 419 | 420 | return shadowed_wrapper_class 421 | 422 | 423 | class CloudpickleWrapper: 424 | """ 425 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 426 | 427 | :param var: the variable you wish to wrap for pickling with cloudpickle 428 | """ 429 | 430 | def __init__(self, var: Any): 431 | self.var = var 432 | 433 | def __getstate__(self) -> Any: 434 | return cloudpickle.dumps(self.var) 435 | 436 | def __setstate__(self, var: Any) -> None: 437 | self.var = cloudpickle.loads(var) -------------------------------------------------------------------------------- /reinforce.py: -------------------------------------------------------------------------------- 1 | """REINFORCE.""" 2 | 3 | from absl import logging, app 4 | from functools import partial 5 | import os 6 | import pickle 7 | import random 8 | import time 9 | from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union 10 | 11 | # os.environ[ 12 | # "XLA_PYTHON_CLIENT_MEM_FRACTION" 13 | # ] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991 14 | 15 | import flax 16 | import flax.linen as nn 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | import optax 21 | 22 | from envs import make_env, Transition, has_discrete_action_space, is_atari_env 23 | from networks.policy import Policy 24 | from networks.networks import FeedForwardNetwork, ActivationFn, make_policy_network, make_value_network, make_atari_feature_extractor 25 | from networks.distributions import NormalTanhDistribution, ParametricDistribution, PolicyNormalDistribution, DiscreteDistribution 26 | 27 | class Config: 28 | # experiment 29 | experiment_name = 'reinforce_main_det_1' 30 | seed = 10 31 | platform = 'cpu' # CPU or GPU 32 | capture_video = False # Not implemented 33 | write_logs_to_file = False 34 | save_model = False 35 | 36 | # environment 37 | env_id = 'HalfCheetah-v4' 38 | num_envs = 1 # DO NOT CHANGE 39 | parallel_envs = False 40 | clip_actions = False 41 | normalize_observations = True 42 | normalize_rewards = True 43 | clip_observations = 10. 44 | clip_rewards = 10. 45 | eval_env = True 46 | num_eval_episodes = 10 47 | eval_every = 20 48 | deterministic_eval = True 49 | 50 | # algorithm hyperparameters 51 | total_timesteps = int(1e6) * 8 52 | learning_rate = 3e-4 53 | unroll_length = 2048 54 | anneal_lr = True 55 | gamma = 0.99 56 | batch_size = 1 57 | num_minibatches = 1 58 | update_epochs = 1 59 | entropy_cost = 0.00 60 | max_grad_norm = 0.5 61 | reward_scaling = 1. 62 | 63 | # policy params 64 | policy_hidden_layer_sizes: Sequence[int] = (32,) * 4 65 | value_hidden_layer_sizes: Sequence[int] = (256,) * 5 66 | activation: ActivationFn = nn.swish 67 | squash_distribution: bool = True 68 | 69 | # atari params 70 | atari_dense_layer_sizes: Sequence[int] = (512,) 71 | 72 | 73 | Metrics = Mapping[str, jnp.ndarray] 74 | 75 | _PMAP_AXIS_NAME = 'i' 76 | 77 | 78 | def _unpmap(v): 79 | return jax.tree_util.tree_map(lambda x: x[0], v) 80 | 81 | 82 | def _strip_weak_type(tree): 83 | # in order to avoid extra jit recompilations we strip all weak types from user input 84 | def f(leaf): 85 | leaf = jnp.asarray(leaf) 86 | return leaf.astype(leaf.dtype) 87 | return jax.tree_util.tree_map(f, tree) 88 | 89 | 90 | @flax.struct.dataclass 91 | class NetworkParams: 92 | """Contains training state for the learner.""" 93 | policy: Any 94 | value: Any 95 | 96 | 97 | @flax.struct.dataclass 98 | class Networks: 99 | policy_network: FeedForwardNetwork 100 | value_network: FeedForwardNetwork 101 | parametric_action_distribution: Union[ParametricDistribution, DiscreteDistribution] 102 | 103 | 104 | @flax.struct.dataclass 105 | class AtariNetworkParams: 106 | """Contains training state for the learner.""" 107 | feature_extractor: Any 108 | policy: Any 109 | value: Any 110 | 111 | 112 | @flax.struct.dataclass 113 | class AtariNetworks: 114 | feature_extractor: FeedForwardNetwork 115 | policy_network: FeedForwardNetwork 116 | value_network: FeedForwardNetwork 117 | parametric_action_distribution: Union[ParametricDistribution, DiscreteDistribution] 118 | 119 | 120 | @flax.struct.dataclass 121 | class TrainingState: 122 | """Contains training state for the learner.""" 123 | optimizer_state: optax.OptState 124 | params: Union[NetworkParams, AtariNetworkParams] 125 | env_steps: jnp.ndarray 126 | 127 | 128 | def make_inference_fn(agent_networks: Union[Networks, AtariNetworks]): 129 | """Creates params and inference function for the agent.""" 130 | 131 | def make_policy(params: Any, 132 | deterministic: bool = False) -> Policy: 133 | policy_network = agent_networks.policy_network 134 | parametric_action_distribution = agent_networks.parametric_action_distribution 135 | 136 | @jax.jit 137 | def policy(observations: jnp.ndarray, 138 | key_sample: jnp.ndarray) -> Tuple[jnp.ndarray, Mapping[str, Any]]: 139 | logits = policy_network.apply(params, observations) 140 | if deterministic: 141 | return agent_networks.parametric_action_distribution.mode(logits), {} 142 | raw_actions = parametric_action_distribution.sample_no_postprocessing( 143 | logits, key_sample) 144 | log_prob = parametric_action_distribution.log_prob(logits, raw_actions) 145 | postprocessed_actions = parametric_action_distribution.postprocess( 146 | raw_actions) 147 | return postprocessed_actions, { 148 | 'log_prob': log_prob, 149 | 'raw_action': raw_actions 150 | } 151 | 152 | return policy 153 | 154 | return make_policy 155 | 156 | 157 | def make_feature_extraction_fn(agent_networks: AtariNetworks): 158 | """Creates feature extractor for inference.""" 159 | 160 | def make_feature_extractor(params: Any): 161 | shared_feature_extractor = agent_networks.feature_extractor 162 | 163 | @jax.jit 164 | def feature_extractor(observations: jnp.ndarray) -> jnp.ndarray: 165 | return shared_feature_extractor.apply(params, observations) 166 | 167 | return feature_extractor 168 | 169 | return make_feature_extractor 170 | 171 | 172 | def make_networks( 173 | observation_size: int, 174 | action_size: int, 175 | policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, 176 | value_hidden_layer_sizes: Sequence[int] = (256,) * 5, 177 | activation: ActivationFn = nn.swish, 178 | sqash_distribution: bool = True, 179 | discrete_policy: bool = False, 180 | shared_feature_extractor: bool = False, 181 | feature_extractor_dense_hidden_layer_sizes: Optional[Sequence[int]] = (512,), 182 | ) -> Networks: 183 | """Make REINFORCE networks with preprocessor.""" 184 | if discrete_policy: 185 | parametric_action_distribution = DiscreteDistribution( 186 | param_size=action_size) 187 | elif sqash_distribution: 188 | parametric_action_distribution = NormalTanhDistribution( 189 | event_size=action_size) 190 | else: 191 | parametric_action_distribution = PolicyNormalDistribution( 192 | event_size=action_size) 193 | if shared_feature_extractor: 194 | feature_extractor = make_atari_feature_extractor( 195 | obs_size=observation_size, 196 | hidden_layer_sizes=feature_extractor_dense_hidden_layer_sizes, 197 | activation=nn.relu 198 | ) 199 | policy_network = make_policy_network( 200 | parametric_action_distribution.param_size, 201 | feature_extractor_dense_hidden_layer_sizes[-1], 202 | hidden_layer_sizes=(), 203 | activation=activation) 204 | value_network = make_value_network( 205 | feature_extractor_dense_hidden_layer_sizes[-1], 206 | hidden_layer_sizes=(), 207 | activation=activation) 208 | return AtariNetworks( 209 | feature_extractor=feature_extractor, 210 | policy_network=policy_network, 211 | value_network=value_network, 212 | parametric_action_distribution=parametric_action_distribution) 213 | policy_network = make_policy_network( 214 | parametric_action_distribution.param_size, 215 | observation_size, 216 | hidden_layer_sizes=policy_hidden_layer_sizes, 217 | activation=activation) 218 | value_network = make_value_network( 219 | observation_size, 220 | hidden_layer_sizes=value_hidden_layer_sizes, 221 | activation=activation) 222 | 223 | return Networks( 224 | policy_network=policy_network, 225 | value_network=value_network, 226 | parametric_action_distribution=parametric_action_distribution) 227 | 228 | 229 | def compute_returns(truncation: jnp.ndarray, 230 | termination: jnp.ndarray, 231 | rewards: jnp.ndarray, 232 | discount: float = 0.99): 233 | """Calculates the returns. 234 | 235 | Args: 236 | truncation: A float32 tensor of shape [T, B] with truncation signal. 237 | termination: A float32 tensor of shape [T, B] with termination signal. 238 | rewards: A float32 tensor of shape [T, B] containing rewards generated by 239 | following the behaviour policy. 240 | discount: TD discount. 241 | 242 | Returns: 243 | A float32 tensor of shape [T, B]. Can be used as target to 244 | train a baseline (V(x_t) - vs_t)^2. 245 | A float32 tensor of shape [T, B] of advantages. 246 | """ 247 | 248 | truncation_mask = 1 - truncation 249 | acc = jnp.zeros_like(truncation_mask[0]) 250 | returns = [] 251 | 252 | def compute_vs_minus_v_xs(carry, target_t): 253 | _, acc = carry 254 | truncation_mask, reward, termination = target_t 255 | acc = reward + discount * (1 - termination) * truncation_mask * acc 256 | return (_, acc), (acc) 257 | 258 | (_, _), (returns) = jax.lax.scan( 259 | compute_vs_minus_v_xs, (None, acc), 260 | (truncation_mask, rewards, termination), 261 | length=int(truncation_mask.shape[0]), 262 | reverse=True) 263 | return jax.lax.stop_gradient(returns) 264 | 265 | 266 | 267 | def compute_reinforce_loss( 268 | params: Union[NetworkParams, AtariNetworkParams], 269 | data: Transition, 270 | rng: jnp.ndarray, 271 | network: Union[Networks, AtariNetworks], 272 | entropy_cost: float = 1e-4, 273 | discounting: float = 0.9, 274 | reward_scaling: float = 1.0, 275 | shared_feature_extractor: bool = False, 276 | ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: 277 | """Computes REINFORCE loss. 278 | 279 | Policy loss: $L_\pi = G \log \pi_\theta (a \mid s)$ 280 | 281 | Args: 282 | params: Network parameters, 283 | data: Transition that with leading dimension [B, T]. extra fields required 284 | are ['state_extras']['truncation'] ['policy_extras']['raw_action'] 285 | ['policy_extras']['log_prob'] 286 | rng: Random key 287 | network: Agent networks. 288 | entropy_cost: entropy cost. 289 | discounting: discounting, 290 | reward_scaling: reward multiplier. 291 | shared_feature_extractor: Whether networks use a shared feature extractor. 292 | 293 | Returns: 294 | A tuple (loss, metrics) 295 | """ 296 | parametric_action_distribution = network.parametric_action_distribution 297 | 298 | policy_apply = network.policy_network.apply 299 | 300 | # Put the time dimension first. 301 | data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) 302 | 303 | hidden = data.observation 304 | if shared_feature_extractor: 305 | feature_extractor_apply = network.feature_extractor.apply 306 | hidden = feature_extractor_apply(params.feature_extractor, data.observation) 307 | 308 | policy_logits = policy_apply(params.policy, 309 | hidden) 310 | 311 | rewards = data.reward * reward_scaling 312 | truncation = data.extras['state_extras']['truncation'] 313 | termination = (1 - data.discount) * (1 - truncation) 314 | 315 | target_action_log_probs = parametric_action_distribution.log_prob( 316 | policy_logits, data.extras['policy_extras']['raw_action']) 317 | behaviour_action_log_probs = data.extras['policy_extras']['log_prob'] 318 | 319 | returns = compute_returns( 320 | truncation=truncation, 321 | termination=termination, 322 | rewards=rewards, 323 | discount=discounting) 324 | 325 | log_ratio = target_action_log_probs - behaviour_action_log_probs 326 | rho_s = jnp.exp(log_ratio) 327 | 328 | policy_loss = -jnp.mean(target_action_log_probs * returns) 329 | approx_kl = ((rho_s - 1) - log_ratio).mean() 330 | 331 | # Entropy reward 332 | entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng)) 333 | entropy_loss = entropy_cost * -entropy 334 | 335 | total_loss = policy_loss + entropy_loss 336 | 337 | metrics = { 338 | 'total_loss': total_loss, 339 | 'policy_loss': policy_loss, 340 | 'entropy_loss': entropy_loss, 341 | 'entropy': entropy, 342 | 'approx_kl': jax.lax.stop_gradient(approx_kl), 343 | } 344 | 345 | return total_loss, metrics 346 | 347 | 348 | 349 | 350 | def main(_): 351 | run_name = f"Exp_{Config.experiment_name}__{Config.env_id}__{Config.seed}__{int(time.time())}" 352 | 353 | if Config.write_logs_to_file: 354 | from absl import flags 355 | flags.FLAGS.alsologtostderr = True 356 | log_path = f'./training_logs/reinforce/{run_name}' 357 | if not os.path.exists(log_path): 358 | os.makedirs(log_path) 359 | logging.get_absl_handler().use_absl_log_file('logs', log_path) 360 | 361 | logging.get_absl_handler().setFormatter(None) 362 | 363 | # jax set up devices 364 | process_count = jax.process_count() 365 | process_id = jax.process_index() 366 | local_device_count = jax.local_device_count() 367 | local_devices_to_use = local_device_count 368 | device_count = local_devices_to_use * process_count 369 | assert Config.num_envs % device_count == 0 370 | 371 | 372 | assert Config.batch_size * Config.num_minibatches % Config.num_envs == 0 373 | # The number of environment steps executed for every training step. 374 | env_step_per_training_step = ( 375 | Config.batch_size * Config.unroll_length * Config.num_minibatches) 376 | 377 | # log hyperparameters 378 | logging.info("|param: value|") 379 | for key, value in vars(Config).items(): 380 | if not key.startswith('__'): 381 | logging.info(f"|{key}: {value}|") 382 | 383 | random.seed(Config.seed) 384 | np.random.seed(Config.seed) 385 | # handle / split random keys 386 | key = jax.random.PRNGKey(Config.seed) 387 | global_key, local_key = jax.random.split(key) 388 | del key 389 | local_key = jax.random.fold_in(local_key, process_id) 390 | local_key, key_envs, eval_key = jax.random.split(local_key, 3) 391 | # key_networks should be global, so that networks are initialized the same 392 | # way for different processes. 393 | key_policy, key_value, key_feature_extractor = jax.random.split(global_key, 3) 394 | del global_key 395 | 396 | is_atari = is_atari_env(Config.env_id) 397 | envs = make_env( 398 | env_id=Config.env_id, 399 | num_envs=Config.num_envs, 400 | parallel=Config.parallel_envs, 401 | clip_actions=Config.clip_actions, 402 | norm_obs=Config.normalize_observations, 403 | norm_reward=Config.normalize_rewards, 404 | clip_obs=Config.clip_observations, 405 | clip_rewards=Config.clip_rewards, 406 | is_atari=is_atari, 407 | ) 408 | 409 | discrete_action_space = has_discrete_action_space(envs) 410 | envs.seed(int(key_envs[0])) 411 | env_state = envs.reset() 412 | 413 | 414 | if discrete_action_space: 415 | action_size = envs.action_space.n 416 | else: 417 | action_size = np.prod(envs.action_space.shape) # flatten action size for nested spaces 418 | if is_atari: 419 | observation_shape = env_state.obs.shape[-3:] 420 | else: 421 | observation_shape = env_state.obs.shape[-1] 422 | 423 | network = make_networks( 424 | observation_size=observation_shape, # NOTE only works with flattened observation space 425 | action_size=action_size, # flatten action size for nested spaces 426 | policy_hidden_layer_sizes=Config.policy_hidden_layer_sizes, 427 | value_hidden_layer_sizes=Config.value_hidden_layer_sizes, 428 | activation=Config.activation, 429 | sqash_distribution=Config.squash_distribution, 430 | discrete_policy=discrete_action_space, 431 | shared_feature_extractor=is_atari, 432 | feature_extractor_dense_hidden_layer_sizes=Config.atari_dense_layer_sizes, 433 | ) 434 | make_policy = make_inference_fn(network) 435 | if is_atari: 436 | make_feature_extractor = make_feature_extraction_fn(network) 437 | 438 | # create optimizer 439 | if Config.anneal_lr: 440 | learning_rate = optax.linear_schedule( 441 | Config.learning_rate, 442 | Config.learning_rate * 0.01, # 0 443 | transition_steps=Config.total_timesteps, 444 | ) 445 | else: 446 | learning_rate = Config.learning_rate 447 | optimizer = optax.chain( 448 | optax.clip_by_global_norm(Config.max_grad_norm), 449 | optax.adam(learning_rate), 450 | ) 451 | 452 | # create loss function via functools.partial 453 | loss_fn = partial( 454 | compute_reinforce_loss, 455 | network=network, 456 | entropy_cost=Config.entropy_cost, 457 | discounting=Config.gamma, 458 | reward_scaling=Config.reward_scaling, 459 | shared_feature_extractor=is_atari, 460 | ) 461 | 462 | 463 | def loss_and_pgrad(loss_fn: Callable[..., float], 464 | pmap_axis_name: Optional[str], 465 | has_aux: bool = False): 466 | g = jax.value_and_grad(loss_fn, has_aux=has_aux) 467 | 468 | def h(*args, **kwargs): 469 | value, grad = g(*args, **kwargs) 470 | return value, jax.lax.pmean(grad, axis_name=pmap_axis_name) 471 | 472 | return g if pmap_axis_name is None else h 473 | 474 | 475 | def gradient_update_fn(loss_fn: Callable[..., float], 476 | optimizer: optax.GradientTransformation, 477 | pmap_axis_name: Optional[str], 478 | has_aux: bool = False): 479 | """Wrapper of the loss function that apply gradient updates. 480 | 481 | Args: 482 | loss_fn: The loss function. 483 | optimizer: The optimizer to apply gradients. 484 | pmap_axis_name: If relevant, the name of the pmap axis to synchronize 485 | gradients. 486 | has_aux: Whether the loss_fn has auxiliary data. 487 | 488 | Returns: 489 | A function that takes the same argument as the loss function plus the 490 | optimizer state. The output of this function is the loss, the new parameter, 491 | and the new optimizer state. 492 | """ 493 | loss_and_pgrad_fn = loss_and_pgrad( 494 | loss_fn, pmap_axis_name=pmap_axis_name, has_aux=has_aux) 495 | 496 | def f(*args, optimizer_state): 497 | value, grads = loss_and_pgrad_fn(*args) 498 | params_update, optimizer_state = optimizer.update(grads, optimizer_state) 499 | params = optax.apply_updates(args[0], params_update) 500 | return value, params, optimizer_state 501 | 502 | return f 503 | 504 | gradient_update_fn = gradient_update_fn( 505 | loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True 506 | ) 507 | 508 | # minibatch training step 509 | def minibatch_step(carry, data: Transition,): 510 | optimizer_state, params, key = carry 511 | key, key_loss = jax.random.split(key) 512 | (_, metrics), params, optimizer_state = gradient_update_fn( 513 | params, 514 | data, 515 | key_loss, 516 | optimizer_state=optimizer_state) 517 | 518 | return (optimizer_state, params, key), metrics 519 | 520 | 521 | # sgd step 522 | def sgd_step(carry, unused_t, data: Transition): 523 | optimizer_state, params, key = carry 524 | key, key_perm, key_grad = jax.random.split(key, 3) 525 | 526 | def convert_data(x: jnp.ndarray): 527 | x = jax.random.permutation(key_perm, x) 528 | x = jnp.reshape(x, (Config.num_minibatches, -1) + x.shape[1:]) 529 | return x 530 | 531 | shuffled_data = jax.tree_util.tree_map(convert_data, data) 532 | (optimizer_state, params, _), metrics = jax.lax.scan( 533 | minibatch_step, 534 | (optimizer_state, params, key_grad), 535 | shuffled_data, 536 | length=Config.num_minibatches) 537 | return (optimizer_state, params, key), metrics 538 | 539 | 540 | # learning 541 | def learn( 542 | data: Transition, 543 | training_state: TrainingState, 544 | key_sgd: jnp.ndarray, 545 | ): 546 | (optimizer_state, params, _), metrics = jax.lax.scan( 547 | partial( 548 | sgd_step, data=data), 549 | (training_state.optimizer_state, training_state.params, key_sgd), (), 550 | length=Config.update_epochs) 551 | 552 | new_training_state = TrainingState( 553 | optimizer_state=optimizer_state, 554 | params=params, 555 | env_steps=training_state.env_steps + env_step_per_training_step) 556 | 557 | metrics = jax.tree_util.tree_map(jnp.mean, metrics) 558 | return new_training_state, metrics 559 | 560 | learn = jax.pmap(learn, axis_name=_PMAP_AXIS_NAME) 561 | 562 | 563 | # initialize params & training state 564 | if is_atari: 565 | init_params = AtariNetworkParams( 566 | feature_extractor=network.feature_extractor.init(key_feature_extractor), 567 | policy=network.policy_network.init(key_policy), 568 | value=network.value_network.init(key_value)) 569 | else: 570 | init_params = NetworkParams( 571 | policy=network.policy_network.init(key_policy), 572 | value=network.value_network.init(key_value)) 573 | training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray 574 | optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars 575 | params=init_params, 576 | env_steps=0) 577 | training_state = jax.device_put_replicated( 578 | training_state, 579 | jax.local_devices()[:local_devices_to_use]) 580 | 581 | 582 | # create eval env 583 | if Config.eval_env: 584 | eval_env = make_env( 585 | env_id=Config.env_id, 586 | num_envs=1, # Config.num_envs, 587 | parallel=False, # Config.parallel_envs, 588 | norm_obs=False, 589 | norm_reward=False, 590 | clip_obs=Config.clip_observations, 591 | clip_rewards=Config.clip_rewards, 592 | evaluate=True, 593 | ) 594 | eval_env.seed(int(eval_key[0])) 595 | eval_state = eval_env.reset() 596 | 597 | 598 | # initialize metrics 599 | global_step = 0 600 | start_time = time.time() 601 | training_walltime = 0 602 | scores = [] 603 | 604 | # training loop 605 | training_step = 0 606 | while global_step < Config.total_timesteps: 607 | update_time_start = time.time() 608 | training_step += 1 609 | 610 | new_key, local_key = jax.random.split(local_key) 611 | training_state, env_state = _strip_weak_type((training_state, env_state)) 612 | key_sgd, key_generate_unroll = jax.random.split(new_key, 2) 613 | 614 | if is_atari: 615 | feature_extractor = make_feature_extractor(_unpmap(training_state.params.feature_extractor)) 616 | policy = make_policy(_unpmap(training_state.params.policy)) 617 | 618 | data = [] 619 | transitions = [] 620 | episode_steps = 0 621 | while episode_steps < 2000: 622 | env_state = envs.reset() 623 | episode_over = False 624 | while not episode_over: 625 | episode_steps += 1 626 | current_key, key_generate_unroll = jax.random.split(key_generate_unroll) 627 | obs = env_state.obs 628 | if is_atari: 629 | obs = feature_extractor(env_state.obs) 630 | actions, policy_extras = policy(obs, current_key) 631 | actions = np.asarray(actions) 632 | nstate = envs.step(actions) 633 | # NOTE: info transformed: Array[Dict] --> Dict[Array] 634 | state_extras = {'truncation': jnp.array([info['truncation'] for info in nstate.info])} 635 | transition = Transition( 636 | observation=env_state.obs, 637 | action=actions, 638 | reward=nstate.reward, 639 | discount=1 - nstate.done, 640 | next_observation=nstate.obs, 641 | extras={ 642 | 'policy_extras': policy_extras, 643 | 'state_extras': state_extras 644 | }) 645 | transitions.append(transition) 646 | env_state = nstate 647 | 648 | episode_over = any(jnp.logical_or(state_extras['truncation'], nstate.done)) 649 | data.append(jax.tree_util.tree_map(lambda *x: np.stack(x), *transitions)) 650 | data = jax.tree_util.tree_map(lambda *x: np.stack(x), *data) 651 | 652 | epoch_rollout_time = time.time() - update_time_start 653 | update_time_start = time.time() 654 | 655 | # Have leading dimensions (batch_size * num_minibatches, unroll_length) 656 | data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) 657 | data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), 658 | data) 659 | 660 | data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (local_devices_to_use, -1,) + x.shape[1:]), 661 | data) 662 | 663 | # as function 664 | keys_sgd = jax.random.split(key_sgd, local_devices_to_use) 665 | new_training_state, metrics = learn(data=data, training_state=training_state, key_sgd=keys_sgd) 666 | 667 | # logging 668 | training_state, env_state, metrics = _strip_weak_type((new_training_state, env_state, metrics)) 669 | metrics = jax.tree_util.tree_map(jnp.mean, metrics) 670 | jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics) 671 | epoch_update_time = time.time() - update_time_start 672 | training_walltime = time.time() - start_time 673 | 674 | sps = episode_steps / (epoch_update_time + epoch_rollout_time) 675 | global_step += episode_steps 676 | 677 | metrics = { 678 | 'training/total_steps': global_step, 679 | 'training/updates': training_step, 680 | 'training/sps': np.round(sps, 3), 681 | 'training/walltime': np.round(training_walltime, 3), 682 | 'training/rollout_time': np.round(epoch_rollout_time, 3), 683 | 'training/update_time': np.round(epoch_update_time, 3), 684 | **{f'training/{name}': float(value) for name, value in metrics.items()} 685 | } 686 | 687 | logging.info(metrics) 688 | 689 | # run eval 690 | if process_id == 0 and Config.eval_env and training_step % Config.eval_every == 0: 691 | eval_start_time = time.time() 692 | eval_steps = 0 693 | if is_atari: 694 | feature_extractor = make_feature_extractor(_unpmap(training_state.params.feature_extractor)) 695 | policy_params = _unpmap(training_state.params.policy) 696 | policy = make_policy(policy_params, deterministic=Config.deterministic_eval) 697 | while True: 698 | eval_steps += 1 699 | 700 | # run eval episode & record scores + lengths 701 | current_key, eval_key = jax.random.split(eval_key) 702 | obs = envs.normalize_obs(eval_state.obs) if Config.normalize_observations else eval_state.obs 703 | if is_atari: 704 | obs = feature_extractor(env_state.obs) 705 | actions, policy_extras = policy(obs, current_key) 706 | actions = np.asarray(actions) 707 | eval_state = eval_env.step(actions) 708 | if len(eval_env.returns) >= Config.num_eval_episodes: 709 | eval_returns, eval_ep_lengths = eval_env.evaluate() 710 | break 711 | eval_state = eval_env.reset() 712 | eval_time = time.time() - eval_start_time 713 | # compute mean + std & record 714 | eval_metrics = { 715 | 'eval/num_episodes': len(eval_returns), 716 | 'eval/num_steps': eval_steps, 717 | 'eval/mean_score': np.round(np.mean(eval_returns), 3), 718 | 'eval/std_score': np.round(np.std(eval_returns), 3), 719 | 'eval/mean_episode_length': np.mean(eval_ep_lengths), 720 | 'eval/std_episode_length': np.round(np.std(eval_ep_lengths), 3), 721 | 'eval/eval_time': eval_time, 722 | } 723 | logging.info(eval_metrics) 724 | scores.append((global_step, np.mean(eval_returns), np.mean(eval_ep_lengths), metrics['training/approx_kl'])) 725 | 726 | logging.info('TRAINING END: training duration: %s', time.time() - start_time) 727 | 728 | # final eval 729 | if process_id == 0 and Config.eval_env: 730 | eval_steps = 0 731 | if is_atari: 732 | feature_extractor = make_feature_extractor(_unpmap(training_state.params.feature_extractor)) 733 | policy_params = _unpmap(training_state.params.policy) 734 | policy = make_policy(policy_params, deterministic=True) 735 | while True: 736 | eval_steps += 1 737 | 738 | # run eval episode & record scores + lengths 739 | current_key, eval_key = jax.random.split(eval_key) 740 | obs = envs.normalize_obs(eval_state.obs) if Config.normalize_observations else eval_state.obs 741 | if is_atari: 742 | obs = feature_extractor(env_state.obs) 743 | actions, policy_extras = policy(obs, current_key) 744 | actions = np.asarray(actions) 745 | eval_state = eval_env.step(actions) 746 | if len(eval_env.returns) >= Config.num_eval_episodes: 747 | eval_returns, eval_ep_lengths = eval_env.evaluate() 748 | break 749 | eval_state = eval_env.reset() 750 | # compute mean + std & record 751 | eval_metrics = { 752 | 'final_eval/num_episodes': len(eval_returns), 753 | 'final_eval/num_steps': eval_steps, 754 | 'final_eval/mean_score': np.mean(eval_returns), 755 | 'final_eval/std_score': np.std(eval_returns), 756 | 'final_eval/mean_episode_length': np.mean(eval_ep_lengths), 757 | 'final_eval/std_episode_length': np.std(eval_ep_lengths), 758 | } 759 | logging.info(eval_metrics) 760 | scores.append((global_step, np.mean(eval_returns), np.mean(eval_ep_lengths), None)) 761 | 762 | # save scores 763 | run_dir = os.path.join('experiments', run_name) 764 | if not os.path.exists(run_dir): 765 | os.makedirs(run_dir) 766 | with open(os.path.join(run_dir, "scores.pkl"), "wb") as f: 767 | pickle.dump(scores, f) 768 | 769 | if Config.save_model: 770 | model_path = f"weights/{run_name}.params" 771 | with open(model_path, "wb") as f: 772 | f.write( 773 | flax.serialization.to_bytes( 774 | [ 775 | vars(Config), 776 | [ 777 | training_state.params.policy, 778 | training_state.params.value, 779 | # agent_state.params.feature_extractor, 780 | ], 781 | ] 782 | ) 783 | ) 784 | print(f"model saved to {model_path}") 785 | 786 | envs.close() 787 | 788 | 789 | if __name__ == "__main__": 790 | app.run(main) -------------------------------------------------------------------------------- /a2c.py: -------------------------------------------------------------------------------- 1 | """Advantage Actor-Critic (A2C).""" 2 | 3 | from absl import logging, app 4 | from functools import partial 5 | import os 6 | import pickle 7 | import random 8 | import time 9 | from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union 10 | 11 | # os.environ[ 12 | # "XLA_PYTHON_CLIENT_MEM_FRACTION" 13 | # ] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991 14 | 15 | import flax 16 | import flax.linen as nn 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | import optax 21 | 22 | from envs import make_env, Transition, has_discrete_action_space, is_atari_env 23 | from networks.policy import Policy 24 | from networks.networks import FeedForwardNetwork, ActivationFn, make_policy_network, make_value_network, make_atari_feature_extractor 25 | from networks.distributions import NormalTanhDistribution, ParametricDistribution, PolicyNormalDistribution, DiscreteDistribution 26 | 27 | class Config: 28 | # experiment 29 | experiment_name = 'a2c_epochs2_det_1' 30 | seed = 42 31 | platform = 'cpu' # CPU or GPU 32 | capture_video = False # not implemented 33 | write_logs_to_file = False 34 | save_model = False 35 | 36 | # environment 37 | env_id = 'HalfCheetah-v4' 38 | num_envs = 8 39 | parallel_envs = False 40 | clip_actions = False 41 | normalize_observations = True 42 | normalize_rewards = True 43 | clip_observations = 10. 44 | clip_rewards = 10. 45 | eval_env = True 46 | num_eval_episodes = 10 47 | eval_every = 2 48 | deterministic_eval = True 49 | 50 | # algorithm hyperparameters 51 | total_timesteps = int(1e6) * 8 52 | learning_rate = 3e-4 53 | unroll_length = 2048 54 | anneal_lr = True 55 | gamma = 0.99 56 | gae_lambda = 0.95 57 | batch_size = 1 # number of unrolls per minibatch 58 | num_minibatches = 8 59 | update_epochs = 1 60 | normalize_advantages = True 61 | entropy_cost = 0.1 62 | vf_cost = 0.5 63 | max_grad_norm = 0.5 64 | reward_scaling = 1. 65 | 66 | # policy params 67 | policy_hidden_layer_sizes: Sequence[int] = (32,) * 4 68 | value_hidden_layer_sizes: Sequence[int] = (256,) * 5 69 | activation: ActivationFn = nn.swish 70 | squash_distribution: bool = True 71 | 72 | # atari params 73 | atari_dense_layer_sizes: Sequence[int] = (512,) 74 | 75 | 76 | Metrics = Mapping[str, jnp.ndarray] 77 | 78 | _PMAP_AXIS_NAME = 'i' 79 | 80 | 81 | def _unpmap(v): 82 | return jax.tree_util.tree_map(lambda x: x[0], v) 83 | 84 | 85 | def _strip_weak_type(tree): 86 | # in order to avoid extra jit recompilations we strip all weak types from user input 87 | def f(leaf): 88 | leaf = jnp.asarray(leaf) 89 | return leaf.astype(leaf.dtype) 90 | return jax.tree_util.tree_map(f, tree) 91 | 92 | 93 | @flax.struct.dataclass 94 | class NetworkParams: 95 | """Contains training state for the learner.""" 96 | policy: Any 97 | value: Any 98 | 99 | 100 | @flax.struct.dataclass 101 | class Networks: 102 | policy_network: FeedForwardNetwork 103 | value_network: FeedForwardNetwork 104 | parametric_action_distribution: Union[ParametricDistribution, DiscreteDistribution] 105 | 106 | 107 | @flax.struct.dataclass 108 | class AtariNetworkParams: 109 | """Contains training state for the learner.""" 110 | feature_extractor: Any 111 | policy: Any 112 | value: Any 113 | 114 | 115 | @flax.struct.dataclass 116 | class AtariNetworks: 117 | feature_extractor: FeedForwardNetwork 118 | policy_network: FeedForwardNetwork 119 | value_network: FeedForwardNetwork 120 | parametric_action_distribution: Union[ParametricDistribution, DiscreteDistribution] 121 | 122 | 123 | @flax.struct.dataclass 124 | class TrainingState: 125 | """Contains training state for the learner.""" 126 | optimizer_state: optax.OptState 127 | params: Union[NetworkParams, AtariNetworkParams] 128 | env_steps: jnp.ndarray 129 | 130 | 131 | def make_inference_fn(agent_networks: Union[Networks, AtariNetworks]): 132 | """Creates params and inference function for the agent.""" 133 | 134 | def make_policy(params: Any, 135 | deterministic: bool = False) -> Policy: 136 | policy_network = agent_networks.policy_network 137 | parametric_action_distribution = agent_networks.parametric_action_distribution 138 | 139 | @jax.jit 140 | def policy(observations: jnp.ndarray, 141 | key_sample: jnp.ndarray) -> Tuple[jnp.ndarray, Mapping[str, Any]]: 142 | logits = policy_network.apply(params, observations) 143 | if deterministic: 144 | return agent_networks.parametric_action_distribution.mode(logits), {} 145 | raw_actions = parametric_action_distribution.sample_no_postprocessing( 146 | logits, key_sample) 147 | log_prob = parametric_action_distribution.log_prob(logits, raw_actions) 148 | postprocessed_actions = parametric_action_distribution.postprocess( 149 | raw_actions) 150 | return postprocessed_actions, { 151 | 'log_prob': log_prob, 152 | 'raw_action': raw_actions 153 | } 154 | 155 | return policy 156 | 157 | return make_policy 158 | 159 | 160 | def make_feature_extraction_fn(agent_networks: AtariNetworks): 161 | """Creates feature extractor for inference.""" 162 | 163 | def make_feature_extractor(params: Any): 164 | shared_feature_extractor = agent_networks.feature_extractor 165 | 166 | @jax.jit 167 | def feature_extractor(observations: jnp.ndarray) -> jnp.ndarray: 168 | return shared_feature_extractor.apply(params, observations) 169 | 170 | return feature_extractor 171 | 172 | return make_feature_extractor 173 | 174 | 175 | def make_networks( 176 | observation_size: int, 177 | action_size: int, 178 | policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, 179 | value_hidden_layer_sizes: Sequence[int] = (256,) * 5, 180 | activation: ActivationFn = nn.swish, 181 | sqash_distribution: bool = True, 182 | discrete_policy: bool = False, 183 | shared_feature_extractor: bool = False, 184 | feature_extractor_dense_hidden_layer_sizes: Optional[Sequence[int]] = (512,), 185 | ) -> Networks: 186 | """Make A2C networks with preprocessor.""" 187 | if discrete_policy: 188 | parametric_action_distribution = DiscreteDistribution( 189 | param_size=action_size) 190 | elif sqash_distribution: 191 | parametric_action_distribution = NormalTanhDistribution( 192 | event_size=action_size) 193 | else: 194 | parametric_action_distribution = PolicyNormalDistribution( 195 | event_size=action_size) 196 | if shared_feature_extractor: 197 | feature_extractor = make_atari_feature_extractor( 198 | obs_size=observation_size, 199 | hidden_layer_sizes=feature_extractor_dense_hidden_layer_sizes, 200 | activation=nn.relu 201 | ) 202 | policy_network = make_policy_network( 203 | parametric_action_distribution.param_size, 204 | feature_extractor_dense_hidden_layer_sizes[-1], 205 | hidden_layer_sizes=(), 206 | activation=activation) 207 | value_network = make_value_network( 208 | feature_extractor_dense_hidden_layer_sizes[-1], 209 | hidden_layer_sizes=(), 210 | activation=activation) 211 | return AtariNetworks( 212 | feature_extractor=feature_extractor, 213 | policy_network=policy_network, 214 | value_network=value_network, 215 | parametric_action_distribution=parametric_action_distribution) 216 | policy_network = make_policy_network( 217 | parametric_action_distribution.param_size, 218 | observation_size, 219 | hidden_layer_sizes=policy_hidden_layer_sizes, 220 | activation=activation) 221 | value_network = make_value_network( 222 | observation_size, 223 | hidden_layer_sizes=value_hidden_layer_sizes, 224 | activation=activation) 225 | 226 | return Networks( 227 | policy_network=policy_network, 228 | value_network=value_network, 229 | parametric_action_distribution=parametric_action_distribution) 230 | 231 | 232 | def compute_gae(truncation: jnp.ndarray, 233 | termination: jnp.ndarray, 234 | rewards: jnp.ndarray, 235 | values: jnp.ndarray, 236 | bootstrap_value: jnp.ndarray, 237 | lambda_: float = 1.0, 238 | discount: float = 0.99): 239 | """Calculates the Generalized Advantage Estimation (GAE). 240 | 241 | Args: 242 | truncation: A float32 tensor of shape [T, B] with truncation signal. 243 | termination: A float32 tensor of shape [T, B] with termination signal. 244 | rewards: A float32 tensor of shape [T, B] containing rewards generated by 245 | following the behaviour policy. 246 | values: A float32 tensor of shape [T, B] with the value function estimates 247 | wrt. the target policy. 248 | bootstrap_value: A float32 of shape [B] with the value function estimate at 249 | time T. 250 | lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). Defaults to 251 | lambda_=1. 252 | discount: TD discount. 253 | 254 | Returns: 255 | A float32 tensor of shape [T, B]. Can be used as target to 256 | train a baseline (V(x_t) - vs_t)^2. 257 | A float32 tensor of shape [T, B] of advantages. 258 | """ 259 | 260 | truncation_mask = 1 - truncation 261 | # Append bootstrapped value to get [v1, ..., v_t+1] 262 | values_t_plus_1 = jnp.concatenate( 263 | [values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) 264 | deltas = rewards + discount * (1 - termination) * values_t_plus_1 - values 265 | deltas *= truncation_mask 266 | 267 | acc = jnp.zeros_like(bootstrap_value) 268 | vs_minus_v_xs = [] 269 | 270 | def compute_vs_minus_v_xs(carry, target_t): 271 | lambda_, acc = carry 272 | truncation_mask, delta, termination = target_t 273 | acc = delta + discount * (1 - termination) * truncation_mask * lambda_ * acc 274 | return (lambda_, acc), (acc) 275 | 276 | (_, _), (vs_minus_v_xs) = jax.lax.scan( 277 | compute_vs_minus_v_xs, (lambda_, acc), 278 | (truncation_mask, deltas, termination), 279 | length=int(truncation_mask.shape[0]), 280 | reverse=True) 281 | # Add V(x_s) to get v_s. 282 | vs = jnp.add(vs_minus_v_xs, values) 283 | 284 | vs_t_plus_1 = jnp.concatenate( 285 | [vs[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) 286 | advantages = (rewards + discount * 287 | (1 - termination) * vs_t_plus_1 - values) * truncation_mask 288 | return jax.lax.stop_gradient(vs), jax.lax.stop_gradient(advantages) 289 | 290 | 291 | 292 | def compute_a2c_loss( 293 | params: Union[NetworkParams, AtariNetworkParams], 294 | data: Transition, 295 | rng: jnp.ndarray, 296 | network: Union[Networks, AtariNetworks], 297 | vf_cost: float = 0.5, 298 | entropy_cost: float = 1e-4, 299 | discounting: float = 0.9, 300 | reward_scaling: float = 1.0, 301 | gae_lambda: float = 0.95, 302 | normalize_advantage: bool = True, 303 | shared_feature_extractor: bool = False, 304 | ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: 305 | """Computes A2C loss including value loss and entropy bonus. 306 | 307 | Policy loss: $L_\pi = \frac{1}{\lvert \mathcal{D} \rvert} 308 | \sum_{\mathcal{D}} \hat{A} \log \pi_{\theta} (a \mid s)$ 309 | 310 | Args: 311 | params: Network parameters, 312 | data: Transition that with leading dimension [B, T]. extra fields required 313 | are ['state_extras']['truncation'] ['policy_extras']['raw_action'] 314 | ['policy_extras']['log_prob'] 315 | rng: Random key 316 | network: A2C networks. 317 | entropy_cost: entropy cost. 318 | discounting: discounting, 319 | reward_scaling: reward multiplier. 320 | gae_lambda: General advantage estimation lambda. 321 | normalize_advantage: whether to normalize advantage estimate 322 | shared_feature_extractor: Whether networks use a shared feature extractor. 323 | 324 | Returns: 325 | A tuple (loss, metrics) 326 | """ 327 | parametric_action_distribution = network.parametric_action_distribution 328 | 329 | policy_apply = network.policy_network.apply 330 | value_apply = network.value_network.apply 331 | 332 | # Put the time dimension first. 333 | data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) 334 | 335 | hidden = data.observation 336 | hidden_boot = data.next_observation[-1] 337 | if shared_feature_extractor: 338 | feature_extractor_apply = network.feature_extractor.apply 339 | hidden = feature_extractor_apply(params.feature_extractor, data.observation) 340 | hidden_boot = feature_extractor_apply(params.feature_extractor, 341 | data.next_observation[-1]) 342 | 343 | policy_logits = policy_apply(params.policy, 344 | hidden) 345 | 346 | baseline = value_apply(params.value, hidden) 347 | 348 | 349 | bootstrap_value = value_apply(params.value, 350 | hidden_boot) 351 | 352 | rewards = data.reward * reward_scaling 353 | truncation = data.extras['state_extras']['truncation'] 354 | termination = (1 - data.discount) * (1 - truncation) 355 | 356 | target_action_log_probs = parametric_action_distribution.log_prob( 357 | policy_logits, data.extras['policy_extras']['raw_action']) 358 | behaviour_action_log_probs = data.extras['policy_extras']['log_prob'] 359 | 360 | vs, advantages = compute_gae( 361 | truncation=truncation, 362 | termination=termination, 363 | rewards=rewards, 364 | values=baseline, 365 | bootstrap_value=bootstrap_value, 366 | lambda_=gae_lambda, 367 | discount=discounting) 368 | if normalize_advantage: 369 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) 370 | log_ratio = target_action_log_probs - behaviour_action_log_probs 371 | rho_s = jnp.exp(log_ratio) 372 | 373 | policy_loss = -jnp.mean(target_action_log_probs * advantages) 374 | approx_kl = ((rho_s - 1) - log_ratio).mean() 375 | 376 | # Value function loss 377 | v_error = vs - baseline 378 | v_loss = jnp.mean(v_error * v_error) * 0.5 * vf_cost 379 | 380 | # Entropy reward 381 | entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng)) 382 | entropy_loss = entropy_cost * -entropy 383 | 384 | total_loss = policy_loss + v_loss + entropy_loss 385 | 386 | metrics = { 387 | 'total_loss': total_loss, 388 | 'policy_loss': policy_loss, 389 | 'value_loss': v_loss, 390 | 'entropy_loss': entropy_loss, 391 | 'entropy': entropy, 392 | 'approx_kl': jax.lax.stop_gradient(approx_kl), 393 | } 394 | 395 | return total_loss, metrics 396 | 397 | 398 | 399 | 400 | def main(_): 401 | run_name = f"Exp_{Config.experiment_name}__{Config.env_id}__{Config.seed}__{int(time.time())}" 402 | 403 | if Config.write_logs_to_file: 404 | from absl import flags 405 | flags.FLAGS.alsologtostderr = True 406 | log_path = f'./training_logs/a2c/{run_name}' 407 | if not os.path.exists(log_path): 408 | os.makedirs(log_path) 409 | logging.get_absl_handler().use_absl_log_file('logs', log_path) 410 | 411 | logging.get_absl_handler().setFormatter(None) 412 | 413 | # jax set up devices 414 | process_count = jax.process_count() 415 | process_id = jax.process_index() 416 | local_device_count = jax.local_device_count() 417 | local_devices_to_use = local_device_count 418 | device_count = local_devices_to_use * process_count 419 | assert Config.num_envs % device_count == 0 420 | 421 | 422 | assert Config.batch_size * Config.num_minibatches % Config.num_envs == 0 423 | # The number of environment steps executed for every training step. 424 | env_step_per_training_step = ( 425 | Config.batch_size * Config.unroll_length * Config.num_minibatches) 426 | num_training_steps = np.ceil(Config.total_timesteps / env_step_per_training_step).astype(int) 427 | 428 | # log hyperparameters 429 | logging.info("|param: value|") 430 | for key, value in vars(Config).items(): 431 | if not key.startswith('__'): 432 | logging.info(f"|{key}: {value}|") 433 | 434 | random.seed(Config.seed) 435 | np.random.seed(Config.seed) 436 | # handle / split random keys 437 | key = jax.random.PRNGKey(Config.seed) 438 | global_key, local_key = jax.random.split(key) 439 | del key 440 | local_key = jax.random.fold_in(local_key, process_id) 441 | local_key, key_envs, eval_key = jax.random.split(local_key, 3) 442 | # key_networks should be global, so that networks are initialized the same 443 | # way for different processes. 444 | key_policy, key_value, key_feature_extractor = jax.random.split(global_key, 3) 445 | del global_key 446 | 447 | is_atari = is_atari_env(Config.env_id) 448 | envs = make_env( 449 | env_id=Config.env_id, 450 | num_envs=Config.num_envs, 451 | parallel=Config.parallel_envs, 452 | clip_actions=Config.clip_actions, 453 | norm_obs=Config.normalize_observations, 454 | norm_reward=Config.normalize_rewards, 455 | clip_obs=Config.clip_observations, 456 | clip_rewards=Config.clip_rewards, 457 | is_atari=is_atari, 458 | ) 459 | 460 | discrete_action_space = has_discrete_action_space(envs) 461 | envs.seed(int(key_envs[0])) 462 | env_state = envs.reset() 463 | 464 | 465 | if discrete_action_space: 466 | action_size = envs.action_space.n 467 | else: 468 | action_size = np.prod(envs.action_space.shape) # flatten action size for nested spaces 469 | if is_atari: 470 | observation_shape = env_state.obs.shape[-3:] 471 | else: 472 | observation_shape = env_state.obs.shape[-1] 473 | 474 | network = make_networks( 475 | observation_size=observation_shape, # NOTE only works with flattened observation space 476 | action_size=action_size, # flatten action size for nested spaces 477 | policy_hidden_layer_sizes=Config.policy_hidden_layer_sizes, 478 | value_hidden_layer_sizes=Config.value_hidden_layer_sizes, 479 | activation=Config.activation, 480 | sqash_distribution=Config.squash_distribution, 481 | discrete_policy=discrete_action_space, 482 | shared_feature_extractor=is_atari, 483 | feature_extractor_dense_hidden_layer_sizes=Config.atari_dense_layer_sizes, 484 | ) 485 | make_policy = make_inference_fn(network) 486 | if is_atari: 487 | make_feature_extractor = make_feature_extraction_fn(network) 488 | 489 | # create optimizer 490 | if Config.anneal_lr: 491 | learning_rate = optax.linear_schedule( 492 | Config.learning_rate, 493 | Config.learning_rate * 0.01, # 0 494 | transition_steps=Config.total_timesteps, 495 | ) 496 | else: 497 | learning_rate = Config.learning_rate 498 | optimizer = optax.chain( 499 | optax.clip_by_global_norm(Config.max_grad_norm), 500 | optax.adam(learning_rate), 501 | ) 502 | 503 | # create loss function via functools.partial 504 | loss_fn = partial( 505 | compute_a2c_loss, 506 | network=network, 507 | vf_cost=Config.vf_cost, 508 | entropy_cost=Config.entropy_cost, 509 | discounting=Config.gamma, 510 | reward_scaling=Config.reward_scaling, 511 | gae_lambda=Config.gae_lambda, 512 | normalize_advantage=Config.normalize_advantages, 513 | shared_feature_extractor=is_atari, 514 | ) 515 | 516 | 517 | def loss_and_pgrad(loss_fn: Callable[..., float], 518 | pmap_axis_name: Optional[str], 519 | has_aux: bool = False): 520 | g = jax.value_and_grad(loss_fn, has_aux=has_aux) 521 | 522 | def h(*args, **kwargs): 523 | value, grad = g(*args, **kwargs) 524 | return value, jax.lax.pmean(grad, axis_name=pmap_axis_name) 525 | 526 | return g if pmap_axis_name is None else h 527 | 528 | 529 | def gradient_update_fn(loss_fn: Callable[..., float], 530 | optimizer: optax.GradientTransformation, 531 | pmap_axis_name: Optional[str], 532 | has_aux: bool = False): 533 | """Wrapper of the loss function that apply gradient updates. 534 | 535 | Args: 536 | loss_fn: The loss function. 537 | optimizer: The optimizer to apply gradients. 538 | pmap_axis_name: If relevant, the name of the pmap axis to synchronize 539 | gradients. 540 | has_aux: Whether the loss_fn has auxiliary data. 541 | 542 | Returns: 543 | A function that takes the same argument as the loss function plus the 544 | optimizer state. The output of this function is the loss, the new parameter, 545 | and the new optimizer state. 546 | """ 547 | loss_and_pgrad_fn = loss_and_pgrad( 548 | loss_fn, pmap_axis_name=pmap_axis_name, has_aux=has_aux) 549 | 550 | def f(*args, optimizer_state): 551 | value, grads = loss_and_pgrad_fn(*args) 552 | params_update, optimizer_state = optimizer.update(grads, optimizer_state) 553 | params = optax.apply_updates(args[0], params_update) 554 | return value, params, optimizer_state 555 | 556 | return f 557 | 558 | gradient_update_fn = gradient_update_fn( 559 | loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True 560 | ) 561 | 562 | # minibatch training step 563 | def minibatch_step(carry, data: Transition,): 564 | optimizer_state, params, key = carry 565 | key, key_loss = jax.random.split(key) 566 | (_, metrics), params, optimizer_state = gradient_update_fn( 567 | params, 568 | data, 569 | key_loss, 570 | optimizer_state=optimizer_state) 571 | 572 | return (optimizer_state, params, key), metrics 573 | 574 | 575 | # sgd step 576 | def sgd_step(carry, unused_t, data: Transition): 577 | optimizer_state, params, key = carry 578 | key, key_perm, key_grad = jax.random.split(key, 3) 579 | 580 | def convert_data(x: jnp.ndarray): 581 | x = jax.random.permutation(key_perm, x) 582 | x = jnp.reshape(x, (Config.num_minibatches, -1) + x.shape[1:]) 583 | return x 584 | 585 | shuffled_data = jax.tree_util.tree_map(convert_data, data) 586 | (optimizer_state, params, _), metrics = jax.lax.scan( 587 | minibatch_step, # partial(minibatch_step, normalizer_params=normalizer_params), 588 | (optimizer_state, params, key_grad), 589 | shuffled_data, 590 | length=Config.num_minibatches) 591 | return (optimizer_state, params, key), metrics 592 | 593 | 594 | # learning 595 | def learn( 596 | data: Transition, 597 | training_state: TrainingState, 598 | key_sgd: jnp.ndarray, 599 | ): 600 | (optimizer_state, params, _), metrics = jax.lax.scan( 601 | partial( 602 | sgd_step, data=data), 603 | (training_state.optimizer_state, training_state.params, key_sgd), (), 604 | length=Config.update_epochs) 605 | 606 | new_training_state = TrainingState( 607 | optimizer_state=optimizer_state, 608 | params=params, 609 | env_steps=training_state.env_steps + env_step_per_training_step) 610 | 611 | metrics = jax.tree_util.tree_map(jnp.mean, metrics) 612 | return new_training_state, metrics 613 | 614 | learn = jax.pmap(learn, axis_name=_PMAP_AXIS_NAME) 615 | 616 | # initialize params & training state 617 | if is_atari: 618 | init_params = AtariNetworkParams( 619 | feature_extractor=network.feature_extractor.init(key_feature_extractor), 620 | policy=network.policy_network.init(key_policy), 621 | value=network.value_network.init(key_value)) 622 | else: 623 | init_params = NetworkParams( 624 | policy=network.policy_network.init(key_policy), 625 | value=network.value_network.init(key_value)) 626 | training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray 627 | optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars 628 | params=init_params, 629 | env_steps=0) 630 | training_state = jax.device_put_replicated( 631 | training_state, 632 | jax.local_devices()[:local_devices_to_use]) 633 | 634 | 635 | # create eval env 636 | if Config.eval_env: 637 | eval_env = make_env( 638 | env_id=Config.env_id, 639 | num_envs=1, 640 | parallel=False, 641 | norm_obs=False, 642 | norm_reward=False, 643 | clip_obs=Config.clip_observations, 644 | clip_rewards=Config.clip_rewards, 645 | evaluate=True, 646 | ) 647 | eval_env.seed(int(eval_key[0])) 648 | eval_state = eval_env.reset() 649 | 650 | 651 | # initialize metrics 652 | global_step = 0 653 | start_time = time.time() 654 | training_walltime = 0 655 | scores = [] 656 | 657 | # training loop 658 | for training_step in range(1, num_training_steps + 1): 659 | update_time_start = time.time() 660 | 661 | new_key, local_key = jax.random.split(local_key) 662 | training_state, env_state = _strip_weak_type((training_state, env_state)) 663 | key_sgd, key_generate_unroll = jax.random.split(new_key, 2) 664 | 665 | if is_atari: 666 | feature_extractor = make_feature_extractor(_unpmap(training_state.params.feature_extractor)) 667 | policy = make_policy(_unpmap(training_state.params.policy)) 668 | 669 | data = [] 670 | for step in range(Config.batch_size * Config.num_minibatches // Config.num_envs): 671 | transitions = [] 672 | for unroll_step in range(Config.unroll_length): 673 | current_key, key_generate_unroll = jax.random.split(key_generate_unroll) 674 | obs = env_state.obs 675 | if is_atari: 676 | obs = feature_extractor(env_state.obs) 677 | actions, policy_extras = policy(obs, current_key) 678 | actions = np.asarray(actions) 679 | nstate = envs.step(actions) 680 | 681 | # NOTE: info is transformed as expected: Array[Dict] --> Dict[Array] 682 | state_extras = {'truncation': jnp.array([info['truncation'] for info in nstate.info])} 683 | transition = Transition( 684 | observation=env_state.obs, 685 | action=actions, 686 | reward=nstate.reward, 687 | discount=1 - nstate.done, 688 | next_observation=nstate.obs, 689 | extras={ 690 | 'policy_extras': policy_extras, 691 | 'state_extras': state_extras 692 | }) 693 | transitions.append(transition) 694 | env_state = nstate 695 | data.append(jax.tree_util.tree_map(lambda *x: np.stack(x), *transitions)) 696 | data = jax.tree_util.tree_map(lambda *x: np.stack(x), *data) 697 | 698 | epoch_rollout_time = time.time() - update_time_start 699 | update_time_start = time.time() 700 | 701 | # Have leading dimensions (batch_size * num_minibatches, unroll_length) 702 | data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) 703 | data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), 704 | data) 705 | assert data.discount.shape[1:] == (Config.unroll_length,) 706 | 707 | data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (local_devices_to_use, -1,) + x.shape[1:]), 708 | data) 709 | 710 | keys_sgd = jax.random.split(key_sgd, local_devices_to_use) 711 | new_training_state, metrics = learn(data=data, training_state=training_state, key_sgd=keys_sgd) 712 | 713 | # logging 714 | training_state, env_state, metrics = _strip_weak_type((new_training_state, env_state, metrics)) 715 | metrics = jax.tree_util.tree_map(jnp.mean, metrics) 716 | jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics) 717 | epoch_update_time = time.time() - update_time_start 718 | training_walltime = time.time() - start_time 719 | sps = env_step_per_training_step / (epoch_update_time + epoch_rollout_time) 720 | global_step += env_step_per_training_step 721 | 722 | current_step = int(_unpmap(training_state.env_steps)) 723 | 724 | metrics = { 725 | 'training/total_steps': current_step, 726 | 'training/updates': training_step, 727 | 'training/sps': np.round(sps, 3), 728 | 'training/walltime': np.round(training_walltime, 3), 729 | 'training/rollout_time': np.round(epoch_rollout_time, 3), 730 | 'training/update_time': np.round(epoch_update_time, 3), 731 | **{f'training/{name}': float(value) for name, value in metrics.items()} 732 | } 733 | 734 | logging.info(metrics) 735 | 736 | # run eval 737 | if process_id == 0 and Config.eval_env and training_step % Config.eval_every == 0: 738 | eval_start_time = time.time() 739 | eval_steps = 0 740 | if is_atari: 741 | feature_extractor = make_feature_extractor(_unpmap(training_state.params.feature_extractor)) 742 | policy_params = _unpmap(training_state.params.policy) 743 | policy = make_policy(policy_params, deterministic=Config.deterministic_eval) 744 | while True: 745 | eval_steps += 1 746 | 747 | # run eval episode & record scores + lengths 748 | current_key, eval_key = jax.random.split(eval_key) 749 | obs = envs.normalize_obs(eval_state.obs) if Config.normalize_observations else eval_state.obs 750 | if is_atari: 751 | obs = feature_extractor(env_state.obs) 752 | actions, policy_extras = policy(obs, current_key) 753 | actions = np.asarray(actions) 754 | eval_state = eval_env.step(actions) 755 | if len(eval_env.returns) >= Config.num_eval_episodes: 756 | eval_returns, eval_ep_lengths = eval_env.evaluate() 757 | break 758 | eval_state = eval_env.reset() 759 | eval_time = time.time() - eval_start_time 760 | # compute mean + std & record 761 | eval_metrics = { 762 | 'eval/num_episodes': len(eval_returns), 763 | 'eval/num_steps': eval_steps, 764 | 'eval/mean_score': np.round(np.mean(eval_returns), 3), 765 | 'eval/std_score': np.round(np.std(eval_returns), 3), 766 | 'eval/mean_episode_length': np.mean(eval_ep_lengths), 767 | 'eval/std_episode_length': np.round(np.std(eval_ep_lengths), 3), 768 | 'eval/eval_time': eval_time, 769 | } 770 | logging.info(eval_metrics) 771 | scores.append((global_step, np.mean(eval_returns), np.mean(eval_ep_lengths), metrics['training/approx_kl'])) 772 | 773 | logging.info('TRAINING END: training duration: %s', time.time() - start_time) 774 | 775 | # final eval 776 | if process_id == 0 and Config.eval_env: 777 | eval_steps = 0 778 | if is_atari: 779 | feature_extractor = make_feature_extractor(_unpmap(training_state.params.feature_extractor)) 780 | policy_params = _unpmap(training_state.params.policy) 781 | policy = make_policy(policy_params, deterministic=True) 782 | while True: 783 | eval_steps += 1 784 | 785 | # run eval episode & record scores + lengths 786 | current_key, eval_key = jax.random.split(eval_key) 787 | obs = envs.normalize_obs(eval_state.obs) if Config.normalize_observations else eval_state.obs 788 | if is_atari: 789 | obs = feature_extractor(env_state.obs) 790 | actions, policy_extras = policy(obs, current_key) 791 | actions = np.asarray(actions) 792 | eval_state = eval_env.step(actions) 793 | if len(eval_env.returns) >= Config.num_eval_episodes: 794 | eval_returns, eval_ep_lengths = eval_env.evaluate() 795 | break 796 | eval_state = eval_env.reset() 797 | # compute mean + std & record 798 | eval_metrics = { 799 | 'final_eval/num_episodes': len(eval_returns), 800 | 'final_eval/num_steps': eval_steps, 801 | 'final_eval/mean_score': np.mean(eval_returns), 802 | 'final_eval/std_score': np.std(eval_returns), 803 | 'final_eval/mean_episode_length': np.mean(eval_ep_lengths), 804 | 'final_eval/std_episode_length': np.std(eval_ep_lengths), 805 | } 806 | logging.info(eval_metrics) 807 | scores.append((global_step, np.mean(eval_returns), np.mean(eval_ep_lengths), None)) 808 | 809 | # save scores 810 | run_dir = os.path.join('experiments', run_name) 811 | if not os.path.exists(run_dir): 812 | os.makedirs(run_dir) 813 | with open(os.path.join(run_dir, "scores.pkl"), "wb") as f: 814 | pickle.dump(scores, f) 815 | 816 | if Config.save_model: 817 | model_path = f"weights/{run_name}.params" 818 | with open(model_path, "wb") as f: 819 | f.write( 820 | flax.serialization.to_bytes( 821 | [ 822 | vars(Config), 823 | [ 824 | training_state.params.policy, 825 | training_state.params.value, 826 | # agent_state.params.feature_extractor, 827 | ], 828 | ] 829 | ) 830 | ) 831 | print(f"model saved to {model_path}") 832 | 833 | envs.close() 834 | 835 | 836 | if __name__ == "__main__": 837 | app.run(main) -------------------------------------------------------------------------------- /ppo.py: -------------------------------------------------------------------------------- 1 | """Proximal Policy Optimization (PPO).""" 2 | 3 | from absl import logging, app 4 | from functools import partial 5 | import os 6 | import pickle 7 | import random 8 | import time 9 | from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union 10 | 11 | # os.environ[ 12 | # "XLA_PYTHON_CLIENT_MEM_FRACTION" 13 | # ] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991 14 | 15 | import flax 16 | import flax.linen as nn 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | import optax 21 | 22 | from envs import make_env, Transition, has_discrete_action_space, is_atari_env 23 | from networks.policy import Policy 24 | from networks.networks import FeedForwardNetwork, ActivationFn, make_policy_network, make_value_network, make_atari_feature_extractor 25 | from networks.distributions import NormalTanhDistribution, ParametricDistribution, PolicyNormalDistribution, DiscreteDistribution 26 | 27 | class Config: 28 | # experiment 29 | experiment_name = 'ppo_test_video' 30 | seed = 20 31 | platform = 'cpu' # CPU or GPU 32 | capture_video = True # Not implemented 33 | write_logs_to_file = False 34 | save_model = False 35 | 36 | # environment 37 | env_id = 'Humanoid-v4' 38 | num_envs = 8 39 | parallel_envs = True 40 | clip_actions = False 41 | normalize_observations = True 42 | normalize_rewards = True 43 | clip_observations = 10. 44 | clip_rewards = 10. 45 | eval_env = True 46 | num_eval_episodes = 10 47 | eval_every = 5 48 | deterministic_eval = True 49 | 50 | # algorithm hyperparameters 51 | total_timesteps = int(1e6) * 8 52 | learning_rate = 3e-4 53 | unroll_length = 2048 54 | anneal_lr = True 55 | gamma = 0.99 56 | gae_lambda = 0.95 57 | batch_size = 1 # number of unrolls per minibatch 58 | num_minibatches = 8 59 | update_epochs = 10 60 | normalize_advantages = True 61 | clip_eps = 0.2 62 | entropy_cost = 0.00 63 | vf_cost = 0.5 64 | max_grad_norm = 0.5 65 | target_kl = None 66 | reward_scaling = 1. 67 | 68 | # policy params 69 | policy_hidden_layer_sizes: Sequence[int] = (32,) * 4 70 | value_hidden_layer_sizes: Sequence[int] = (256,) * 5 71 | activation: ActivationFn = nn.swish 72 | squash_distribution: bool = True 73 | 74 | # atari params 75 | atari_dense_layer_sizes: Sequence[int] = (512,) 76 | 77 | 78 | Metrics = Mapping[str, jnp.ndarray] 79 | 80 | _PMAP_AXIS_NAME = 'i' 81 | 82 | 83 | def _unpmap(v): 84 | return jax.tree_util.tree_map(lambda x: x[0], v) 85 | 86 | 87 | def _strip_weak_type(tree): 88 | # in order to avoid extra jit recompilations we strip all weak types from user input 89 | def f(leaf): 90 | leaf = jnp.asarray(leaf) 91 | return leaf.astype(leaf.dtype) 92 | return jax.tree_util.tree_map(f, tree) 93 | 94 | 95 | @flax.struct.dataclass 96 | class PPONetworkParams: 97 | """Contains training state for the learner.""" 98 | policy: Any 99 | value: Any 100 | 101 | 102 | @flax.struct.dataclass 103 | class PPONetworks: 104 | policy_network: FeedForwardNetwork 105 | value_network: FeedForwardNetwork 106 | parametric_action_distribution: Union[ParametricDistribution, DiscreteDistribution] 107 | 108 | 109 | @flax.struct.dataclass 110 | class AtariPPONetworkParams: 111 | """Contains training state for the learner.""" 112 | feature_extractor: Any 113 | policy: Any 114 | value: Any 115 | 116 | 117 | @flax.struct.dataclass 118 | class AtariPPONetworks: 119 | feature_extractor: FeedForwardNetwork 120 | policy_network: FeedForwardNetwork 121 | value_network: FeedForwardNetwork 122 | parametric_action_distribution: Union[ParametricDistribution, DiscreteDistribution] 123 | 124 | 125 | @flax.struct.dataclass 126 | class TrainingState: 127 | """Contains training state for the learner.""" 128 | optimizer_state: optax.OptState 129 | params: Union[PPONetworkParams, AtariPPONetworkParams] 130 | env_steps: jnp.ndarray 131 | 132 | 133 | def make_inference_fn(ppo_networks: Union[PPONetworks, AtariPPONetworks]): 134 | """Creates params and inference function for the PPO agent.""" 135 | 136 | def make_policy(params: Any, 137 | deterministic: bool = False) -> Policy: 138 | policy_network = ppo_networks.policy_network 139 | parametric_action_distribution = ppo_networks.parametric_action_distribution 140 | 141 | @jax.jit 142 | def policy(observations: jnp.ndarray, 143 | key_sample: jnp.ndarray) -> Tuple[jnp.ndarray, Mapping[str, Any]]: 144 | logits = policy_network.apply(params, observations) 145 | if deterministic: 146 | return ppo_networks.parametric_action_distribution.mode(logits), {} 147 | raw_actions = parametric_action_distribution.sample_no_postprocessing( 148 | logits, key_sample) 149 | log_prob = parametric_action_distribution.log_prob(logits, raw_actions) 150 | postprocessed_actions = parametric_action_distribution.postprocess( 151 | raw_actions) 152 | return postprocessed_actions, { 153 | 'log_prob': log_prob, 154 | 'raw_action': raw_actions 155 | } 156 | 157 | return policy 158 | 159 | return make_policy 160 | 161 | 162 | def make_feature_extraction_fn(ppo_networks: AtariPPONetworks): 163 | """Creates feature extractor for inference.""" 164 | 165 | def make_feature_extractor(params: Any): 166 | shared_feature_extractor = ppo_networks.feature_extractor 167 | 168 | @jax.jit 169 | def feature_extractor(observations: jnp.ndarray) -> jnp.ndarray: 170 | return shared_feature_extractor.apply(params, observations) 171 | 172 | return feature_extractor 173 | 174 | return make_feature_extractor 175 | 176 | 177 | def make_ppo_networks( 178 | observation_size: int, 179 | action_size: int, 180 | policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, 181 | value_hidden_layer_sizes: Sequence[int] = (256,) * 5, 182 | activation: ActivationFn = nn.swish, 183 | sqash_distribution: bool = True, 184 | discrete_policy: bool = False, 185 | shared_feature_extractor: bool = False, 186 | feature_extractor_dense_hidden_layer_sizes: Optional[Sequence[int]] = (512,), 187 | ) -> PPONetworks: 188 | """Make PPO networks with preprocessor.""" 189 | if discrete_policy: 190 | parametric_action_distribution = DiscreteDistribution( 191 | param_size=action_size) 192 | elif sqash_distribution: 193 | parametric_action_distribution = NormalTanhDistribution( 194 | event_size=action_size) 195 | else: 196 | parametric_action_distribution = PolicyNormalDistribution( 197 | event_size=action_size) 198 | if shared_feature_extractor: 199 | feature_extractor = make_atari_feature_extractor( 200 | obs_size=observation_size, 201 | hidden_layer_sizes=feature_extractor_dense_hidden_layer_sizes, 202 | activation=nn.relu 203 | ) 204 | policy_network = make_policy_network( 205 | parametric_action_distribution.param_size, 206 | feature_extractor_dense_hidden_layer_sizes[-1], 207 | hidden_layer_sizes=(), 208 | activation=activation) 209 | value_network = make_value_network( 210 | feature_extractor_dense_hidden_layer_sizes[-1], 211 | hidden_layer_sizes=(), 212 | activation=activation) 213 | return AtariPPONetworks( 214 | feature_extractor=feature_extractor, 215 | policy_network=policy_network, 216 | value_network=value_network, 217 | parametric_action_distribution=parametric_action_distribution) 218 | policy_network = make_policy_network( 219 | parametric_action_distribution.param_size, 220 | observation_size, 221 | hidden_layer_sizes=policy_hidden_layer_sizes, 222 | activation=activation) 223 | value_network = make_value_network( 224 | observation_size, 225 | hidden_layer_sizes=value_hidden_layer_sizes, 226 | activation=activation) 227 | 228 | return PPONetworks( 229 | policy_network=policy_network, 230 | value_network=value_network, 231 | parametric_action_distribution=parametric_action_distribution) 232 | 233 | 234 | def compute_gae(truncation: jnp.ndarray, 235 | termination: jnp.ndarray, 236 | rewards: jnp.ndarray, 237 | values: jnp.ndarray, 238 | bootstrap_value: jnp.ndarray, 239 | lambda_: float = 1.0, 240 | discount: float = 0.99): 241 | """Calculates the Generalized Advantage Estimation (GAE). 242 | 243 | Args: 244 | truncation: A float32 tensor of shape [T, B] with truncation signal. 245 | termination: A float32 tensor of shape [T, B] with termination signal. 246 | rewards: A float32 tensor of shape [T, B] containing rewards generated by 247 | following the behaviour policy. 248 | values: A float32 tensor of shape [T, B] with the value function estimates 249 | wrt. the target policy. 250 | bootstrap_value: A float32 of shape [B] with the value function estimate at 251 | time T. 252 | lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). Defaults to 253 | lambda_=1. 254 | discount: TD discount. 255 | 256 | Returns: 257 | A float32 tensor of shape [T, B]. Can be used as target to 258 | train a baseline (V(x_t) - vs_t)^2. 259 | A float32 tensor of shape [T, B] of advantages. 260 | """ 261 | 262 | truncation_mask = 1 - truncation 263 | # Append bootstrapped value to get [v1, ..., v_t+1] 264 | values_t_plus_1 = jnp.concatenate( 265 | [values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) 266 | deltas = rewards + discount * (1 - termination) * values_t_plus_1 - values 267 | deltas *= truncation_mask 268 | 269 | acc = jnp.zeros_like(bootstrap_value) 270 | vs_minus_v_xs = [] 271 | 272 | def compute_vs_minus_v_xs(carry, target_t): 273 | lambda_, acc = carry 274 | truncation_mask, delta, termination = target_t 275 | acc = delta + discount * (1 - termination) * truncation_mask * lambda_ * acc 276 | return (lambda_, acc), (acc) 277 | 278 | (_, _), (vs_minus_v_xs) = jax.lax.scan( 279 | compute_vs_minus_v_xs, (lambda_, acc), 280 | (truncation_mask, deltas, termination), 281 | length=int(truncation_mask.shape[0]), 282 | reverse=True) 283 | # Add V(x_s) to get v_s. 284 | vs = jnp.add(vs_minus_v_xs, values) 285 | 286 | vs_t_plus_1 = jnp.concatenate( 287 | [vs[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) 288 | advantages = (rewards + discount * 289 | (1 - termination) * vs_t_plus_1 - values) * truncation_mask 290 | return jax.lax.stop_gradient(vs), jax.lax.stop_gradient(advantages) 291 | 292 | 293 | 294 | def compute_ppo_loss( 295 | params: Union[PPONetworkParams, AtariPPONetworkParams], 296 | data: Transition, 297 | rng: jnp.ndarray, 298 | ppo_network: Union[PPONetworks, AtariPPONetworks], 299 | vf_cost: float = 0.5, 300 | entropy_cost: float = 1e-4, 301 | discounting: float = 0.9, 302 | reward_scaling: float = 1.0, 303 | gae_lambda: float = 0.95, 304 | clipping_epsilon: float = 0.3, 305 | normalize_advantage: bool = True, 306 | shared_feature_extractor: bool = False, 307 | ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: 308 | """Computes PPO loss including value loss and entropy bonus. 309 | 310 | Policy loss: $L_\pi = \frac{1}{\lvert \mathcal{D} \rvert} \sum_{\mathcal{D}} 311 | \min \biggl( \frac{\pi_\theta (a \mid s)}{\pi_\text{old} 312 | (a \mid s)} \hat{A}, \text{clip}\Bigl( \frac{\pi_\theta (a \mid s)}{\pi_\text{old} 313 | (a \mid s)}, 1-\varepsilon, 1+\varepsilon \Bigr) \hat{A} \biggr)$ 314 | 315 | Args: 316 | params: Network parameters, 317 | data: Transition that with leading dimension [B, T]. extra fields required 318 | are ['state_extras']['truncation'] ['policy_extras']['raw_action'] 319 | ['policy_extras']['log_prob'] 320 | rng: Random key 321 | ppo_network: PPO networks. 322 | entropy_cost: entropy cost. 323 | discounting: discounting, 324 | reward_scaling: reward multiplier. 325 | gae_lambda: General advantage estimation lambda. 326 | clipping_epsilon: Policy loss clipping epsilon 327 | normalize_advantage: whether to normalize advantage estimate 328 | shared_feature_extractor: Whether networks use a shared feature extractor. 329 | 330 | Returns: 331 | A tuple (loss, metrics) 332 | """ 333 | parametric_action_distribution = ppo_network.parametric_action_distribution 334 | 335 | policy_apply = ppo_network.policy_network.apply 336 | value_apply = ppo_network.value_network.apply 337 | 338 | # Put the time dimension first. 339 | data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) 340 | 341 | hidden = data.observation 342 | hidden_boot = data.next_observation[-1] 343 | if shared_feature_extractor: 344 | feature_extractor_apply = ppo_network.feature_extractor.apply 345 | hidden = feature_extractor_apply(params.feature_extractor, data.observation) 346 | hidden_boot = feature_extractor_apply(params.feature_extractor, 347 | data.next_observation[-1]) 348 | 349 | policy_logits = policy_apply(params.policy, 350 | hidden) 351 | 352 | baseline = value_apply(params.value, hidden) 353 | 354 | 355 | bootstrap_value = value_apply(params.value, 356 | hidden_boot) 357 | 358 | rewards = data.reward * reward_scaling 359 | truncation = data.extras['state_extras']['truncation'] 360 | termination = (1 - data.discount) * (1 - truncation) 361 | 362 | target_action_log_probs = parametric_action_distribution.log_prob( 363 | policy_logits, data.extras['policy_extras']['raw_action']) 364 | behaviour_action_log_probs = data.extras['policy_extras']['log_prob'] 365 | 366 | vs, advantages = compute_gae( 367 | truncation=truncation, 368 | termination=termination, 369 | rewards=rewards, 370 | values=baseline, 371 | bootstrap_value=bootstrap_value, 372 | lambda_=gae_lambda, 373 | discount=discounting) 374 | if normalize_advantage: 375 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) 376 | log_ratio = target_action_log_probs - behaviour_action_log_probs 377 | rho_s = jnp.exp(log_ratio) 378 | 379 | surrogate_loss1 = rho_s * advantages 380 | surrogate_loss2 = jnp.clip(rho_s, 1 - clipping_epsilon, 381 | 1 + clipping_epsilon) * advantages 382 | 383 | policy_loss = -jnp.mean(jnp.minimum(surrogate_loss1, surrogate_loss2)) 384 | approx_kl = ((rho_s - 1) - log_ratio).mean() 385 | 386 | # Value function loss 387 | v_error = vs - baseline 388 | v_loss = jnp.mean(v_error * v_error) * 0.5 * vf_cost 389 | 390 | # Entropy reward 391 | entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng)) 392 | entropy_loss = entropy_cost * -entropy 393 | 394 | total_loss = policy_loss + v_loss + entropy_loss 395 | 396 | metrics = { 397 | 'total_loss': total_loss, 398 | 'policy_loss': policy_loss, 399 | 'value_loss': v_loss, 400 | 'entropy_loss': entropy_loss, 401 | 'entropy': entropy, 402 | 'approx_kl': jax.lax.stop_gradient(approx_kl), 403 | } 404 | 405 | return total_loss, metrics 406 | 407 | 408 | 409 | 410 | def main(_): 411 | run_name = f"Exp_{Config.experiment_name}__{Config.env_id}__{Config.seed}__{int(time.time())}" 412 | 413 | if Config.write_logs_to_file: 414 | from absl import flags 415 | flags.FLAGS.alsologtostderr = True 416 | log_path = f'./training_logs/ppo/{run_name}' 417 | if not os.path.exists(log_path): 418 | os.makedirs(log_path) 419 | logging.get_absl_handler().use_absl_log_file('logs', log_path) 420 | 421 | logging.get_absl_handler().setFormatter(None) 422 | 423 | # jax set up devices 424 | process_count = jax.process_count() 425 | process_id = jax.process_index() 426 | local_device_count = jax.local_device_count() 427 | local_devices_to_use = local_device_count 428 | device_count = local_devices_to_use * process_count 429 | assert Config.num_envs % device_count == 0 430 | 431 | 432 | assert Config.batch_size * Config.num_minibatches % Config.num_envs == 0 433 | # The number of environment steps executed for every training step. 434 | env_step_per_training_step = ( 435 | Config.batch_size * Config.unroll_length * Config.num_minibatches) 436 | num_training_steps = np.ceil(Config.total_timesteps / env_step_per_training_step).astype(int) 437 | 438 | # log hyperparameters 439 | logging.info("|param: value|") 440 | for key, value in vars(Config).items(): 441 | if not key.startswith('__'): 442 | logging.info(f"|{key}: {value}|") 443 | 444 | random.seed(Config.seed) 445 | np.random.seed(Config.seed) 446 | # handle / split random keys 447 | key = jax.random.PRNGKey(Config.seed) 448 | global_key, local_key = jax.random.split(key) 449 | del key 450 | local_key = jax.random.fold_in(local_key, process_id) 451 | local_key, key_envs, eval_key = jax.random.split(local_key, 3) 452 | # key_networks should be global, so that networks are initialized the same 453 | # way for different processes. 454 | key_policy, key_value, key_feature_extractor = jax.random.split(global_key, 3) 455 | del global_key 456 | 457 | is_atari = is_atari_env(Config.env_id) 458 | envs = make_env( 459 | env_id=Config.env_id, 460 | num_envs=Config.num_envs, 461 | parallel=Config.parallel_envs, 462 | clip_actions=Config.clip_actions, 463 | norm_obs=Config.normalize_observations, 464 | norm_reward=Config.normalize_rewards, 465 | clip_obs=Config.clip_observations, 466 | clip_rewards=Config.clip_rewards, 467 | is_atari=is_atari, 468 | ) 469 | 470 | discrete_action_space = has_discrete_action_space(envs) 471 | envs.seed(int(key_envs[0])) 472 | env_state = envs.reset() 473 | 474 | 475 | if discrete_action_space: 476 | action_size = envs.action_space.n 477 | else: 478 | action_size = np.prod(envs.action_space.shape) # flatten action size for nested spaces 479 | if is_atari: 480 | observation_shape = env_state.obs.shape[-3:] 481 | else: 482 | observation_shape = env_state.obs.shape[-1] 483 | 484 | ppo_network = make_ppo_networks( 485 | observation_size=observation_shape, # NOTE only works with flattened observation space 486 | action_size=action_size, # flatten action size for nested spaces 487 | policy_hidden_layer_sizes=Config.policy_hidden_layer_sizes, 488 | value_hidden_layer_sizes=Config.value_hidden_layer_sizes, 489 | activation=Config.activation, 490 | sqash_distribution=Config.squash_distribution, 491 | discrete_policy=discrete_action_space, 492 | shared_feature_extractor=is_atari, 493 | feature_extractor_dense_hidden_layer_sizes=Config.atari_dense_layer_sizes, 494 | ) 495 | make_policy = make_inference_fn(ppo_network) 496 | if is_atari: 497 | make_feature_extractor = make_feature_extraction_fn(ppo_network) 498 | 499 | # create optimizer 500 | if Config.anneal_lr: 501 | learning_rate = optax.linear_schedule( 502 | Config.learning_rate, 503 | Config.learning_rate * 0.01, # 0 504 | transition_steps=Config.total_timesteps, 505 | ) 506 | else: 507 | learning_rate = Config.learning_rate 508 | optimizer = optax.chain( 509 | optax.clip_by_global_norm(Config.max_grad_norm), 510 | optax.adam(learning_rate), 511 | ) 512 | 513 | # create loss function via functools.partial 514 | loss_fn = partial( 515 | compute_ppo_loss, 516 | ppo_network=ppo_network, 517 | vf_cost=Config.vf_cost, 518 | entropy_cost=Config.entropy_cost, 519 | discounting=Config.gamma, 520 | reward_scaling=Config.reward_scaling, 521 | gae_lambda=Config.gae_lambda, 522 | clipping_epsilon=Config.clip_eps, 523 | normalize_advantage=Config.normalize_advantages, 524 | shared_feature_extractor=is_atari, 525 | ) 526 | 527 | 528 | def loss_and_pgrad(loss_fn: Callable[..., float], 529 | pmap_axis_name: Optional[str], 530 | has_aux: bool = False): 531 | g = jax.value_and_grad(loss_fn, has_aux=has_aux) 532 | 533 | def h(*args, **kwargs): 534 | value, grad = g(*args, **kwargs) 535 | return value, jax.lax.pmean(grad, axis_name=pmap_axis_name) 536 | 537 | return g if pmap_axis_name is None else h 538 | 539 | 540 | def gradient_update_fn(loss_fn: Callable[..., float], 541 | optimizer: optax.GradientTransformation, 542 | pmap_axis_name: Optional[str], 543 | has_aux: bool = False): 544 | """Wrapper of the loss function that apply gradient updates. 545 | 546 | Args: 547 | loss_fn: The loss function. 548 | optimizer: The optimizer to apply gradients. 549 | pmap_axis_name: If relevant, the name of the pmap axis to synchronize 550 | gradients. 551 | has_aux: Whether the loss_fn has auxiliary data. 552 | 553 | Returns: 554 | A function that takes the same argument as the loss function plus the 555 | optimizer state. The output of this function is the loss, the new parameter, 556 | and the new optimizer state. 557 | """ 558 | loss_and_pgrad_fn = loss_and_pgrad( 559 | loss_fn, pmap_axis_name=pmap_axis_name, has_aux=has_aux) 560 | 561 | def f(*args, optimizer_state): 562 | value, grads = loss_and_pgrad_fn(*args) 563 | params_update, optimizer_state = optimizer.update(grads, optimizer_state) 564 | params = optax.apply_updates(args[0], params_update) 565 | return value, params, optimizer_state 566 | 567 | return f 568 | 569 | gradient_update_fn = gradient_update_fn( 570 | loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True 571 | ) 572 | 573 | # minibatch training step 574 | def minibatch_step(carry, data: Transition,): 575 | optimizer_state, params, key = carry 576 | key, key_loss = jax.random.split(key) 577 | (_, metrics), params, optimizer_state = gradient_update_fn( 578 | params, 579 | data, 580 | key_loss, 581 | optimizer_state=optimizer_state) 582 | 583 | return (optimizer_state, params, key), metrics 584 | 585 | 586 | # sgd step 587 | def sgd_step(carry, unused_t, data: Transition): 588 | optimizer_state, params, key = carry 589 | key, key_perm, key_grad = jax.random.split(key, 3) 590 | 591 | def convert_data(x: jnp.ndarray): 592 | x = jax.random.permutation(key_perm, x) 593 | x = jnp.reshape(x, (Config.num_minibatches, -1) + x.shape[1:]) 594 | return x 595 | 596 | shuffled_data = jax.tree_util.tree_map(convert_data, data) 597 | (optimizer_state, params, _), metrics = jax.lax.scan( 598 | minibatch_step, 599 | (optimizer_state, params, key_grad), 600 | shuffled_data, 601 | length=Config.num_minibatches) 602 | return (optimizer_state, params, key), metrics 603 | 604 | 605 | # learning 606 | def learn( 607 | data: Transition, 608 | training_state: TrainingState, 609 | key_sgd: jnp.ndarray, 610 | ): 611 | (optimizer_state, params, _), metrics = jax.lax.scan( 612 | partial( 613 | sgd_step, data=data), 614 | (training_state.optimizer_state, training_state.params, key_sgd), (), 615 | length=Config.update_epochs) 616 | 617 | new_training_state = TrainingState( 618 | optimizer_state=optimizer_state, 619 | params=params, 620 | env_steps=training_state.env_steps + env_step_per_training_step) 621 | 622 | metrics = jax.tree_util.tree_map(jnp.mean, metrics) 623 | return new_training_state, metrics 624 | 625 | learn = jax.pmap(learn, axis_name=_PMAP_AXIS_NAME) 626 | 627 | 628 | # initialize params & training state 629 | if is_atari: 630 | init_params = AtariPPONetworkParams( 631 | feature_extractor=ppo_network.feature_extractor.init(key_feature_extractor), 632 | policy=ppo_network.policy_network.init(key_policy), 633 | value=ppo_network.value_network.init(key_value)) 634 | else: 635 | init_params = PPONetworkParams( 636 | policy=ppo_network.policy_network.init(key_policy), 637 | value=ppo_network.value_network.init(key_value)) 638 | training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray 639 | optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars 640 | params=init_params, 641 | env_steps=0) 642 | training_state = jax.device_put_replicated( 643 | training_state, 644 | jax.local_devices()[:local_devices_to_use]) 645 | 646 | 647 | # create eval env 648 | if Config.eval_env: 649 | eval_env = make_env( 650 | env_id=Config.env_id, 651 | num_envs=1, 652 | parallel=False, 653 | norm_obs=False, 654 | norm_reward=False, 655 | clip_obs=Config.clip_observations, 656 | clip_rewards=Config.clip_rewards, 657 | evaluate=True, 658 | capture_video=Config.capture_video, 659 | ) 660 | eval_env.seed(int(eval_key[0])) 661 | eval_state = eval_env.reset() 662 | 663 | 664 | # initialize metrics 665 | global_step = 0 666 | start_time = time.time() 667 | training_walltime = 0 668 | scores = [] 669 | 670 | # training loop 671 | for training_step in range(1, num_training_steps + 1): 672 | update_time_start = time.time() 673 | 674 | new_key, local_key = jax.random.split(local_key) 675 | training_state, env_state = _strip_weak_type((training_state, env_state)) 676 | key_sgd, key_generate_unroll = jax.random.split(new_key, 2) 677 | 678 | if is_atari: 679 | feature_extractor = make_feature_extractor(_unpmap(training_state.params.feature_extractor)) 680 | policy = make_policy(_unpmap(training_state.params.policy)) 681 | 682 | data = [] 683 | for step in range(Config.batch_size * Config.num_minibatches // Config.num_envs): 684 | transitions = [] 685 | for unroll_step in range(Config.unroll_length): 686 | current_key, key_generate_unroll = jax.random.split(key_generate_unroll) 687 | obs = env_state.obs 688 | if is_atari: 689 | obs = feature_extractor(env_state.obs) 690 | actions, policy_extras = policy(obs, current_key) 691 | actions = np.asarray(actions) 692 | nstate = envs.step(actions) 693 | # NOTE: info transformed: Array[Dict] --> Dict[Array] 694 | state_extras = {'truncation': jnp.array([info['truncation'] for info in nstate.info])} 695 | transition = Transition( 696 | observation=env_state.obs, 697 | action=actions, 698 | reward=nstate.reward, 699 | discount=1 - nstate.done, 700 | next_observation=nstate.obs, 701 | extras={ 702 | 'policy_extras': policy_extras, 703 | 'state_extras': state_extras 704 | }) 705 | transitions.append(transition) 706 | env_state = nstate 707 | data.append(jax.tree_util.tree_map(lambda *x: np.stack(x), *transitions)) 708 | data = jax.tree_util.tree_map(lambda *x: np.stack(x), *data) 709 | 710 | epoch_rollout_time = time.time() - update_time_start 711 | update_time_start = time.time() 712 | 713 | # Have leading dimensions (batch_size * num_minibatches, unroll_length) 714 | data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) 715 | data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), 716 | data) 717 | assert data.discount.shape[1:] == (Config.unroll_length,) 718 | 719 | data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (local_devices_to_use, -1,) + x.shape[1:]), 720 | data) 721 | 722 | # as function 723 | keys_sgd = jax.random.split(key_sgd, local_devices_to_use) 724 | new_training_state, metrics = learn(data=data, training_state=training_state, key_sgd=keys_sgd) 725 | 726 | # logging 727 | training_state, env_state, metrics = _strip_weak_type((new_training_state, env_state, metrics)) 728 | metrics = jax.tree_util.tree_map(jnp.mean, metrics) 729 | jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics) 730 | epoch_update_time = time.time() - update_time_start 731 | training_walltime = time.time() - start_time # += epoch_update_time + epoch_rollout_time 732 | sps = env_step_per_training_step / (epoch_update_time + epoch_rollout_time) 733 | global_step += env_step_per_training_step 734 | 735 | current_step = int(_unpmap(training_state.env_steps)) 736 | 737 | metrics = { 738 | 'training/total_steps': current_step, 739 | 'training/updates': training_step, 740 | 'training/sps': np.round(sps, 3), 741 | 'training/walltime': np.round(training_walltime, 3), 742 | 'training/rollout_time': np.round(epoch_rollout_time, 3), 743 | 'training/update_time': np.round(epoch_update_time, 3), 744 | **{f'training/{name}': float(value) for name, value in metrics.items()} 745 | } 746 | 747 | logging.info(metrics) 748 | 749 | # run eval 750 | if process_id == 0 and Config.eval_env and training_step % Config.eval_every == 0: 751 | eval_start_time = time.time() 752 | eval_steps = 0 753 | if is_atari: 754 | feature_extractor = make_feature_extractor(_unpmap(training_state.params.feature_extractor)) 755 | policy_params = _unpmap(training_state.params.policy) 756 | policy = make_policy(policy_params, deterministic=Config.deterministic_eval) 757 | while True: 758 | eval_steps += 1 759 | 760 | # run eval episode & record scores + lengths 761 | current_key, eval_key = jax.random.split(eval_key) 762 | obs = envs.normalize_obs(eval_state.obs) if Config.normalize_observations else eval_state.obs 763 | if is_atari: 764 | obs = feature_extractor(env_state.obs) 765 | actions, policy_extras = policy(obs, current_key) 766 | actions = np.asarray(actions) 767 | eval_state = eval_env.step(actions) 768 | if len(eval_env.returns) >= Config.num_eval_episodes: 769 | eval_returns, eval_ep_lengths = eval_env.evaluate() 770 | break 771 | eval_state = eval_env.reset() 772 | eval_time = time.time() - eval_start_time 773 | # compute mean + std & record 774 | eval_metrics = { 775 | 'eval/num_episodes': len(eval_returns), 776 | 'eval/num_steps': eval_steps, 777 | 'eval/mean_score': np.round(np.mean(eval_returns), 3), 778 | 'eval/std_score': np.round(np.std(eval_returns), 3), 779 | 'eval/mean_episode_length': np.mean(eval_ep_lengths), 780 | 'eval/std_episode_length': np.round(np.std(eval_ep_lengths), 3), 781 | 'eval/eval_time': eval_time, 782 | } 783 | logging.info(eval_metrics) 784 | scores.append((global_step, np.mean(eval_returns), np.mean(eval_ep_lengths), metrics['training/approx_kl'])) 785 | 786 | logging.info('TRAINING END: training duration: %s', time.time() - start_time) 787 | 788 | # final eval 789 | if process_id == 0 and Config.eval_env: 790 | eval_steps = 0 791 | if is_atari: 792 | feature_extractor = make_feature_extractor(_unpmap(training_state.params.feature_extractor)) 793 | policy_params = _unpmap(training_state.params.policy) 794 | policy = make_policy(policy_params, deterministic=True) 795 | while True: 796 | eval_steps += 1 797 | 798 | # run eval episode & record scores + lengths 799 | current_key, eval_key = jax.random.split(eval_key) 800 | obs = envs.normalize_obs(eval_state.obs) if Config.normalize_observations else eval_state.obs 801 | if is_atari: 802 | obs = feature_extractor(env_state.obs) 803 | actions, policy_extras = policy(obs, current_key) 804 | actions = np.asarray(actions) 805 | eval_state = eval_env.step(actions) 806 | if len(eval_env.returns) >= Config.num_eval_episodes: 807 | eval_returns, eval_ep_lengths = eval_env.evaluate() 808 | break 809 | eval_state = eval_env.reset() 810 | # compute mean + std & record 811 | eval_metrics = { 812 | 'final_eval/num_episodes': len(eval_returns), 813 | 'final_eval/num_steps': eval_steps, 814 | 'final_eval/mean_score': np.mean(eval_returns), 815 | 'final_eval/std_score': np.std(eval_returns), 816 | 'final_eval/mean_episode_length': np.mean(eval_ep_lengths), 817 | 'final_eval/std_episode_length': np.std(eval_ep_lengths), 818 | } 819 | logging.info(eval_metrics) 820 | scores.append((global_step, np.mean(eval_returns), np.mean(eval_ep_lengths), None)) 821 | 822 | # save scores 823 | run_dir = os.path.join('experiments', run_name) 824 | if not os.path.exists(run_dir): 825 | os.makedirs(run_dir) 826 | with open(os.path.join(run_dir, "scores.pkl"), "wb") as f: 827 | pickle.dump(scores, f) 828 | 829 | if Config.save_model: 830 | model_path = f"weights/{run_name}.params" 831 | with open(model_path, "wb") as f: 832 | f.write( 833 | flax.serialization.to_bytes( 834 | [ 835 | vars(Config), 836 | [ 837 | training_state.params.policy, 838 | training_state.params.value, 839 | # agent_state.params.feature_extractor, 840 | ], 841 | ] 842 | ) 843 | ) 844 | print(f"model saved to {model_path}") 845 | 846 | envs.close() 847 | 848 | 849 | if __name__ == "__main__": 850 | app.run(main) --------------------------------------------------------------------------------