├── figures ├── ant.png ├── hopper.png ├── walker.png ├── humanoid.png ├── reacher.png ├── swimmer.png ├── halfcheetah.png ├── hopper_bad.png ├── hopper_good.png ├── humanoid_bad.png ├── humanoid_good.png ├── invertedpendulum.png ├── all_envs_box_plot.png └── inverteddoublependulum.png ├── LICENSE ├── replay_buffer.py ├── networks.py ├── README.md ├── main.py └── agent.py /figures/ant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/ant.png -------------------------------------------------------------------------------- /figures/hopper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/hopper.png -------------------------------------------------------------------------------- /figures/walker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/walker.png -------------------------------------------------------------------------------- /figures/humanoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/humanoid.png -------------------------------------------------------------------------------- /figures/reacher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/reacher.png -------------------------------------------------------------------------------- /figures/swimmer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/swimmer.png -------------------------------------------------------------------------------- /figures/halfcheetah.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/halfcheetah.png -------------------------------------------------------------------------------- /figures/hopper_bad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/hopper_bad.png -------------------------------------------------------------------------------- /figures/hopper_good.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/hopper_good.png -------------------------------------------------------------------------------- /figures/humanoid_bad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/humanoid_bad.png -------------------------------------------------------------------------------- /figures/humanoid_good.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/humanoid_good.png -------------------------------------------------------------------------------- /figures/invertedpendulum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/invertedpendulum.png -------------------------------------------------------------------------------- /figures/all_envs_box_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/all_envs_box_plot.png -------------------------------------------------------------------------------- /figures/inverteddoublependulum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/TD3-JAX/HEAD/figures/inverteddoublependulum.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Chen Tessler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | 4 | from typing import Tuple 5 | 6 | 7 | class ReplayBuffer(object): 8 | """A simple container for maintaining the history of the agent.""" 9 | def __init__( 10 | self, 11 | state_dim: int, 12 | action_dim: int, 13 | max_size: int 14 | ): 15 | self.max_size = max_size 16 | self.ptr = 0 17 | self.size = 0 18 | 19 | self.state = np.zeros((max_size, state_dim)) 20 | self.action = np.zeros((max_size, action_dim)) 21 | self.next_state = np.zeros((max_size, state_dim)) 22 | self.reward = np.zeros((max_size, 1)) 23 | self.not_done = np.zeros((max_size, 1)) 24 | 25 | def add( 26 | self, 27 | state: np.ndarray, 28 | action: np.ndarray, 29 | next_state: np.ndarray, 30 | reward: float, 31 | done: float 32 | ) -> None: 33 | """Memory built for per-transition interaction, does not handle batch updates.""" 34 | self.state[self.ptr] = state 35 | self.action[self.ptr] = action 36 | self.next_state[self.ptr] = next_state 37 | self.reward[self.ptr] = reward 38 | self.not_done[self.ptr] = 1. - done 39 | 40 | self.ptr = (self.ptr + 1) % self.max_size 41 | self.size = min(self.size + 1, self.max_size) 42 | 43 | def sample( 44 | self, 45 | batch_size: int, 46 | rng: jax.numpy.ndarray 47 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 48 | """Given a JAX PRNG key, sample batch from memory.""" 49 | ind = jax.random.randint(rng, (batch_size, ), 0, self.size) 50 | 51 | return ( 52 | self.state[ind], 53 | self.action[ind], 54 | self.next_state[ind], 55 | self.reward[ind], 56 | self.not_done[ind] 57 | ) 58 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | import haiku as hk 3 | import jax 4 | from jax import numpy as jnp 5 | import numpy as np 6 | 7 | """ 8 | Actor and Critic networks defined as in the TD3 paper (Fujimoto et. al.) https://arxiv.org/abs/1802.09477 9 | """ 10 | 11 | 12 | class Actor(hk.Module): 13 | def __init__(self, action_dim: int, max_action: float): 14 | super(Actor, self).__init__() 15 | self.action_dim = action_dim 16 | self.max_action = max_action 17 | 18 | def __call__(self, state: np.ndarray) -> jnp.DeviceArray: 19 | actor_net = hk.Sequential([ 20 | hk.Flatten(), 21 | hk.Linear(256, w_init=hk.initializers.VarianceScaling(scale=2.0, distribution='uniform')), 22 | jax.nn.relu, 23 | hk.Linear(256, w_init=hk.initializers.VarianceScaling(scale=2.0, distribution='uniform')), 24 | jax.nn.relu, 25 | hk.Linear(self.action_dim, w_init=hk.initializers.VarianceScaling(scale=2.0, distribution='uniform')) 26 | ]) 27 | return jnp.tanh(actor_net(state)) * self.max_action 28 | 29 | 30 | class Critic(hk.Module): 31 | def __init__(self): 32 | super(Critic, self).__init__() 33 | 34 | def __call__(self, state_action: np.ndarray) -> Tuple[jnp.DeviceArray, jnp.DeviceArray]: 35 | def critic_net(): 36 | return hk.Sequential([ 37 | hk.Flatten(), 38 | hk.Linear(256, w_init=hk.initializers.VarianceScaling(scale=2.0, distribution='uniform')), 39 | jax.nn.relu, 40 | hk.Linear(256, w_init=hk.initializers.VarianceScaling(scale=2.0, distribution='uniform')), 41 | jax.nn.relu, 42 | hk.Linear(1, w_init=hk.initializers.VarianceScaling(scale=2.0, distribution='uniform')) 43 | ]) 44 | critic_net_1 = critic_net() 45 | 46 | critic_net_2 = critic_net() 47 | 48 | return critic_net_1(state_action), critic_net_2(state_action) 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TD3-JAX 2 | A JAX Implementation of the [Twin Delayed DDPG](https://github.com/sfujim/TD3) Algorithm 3 | 4 | 5 | ## Requirements 6 | 7 | Beside each requirement, I have stated the version installed on my system for reproducibility.. 8 | 9 | * [JAX](https://github.com/google/jax) - jax 0.1.59, jaxlib 0.1.39 10 | * [Haiku](https://github.com/deepmind/dm-haiku) - dm-haiku 0.0.1a0, dm-sonnet 2.0.0b0 11 | * [RLax](https://github.com/deepmind/rlax) - rlax 0.0.0 12 | * [Gym](https://github.com/openai/gym) - gym 0.15.4 13 | * [MuJoCo](https://github.com/openai/mujoco-py) - mujoco-py 2.0.2.9 14 | 15 | ## Command line arguments 16 | In order to run each environment 17 | ``` 18 | for seed in {0..9}; do python main.py --env Hopper-v2 --seed $seed; done 19 | ``` 20 | 21 | The default hyper parameters aren't ideal for all domains. 22 | Based on some limited testing and intuition, the following values are better than the defaults. 23 | 24 | Environment | Command line addition 25 | --- | --- 26 | Swimmer-v2 | --discount 0.995 27 | 28 | ## Results 29 | 30 | For each seed, we maintain the 'best policy seen' during evaluation, which we re-evaluate at the end of training. 31 | These results are the average +\- one standard deviation for this metric. 32 | All reported results are based on 10 seeds (0 to 9). 33 | 34 | Environment | Best policy per run 35 | --- | --- 36 | Hopper-v2 | 3691.5 ± 61.7 37 | Humanoid-v2 | 5194.0 ± 97.1 38 | Walker2d-v2 | 4328.8 ± 1059.0 39 | Ant-v2 | 3505.4 ± 411.7 40 | HalfCheetah-v2 | 10411.0 ± 1323.5 41 | Swimmer-v2 | 314.1 ± 69.2 42 | InvertedPendulum-v2 | 1000.0 ± 0.0 43 | InvertedDoublePendulum-v2 | 9350.6 ± 26.8 44 | Reacher-v2 | -4.0 ± 0.3 45 | 46 | ![Box Plots](figures/all_envs_box_plot.png "Box Plots") 47 | 48 | ![Hopper-v2](figures/hopper.png "Hopper-v2") 49 | ![Humanoid-v2](figures/humanoid.png "Humanoid-v2") 50 | ![Walker2d-v2](figures/walker.png "Walker2d-v2") 51 | ![Ant-v2](figures/ant.png "Ant-v2") 52 | ![HalfCheetah-v2](figures/halfcheetah.png "HalfCheetah-v2") 53 | ![Swimmer-v2](figures/swimmer.png "Swimmer-v2") 54 | ![InvertedPendulum-v2](figures/invertedpendulum.png "InvertedPendulum-v2") 55 | ![InvertedDoublePendulum-v2](figures/inverteddoublependulum.png "InvertedDoublePendulum-v2") 56 | ![Reacher-v2](figures/reacher.png "Reacher-v2") 57 | 58 | The code for reproducing the figures, including per-seed representation for each environment is provided in [plot_results.ipynb](plot_results.ipynb) 59 | 60 | Based on the per-seed analysis, it seems that with some hyperparameter tuning, the results of TD3 can improve dramatically. 61 | Mainly, it seems that in some domains, it takes a while for the algorithm to start learning -- either a result of low learning rates, large experience replay or un-optimized discount factor. 62 | 63 | For instance: 64 | 65 | ![A good humanoid run](figures/humanoid_good.png "A good humanoid run") 66 | 67 | versus 68 | 69 | ![A bad humanoid run](figures/humanoid_bad.png "A bad humanoid run") 70 | 71 | and 72 | 73 | ![A good hopper run](figures/hopper_good.png "A good hopper run") 74 | 75 | versus 76 | 77 | ![A bad hopper run](figures/hopper_bad.png "A bad hopper run") -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credits: https://github.com/sfujim/TD3 3 | """ 4 | 5 | import argparse 6 | from typing import Any, Tuple 7 | import numpy as np 8 | 9 | import gym 10 | import jax 11 | 12 | from replay_buffer import ReplayBuffer 13 | from agent import Agent 14 | 15 | import os 16 | 17 | OptState = Any 18 | 19 | 20 | def parse_arguments() -> argparse.Namespace: 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--policy", default="TD3", choices=['TD3', 'DDPG']) # Policy name (TD3, DDPG) 23 | parser.add_argument("--env", default="InvertedPendulum-v2") # OpenAI gym environment name 24 | parser.add_argument("--seed", type=int, required=True) # Sets Gym, PyTorch and Numpy seeds 25 | parser.add_argument("--start_timesteps", default=25000, type=int) # Time steps initial random policy is used 26 | parser.add_argument("--eval_freq", default=5e3, type=int) # How often (time steps) we evaluate 27 | parser.add_argument("--max_timesteps", default=1e6, type=int) # Max time steps to run environment 28 | parser.add_argument("--replay_size", default=200000, type=int) # Size of the replay buffer 29 | parser.add_argument("--expl_noise", default=0.1, type=float) # Std of Gaussian exploration noise 30 | parser.add_argument("--batch_size", default=256, type=int) # Batch size for both actor and critic 31 | parser.add_argument("--discount", default=0.99, type=float) # Discount factor 32 | parser.add_argument("--tau", default=0.005, type=float) # Target network update rate 33 | parser.add_argument("--policy_noise", default=0.2, type=float) # Noise added to target policy during critic update 34 | parser.add_argument("--noise_clip", default=0.5, type=float) # Range to clip target policy noise 35 | parser.add_argument("--lr", default=3e-4, type=float) # Optimizer learning rates 36 | parser.add_argument("--policy_freq", default=2, type=int) # Frequency of delayed policy updates 37 | # TODO: Model saving and loading is not supported yet. 38 | # parser.add_argument("--save_model", action="store_true") # Save model and optimizer parameters 39 | # parser.add_argument("--load_model", default="") # Model load file name, "" doesn't load, "default" uses file_name 40 | args = parser.parse_args() 41 | 42 | return args 43 | 44 | 45 | def eval_policy(agent: Agent, env_name: str, eval_episodes: int = 10, max_steps: int = 0) -> float: 46 | eval_env = gym.make(env_name) 47 | 48 | avg_reward = 0. 49 | for _ in range(eval_episodes): 50 | state, done = eval_env.reset(), False 51 | remaining_steps = max_steps * 1.0 52 | 53 | while not done: 54 | action = agent.policy(agent.actor_params, state) 55 | state, reward, done, _ = eval_env.step(action) 56 | 57 | remaining_steps -= 1 58 | 59 | avg_reward += reward 60 | 61 | avg_reward /= eval_episodes 62 | 63 | print("---------------------------------------") 64 | print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}") 65 | print("---------------------------------------") 66 | return avg_reward 67 | 68 | 69 | def main(): 70 | args = parse_arguments() 71 | 72 | idx = 0 73 | file_name = f"{args.env}_{idx}" 74 | # For easy extraction of the data, we save all runs using a serially increasing indicator. 75 | while os.path.exists('./results/' + args.policy + '/' + file_name + '.npy'): 76 | idx += 1 77 | file_name = f"{args.env}_{idx}" 78 | 79 | print("---------------------------------------") 80 | print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}") 81 | print("---------------------------------------") 82 | 83 | if not os.path.exists("./results/" + args.policy): 84 | os.makedirs("./results/" + args.policy) 85 | 86 | # if args.save_model and not os.path.exists("./models/" + args.policy): 87 | # os.makedirs("./models/" + args.policy) 88 | 89 | env = gym.make(args.env) 90 | env.seed(args.seed) 91 | 92 | state_dim = env.observation_space.shape[0] 93 | action_dim = env.action_space.shape[0] 94 | max_action = float(env.action_space.high[0]) 95 | 96 | state, done = env.reset(), False 97 | episode_reward = 0 98 | episode_timesteps = 0 99 | episode_num = 0 100 | 101 | rng = jax.random.PRNGKey(args.seed) 102 | rng, actor_rng, critic_rng = jax.random.split(rng, 3) 103 | 104 | agent = Agent(args.policy, 105 | action_dim, 106 | max_action, 107 | args.lr, 108 | args.discount, 109 | args.noise_clip, 110 | args.policy_noise, 111 | args.policy_freq, 112 | actor_rng, 113 | critic_rng, 114 | state) 115 | 116 | replay_buffer = ReplayBuffer(state_dim, action_dim, max_size=args.replay_size) 117 | 118 | # Evaluate untrained policy. 119 | # We evaluate for 100 episodes as 10 episodes provide a very noisy estimation in some domains. 120 | evaluations = [eval_policy(agent, args.env, max_steps=env._max_episode_steps, eval_episodes=100)] 121 | np.save(f"./results/{args.policy}/{file_name}", evaluations) 122 | best_performance = evaluations[-1] 123 | best_actor_params = agent.actor_params 124 | # if args.save_model: agent.save(f"./models/{args.policy}/{file_name}") 125 | 126 | for t in range(int(args.max_timesteps)): 127 | 128 | episode_timesteps += 1 129 | 130 | # Select action randomly or according to policy 131 | if t < args.start_timesteps: 132 | action = env.action_space.sample() 133 | else: 134 | rng, noise_rng = jax.random.split(rng) 135 | action = ( 136 | agent.policy(agent.actor_params, state) 137 | + jax.random.normal(noise_rng, (action_dim, )) * max_action * args.expl_noise 138 | ).clip(-max_action, max_action) 139 | 140 | # Perform action 141 | next_state, reward, done, _ = env.step(action) 142 | # This 'trick' converts the finite-horizon task into an infinite-horizon one. It does change the problem we are 143 | # solving, however it has been observed empirically to work pretty well. 144 | done_bool = float(done) if episode_timesteps < env._max_episode_steps else 0 145 | 146 | # Store data in replay buffer 147 | replay_buffer.add(state, action, next_state, reward, done_bool) 148 | 149 | state = next_state 150 | episode_reward += reward 151 | 152 | # Train agent after collecting sufficient data 153 | if t >= args.start_timesteps: 154 | rng, update_rng = jax.random.split(rng) 155 | agent.update(replay_buffer, args.batch_size, update_rng) 156 | 157 | if done: 158 | # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True 159 | print(f"Total T: {t + 1} Episode Num: {episode_num + 1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}") 160 | # Reset environment 161 | state, done = env.reset(), False 162 | episode_reward = 0 163 | episode_timesteps = 0 164 | episode_num += 1 165 | 166 | # Evaluate episode 167 | if (t + 1) % args.eval_freq == 0: 168 | evaluations.append(eval_policy(agent, args.env, max_steps=env._max_episode_steps, eval_episodes=100)) 169 | np.save(f"./results/{args.policy}/{file_name}", evaluations) 170 | if evaluations[-1] > best_performance: 171 | best_performance = evaluations[-1] 172 | best_actor_params = agent.actor_params 173 | # if args.save_model: agent.save(f"./models/{args.policy}/{file_name}") 174 | 175 | # At the end, re-evaluate the policy which is presumed to be best. This ensures an un-biased estimator when 176 | # reporting the average best results across each run. 177 | agent.actor_params = best_actor_params 178 | evaluations.append(eval_policy(agent, args.env, max_steps=env._max_episode_steps, eval_episodes=100)) 179 | np.save(f"./results/{args.policy}/{file_name}", evaluations) 180 | print(f"Selected policy has an average score of: {evaluations[-1]:.3f}") 181 | 182 | 183 | if __name__ == "__main__": 184 | main() 185 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | import haiku as hk 3 | import jax 4 | import jax.numpy as jnp 5 | from jax.experimental import optix 6 | import rlax 7 | import numpy as np 8 | from networks import Actor, Critic 9 | import functools 10 | import replay_buffer as rb 11 | 12 | OptState = Any 13 | 14 | 15 | # Perform Polyak averaging provided two network parameters and the averaging value tau. 16 | @jax.jit 17 | def soft_update(target_params: hk.Params, online_params: hk.Params, tau: float = 0.005) -> hk.Params: 18 | return jax.tree_multimap(lambda x, y: (1 - tau) * x + tau * y, target_params, online_params) 19 | 20 | 21 | class Agent(object): 22 | """Agent class for the TD3 algorithm. Combines both the agent and the learner functions.""" 23 | def __init__( 24 | self, 25 | policy: str, 26 | action_dim: int, 27 | max_action: float, 28 | lr: float, 29 | discount: float, 30 | noise_clip: float, 31 | policy_noise: float, 32 | policy_freq: int, 33 | actor_rng: jnp.ndarray, 34 | critic_rng: jnp.ndarray, 35 | sample_state: np.ndarray 36 | ): 37 | self.discount = discount 38 | self.noise_clip = noise_clip 39 | self.policy_noise = policy_noise 40 | self.policy_freq = policy_freq 41 | self.max_action = max_action 42 | self.td3_update = policy == 'TD3' 43 | 44 | self.actor = hk.transform(lambda x: Actor(action_dim, max_action)(x)) 45 | actor_opt_init, self.actor_opt_update = optix.adam(lr) 46 | 47 | self.critic = hk.transform(lambda x: Critic()(x)) 48 | critic_opt_init, self.critic_opt_update = optix.adam(lr) 49 | 50 | self.actor_params = self.target_actor_params = self.actor.init(actor_rng, sample_state) 51 | self.actor_opt_state = actor_opt_init(self.actor_params) 52 | 53 | action = self.actor.apply(self.actor_params, sample_state) 54 | 55 | self.critic_params = self.target_critic_params = self.critic.init(critic_rng, jnp.concatenate((sample_state, action), 0)) 56 | self.critic_opt_state = critic_opt_init(self.critic_params) 57 | 58 | self.updates = 0 59 | 60 | def update(self, replay_buffer: rb.ReplayBuffer, batch_size: int, rng: jnp.ndarray) -> None: 61 | """ 62 | Sample batch of transitions and update both the policy and critic networks. 63 | As this function contains a conditional function, periodically updating the actor, we do not jit compile it. 64 | """ 65 | self.updates += 1 66 | 67 | # Provide each element an independent rng sample. 68 | replay_rand, critic_rand = jax.random.split(rng) 69 | 70 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size, replay_rand) 71 | 72 | self.critic_params, self.critic_opt_state = self.update_critic(self.critic_params, self.target_critic_params, 73 | self.target_actor_params, self.critic_opt_state, 74 | state, action, next_state, reward, not_done, 75 | critic_rand) 76 | 77 | if self.updates % self.policy_freq == 0: 78 | self.actor_params, self.actor_opt_state = self.update_actor(self.actor_params, self.critic_params, 79 | self.actor_opt_state, state) 80 | 81 | self.target_actor_params = soft_update(self.target_actor_params, self.actor_params) 82 | self.target_critic_params = soft_update(self.target_critic_params, self.critic_params) 83 | 84 | @functools.partial(jax.jit, static_argnums=0) 85 | def critic_1( 86 | self, 87 | critic_params: hk.Params, 88 | state_action: np.ndarray 89 | ) -> jnp.DeviceArray: 90 | """Retrieves the result from a single critic network. Relevant for the actor update rule.""" 91 | return self.critic.apply(critic_params, state_action)[0].squeeze(-1) 92 | 93 | @functools.partial(jax.jit, static_argnums=0) 94 | def actor_loss( 95 | self, 96 | actor_params: hk.Params, 97 | critic_params: hk.Params, 98 | state: np.ndarray 99 | ) -> jnp.DeviceArray: 100 | """Standard DDPG update rule based on the gradient through a single critic network.""" 101 | action = self.actor.apply(actor_params, state) 102 | return - jnp.mean(self.critic_1(critic_params, jnp.concatenate((state, action), 1))) 103 | 104 | @functools.partial(jax.jit, static_argnums=0) 105 | def update_actor( 106 | self, 107 | actor_params: hk.Params, 108 | critic_params: hk.Params, 109 | actor_opt_state: OptState, 110 | state: np.ndarray 111 | ) -> Tuple[hk.Params, OptState]: 112 | """Learning rule (stochastic gradient descent).""" 113 | _, gradient = jax.value_and_grad(self.actor_loss)(actor_params, critic_params, state) 114 | updates, opt_state = self.actor_opt_update(gradient, actor_opt_state) 115 | new_params = optix.apply_updates(actor_params, updates) 116 | return new_params, opt_state 117 | 118 | @functools.partial(jax.jit, static_argnums=0) 119 | def critic_loss( 120 | self, 121 | critic_params: hk.Params, 122 | target_critic_params: hk.Params, 123 | target_actor_params: hk.Params, 124 | state: np.ndarray, 125 | action: np.ndarray, 126 | next_state: np.ndarray, 127 | reward: np.ndarray, 128 | not_done: np.ndarray, 129 | rng: jnp.ndarray 130 | ) -> jnp.DeviceArray: 131 | """ 132 | TD3 adds truncated Gaussian noise to the policy while training the critic. 133 | Can be seen as a form of 'Exploration Consciousness' https://arxiv.org/abs/1812.05551 or simply as a 134 | regularization scheme. 135 | As this helps stabilize the critic, we also use this for the DDPG update rule. 136 | """ 137 | noise = ( 138 | jax.random.normal(rng, shape=action.shape) * self.policy_noise 139 | ).clip(-self.noise_clip, self.noise_clip) 140 | 141 | # Make sure the noisy action is within the valid bounds. 142 | next_action = ( 143 | self.actor.apply(target_actor_params, next_state) + noise 144 | ).clip(-self.max_action, self.max_action) 145 | 146 | next_q_1, next_q_2 = self.critic.apply(target_critic_params, jnp.concatenate((next_state, next_action), 1)) 147 | if self.td3_update: 148 | next_q = jax.lax.min(next_q_1, next_q_2) 149 | else: 150 | # Since the actor uses Q_1 for training, setting this as the target for the critic updates is sufficient to 151 | # obtain an equivalent update. 152 | next_q = next_q_1 153 | # Cut the gradient from flowing through the target critic. This is more efficient, computationally. 154 | target_q = jax.lax.stop_gradient(reward + self.discount * next_q * not_done) 155 | q_1, q_2 = self.critic.apply(critic_params, jnp.concatenate((state, action), 1)) 156 | 157 | return jnp.mean(rlax.l2_loss(q_1, target_q) + rlax.l2_loss(q_2, target_q)) 158 | 159 | @functools.partial(jax.jit, static_argnums=0) 160 | def update_critic( 161 | self, 162 | critic_params: hk.Params, 163 | target_critic_params: hk.Params, 164 | target_actor_params: hk.Params, 165 | critic_opt_state: OptState, 166 | state: np.ndarray, 167 | action: np.ndarray, 168 | next_state: np.ndarray, 169 | reward: np.ndarray, 170 | not_done: np.ndarray, 171 | rng: jnp.ndarray 172 | ) -> Tuple[hk.Params, OptState]: 173 | """Learning rule (stochastic gradient descent).""" 174 | _, gradient = jax.value_and_grad(self.critic_loss)(critic_params, target_critic_params, target_actor_params, 175 | state, action, next_state, reward, not_done, rng) 176 | updates, opt_state = self.critic_opt_update(gradient, critic_opt_state) 177 | new_params = optix.apply_updates(critic_params, updates) 178 | return new_params, opt_state 179 | 180 | @functools.partial(jax.jit, static_argnums=0) 181 | def policy(self, actor_params: hk.Params, state: np.ndarray) -> jnp.DeviceArray: 182 | return self.actor.apply(actor_params, state) 183 | --------------------------------------------------------------------------------