├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── a2c ├── a2c.py ├── a2c_multi.py └── baseline.py ├── ddpg ├── ddpg_jax.py ├── ddpg_jax_profile.py ├── ddpg_param_tuning.py └── ddpg_td3.py ├── demos ├── .DS_Store ├── bipedal │ ├── BipedalWalker-v3_ddpg_4.08.png │ ├── BipedalWalker-v3_td3_ddpg_310.10.png │ ├── BipedalWalker-v3_td3_ddpg_97.04.png │ ├── BipedalWalker-v3_td3_ddpg_97.28.mp4 │ └── td3_tensorflow.png ├── ddpg_ant │ ├── AntBulletEnv-v0_td3_ddpg_207.26.mp4 │ └── AntBulletEnv-v0_td3_ddpg_81.77.mp4 ├── ddpg_cheeta │ ├── .DS_Store │ ├── Screen Shot 2021-07-23 at 8.33.21 AM.png │ ├── Screen Shot 2021-07-23 at 8.33.28 AM.png │ ├── backwards │ │ └── HalfCheetahBulletEnv-v0_backwards_td3_ddpg_612.73.mp4 │ └── forwards │ │ ├── .DS_Store │ │ ├── HalfCheetahBulletEnv-v0_td3_ddpg_167.27.mp4 │ │ ├── HalfCheetahBulletEnv-v0_td3_ddpg_492.81.mp4 │ │ └── HalfCheetahBulletEnv-v0_td3_ddpg_94.63.mp4 ├── pendulum │ ├── Pendulum-v0_ddpg_-3.32.png │ └── Pendulum-v0_initmodel_-1897.19.png └── pong │ └── PongNoFrameskip-v4_13.00_PPO.mp4 ├── maml ├── 3maml.py ├── 3maml_dist.py ├── 4maml.py ├── 4maml_dist.py ├── env.py ├── maml.py ├── maml_no_val.py ├── maml_rl.py ├── maml_wave.py └── utils.py ├── ppo ├── env.py ├── ppo.py ├── ppo_brax.py ├── ppo_disc.py ├── ppo_multi.py ├── ppo_multi_disc.py └── tmp.py ├── qlearn.py ├── reinforce ├── cartpole.png ├── cartpole.sh ├── env.py ├── jax2.py ├── policy_grad.py ├── reinforce_cont.py ├── reinforce_jax.py ├── reinforce_linear_baseline.py ├── reinforce_torchVSjax.py ├── tmp.png └── tmp.py ├── tmp.md └── trpo ├── cont.py ├── debug.py └── trpo.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | trpo/pytorch-trpo/ 2 | ROM/ 3 | tmp/ 4 | runs/ 5 | models/ 6 | baselines/ 7 | *.rar 8 | *.pdf 9 | *.db 10 | **events.out.tf** 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Brennan Gebotys 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RL Implementations in JAX 2 | 3 | Single-file implementations focused on clarity rather than proper code standards :) 4 | 5 | | Algo | Path | Discrete Actions | Continuous Actions | Multi-CPU | Other | 6 | |-----------|------------|------------------|--------------------|------------|--------------------------------------------------------------------| 7 | | TRPO | trpo/ | trpo.py | cont.py | | | 8 | | PPO | ppo/ | ppo_disc.py | ppo.py | *_multi.py | | 9 | | MAML | maml/ | | | | *SineWave* = maml_wave.py | 10 | | DQN | | dqn.py | | | | 11 | | REINFORCE | reinforce/ | reinforce_jax.py | reinforce_cont.py | | *Pytorch* = policy_grad.py
*Time Comparison* = reinforce_torchVSjax.py | 12 | | DDPG | ddpg/ | | ddpg_jax.py | | *TD3_DDPG* = ddpg_td3.py | 13 | | A2C | a2c/ | a2c.py | | *_multi.py | | 14 | 15 | For a better understanding of TRPO optimization check out [Natural Gradient Descent without the Tears](https://gebob19.github.io/natural-gradient/) 16 | -------------------------------------------------------------------------------- /a2c/a2c.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import jax 3 | import jax.numpy as np 4 | import numpy as onp 5 | import matplotlib.pyplot as plt 6 | import gym 7 | import haiku as hk 8 | import random 9 | import optax 10 | 11 | #%% 12 | env = gym.make('CartPole-v0') 13 | 14 | n_actions = env.action_space.n 15 | obs_dim = env.observation_space.shape[0] 16 | 17 | #%% 18 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 19 | 20 | def _policy_value(obs): 21 | pi = hk.Sequential([ 22 | hk.Linear(64), jax.nn.relu, 23 | hk.Linear(64), jax.nn.relu, 24 | hk.Linear(n_actions, w_init=init_final), jax.nn.softmax 25 | ])(obs) 26 | 27 | v = hk.Sequential([ 28 | hk.Linear(64), jax.nn.relu, 29 | hk.Linear(64), jax.nn.relu, 30 | hk.Linear(1, w_init=init_final), 31 | ])(obs) 32 | return pi, v 33 | 34 | policy_value = hk.transform(_policy_value) 35 | policy_value = hk.without_apply_rng(policy_value) 36 | pv_frwd = jax.jit(policy_value.apply) # forward fcn 37 | 38 | #%% 39 | seed = onp.random.randint(1e5) # seed=81705 works 40 | print(f'[LOGGING] seed={seed}') 41 | rng = jax.random.PRNGKey(seed) 42 | onp.random.seed(seed) 43 | random.seed(seed) 44 | 45 | #%% 46 | class Categorical: # similar to pytorch categorical 47 | def __init__(self, probs): 48 | self.probs = probs 49 | def sample(self): 50 | # https://stackoverflow.com/questions/46539431/np-random-choice-probabilities-do-not-sum-to-1 51 | p = onp.asanyarray(self.probs) 52 | p = p / p.sum() 53 | a = onp.random.choice(onp.arange(len(self.probs)), p=p) 54 | return a 55 | def log_prob(self, i): return np.log(self.probs[i]) 56 | def entropy(self): return -(self.probs * np.log(self.probs)).sum() 57 | 58 | def discount_cumsum(l, discount): 59 | l = onp.array(l) 60 | for i in range(len(l) - 1)[::-1]: 61 | l[i] = l[i] + discount * l[i+1] 62 | return l 63 | 64 | global_env_count = 0 65 | def rollout_v(step_i, params, env, max_n_steps=200): 66 | global global_env_count 67 | 68 | obs = env.reset() 69 | # obs, obs2 + a, r, done, 70 | v_buffer = onp.zeros((max_n_steps, 2 * obs_dim + 3)) 71 | 72 | for i in range(max_n_steps): 73 | a_probs, v_s = pv_frwd(params, obs) 74 | a_dist = Categorical(a_probs) 75 | a = a_dist.sample() 76 | 77 | entropy = a_dist.entropy().item() 78 | writer.add_scalar('policy/entropy', entropy, global_env_count) 79 | writer.add_scalar('policy/value', v_s.item(), global_env_count) 80 | 81 | obs2, r, done, _ = env.step(a) 82 | v_buffer[i] = onp.array([*obs, a, r, *obs2, float(done)]) 83 | 84 | global_env_count += 1 85 | obs = obs2 86 | if done: break 87 | 88 | v_buffer = v_buffer[:i+1] 89 | obs, a, r, obs2, done = onp.split(v_buffer, [obs_dim, obs_dim+1, obs_dim+2, obs_dim*2+2], axis=-1) 90 | writer.add_scalar('rollout/total_reward', r.sum(), step_i) 91 | 92 | r = discount_cumsum(r, discount=0.99) 93 | 94 | return obs, a, r, obs2, done 95 | 96 | from functools import partial 97 | 98 | # obs, a, r, obs2, done 99 | def policy_loss(params, obs, a, r): 100 | a_probs, v_s = pv_frwd(params, obs) 101 | a_dist = Categorical(a_probs) 102 | 103 | log_prob = a_dist.log_prob(a.astype(np.int32)) 104 | advantage = jax.lax.stop_gradient(r - v_s) 105 | policy_loss = -(log_prob * advantage).sum() 106 | 107 | entropy_loss = -0.001 * a_dist.entropy() 108 | return policy_loss + entropy_loss 109 | 110 | def critic_loss(params, obs, r): 111 | _, v_s = pv_frwd(params, obs) 112 | return ((v_s - r) ** 2).sum() 113 | 114 | def a2c_loss(params, sample): 115 | obs, a, r, _, _ = sample 116 | ploss = policy_loss(params, obs, a, r) 117 | vloss = critic_loss(params, obs, r) 118 | loss = ploss + 0.25 * vloss 119 | return loss, ploss, vloss 120 | 121 | def batch_a2c_loss(params, samples): 122 | loss, ploss, vloss = jax.vmap(partial(a2c_loss, params))(samples) 123 | return loss.mean(), (ploss.mean(), vloss.mean()) 124 | 125 | @jax.jit 126 | def a2c_step(samples, params, opt_state): 127 | (loss, (ploss, vloss)), grad = jax.value_and_grad(batch_a2c_loss, has_aux=True)(params, samples) 128 | grad, opt_state = optim.update(grad, opt_state) 129 | params = optax.apply_updates(params, grad) 130 | return loss, ploss, vloss, opt_state, params, grad 131 | 132 | #%% 133 | from torch.utils.tensorboard import SummaryWriter 134 | writer = SummaryWriter(comment='a2c_test') 135 | 136 | #%% 137 | obs = env.reset() # dummy input 138 | a = np.zeros(env.action_space.shape) 139 | params = policy_value.init(rng, obs) 140 | 141 | optim = optax.chain( 142 | optax.clip_by_global_norm(0.5), 143 | optax.scale_by_adam(), 144 | optax.scale(-1e-3), 145 | ) 146 | 147 | opt_state = optim.init(params) 148 | 149 | # from tqdm.notebook import tqdm 150 | from tqdm import tqdm 151 | 152 | n_episodes = 1000 153 | for step_i in tqdm(range(n_episodes)): 154 | 155 | samples = rollout_v(step_i, params, env) 156 | loss, ploss, vloss, opt_state, params, grads = a2c_step(samples, params, opt_state) 157 | 158 | writer.add_scalar('loss/loss', loss.item(), step_i) 159 | writer.add_scalar('loss/policy', ploss.item(), step_i) 160 | writer.add_scalar('loss/critic', vloss.item(), step_i) 161 | 162 | # for i, g in enumerate(jax.tree_leaves(grads)): 163 | # name = 'b' if len(g.shape) == 1 else 'w' 164 | # writer.add_histogram(f'{name}_{i}_grad', onp.array(g), step_i) 165 | 166 | # #%% 167 | # #%% -------------------------------------------------------------------------------- /a2c/a2c_multi.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import gym 3 | import ray 4 | import jax 5 | import jax.numpy as np 6 | import numpy as onp 7 | import haiku as hk 8 | import optax 9 | from functools import partial 10 | 11 | from torch.utils.tensorboard import SummaryWriter 12 | from tqdm import tqdm 13 | import cloudpickle 14 | 15 | #%% 16 | ray.init() 17 | 18 | env_name = 'CartPole-v0' 19 | env = gym.make(env_name) 20 | 21 | n_actions = env.action_space.n 22 | obs = env.reset() 23 | obs_dim = obs.shape[0] 24 | 25 | print(f'[LOGGER] obs_dim: {obs_dim} n_actions: {n_actions}') 26 | 27 | def _policy_value(obs): 28 | pi = hk.Sequential([ 29 | hk.Linear(128), jax.nn.relu, 30 | hk.Linear(128), jax.nn.relu, 31 | hk.Linear(n_actions), jax.nn.softmax 32 | ])(obs) 33 | v = hk.Sequential([ 34 | hk.Linear(128), jax.nn.relu, 35 | hk.Linear(128), jax.nn.relu, 36 | hk.Linear(1), 37 | ])(obs) 38 | return pi, v 39 | 40 | policy_value = hk.transform(_policy_value) 41 | policy_value = hk.without_apply_rng(policy_value) 42 | pv_frwd = jax.jit(policy_value.apply) 43 | 44 | class Categorical: # similar to pytorch categorical 45 | def __init__(self, probs): 46 | self.probs = probs 47 | def sample(self): 48 | # https://stackoverflow.com/questions/46539431/np-random-choice-probabilities-do-not-sum-to-1 49 | p = onp.asanyarray(self.probs) 50 | p = p / p.sum() 51 | a = onp.random.choice(onp.arange(len(self.probs)), p=p) 52 | return a 53 | def log_prob(self, i): return np.log(self.probs[i]) 54 | def entropy(self): return -(self.probs * np.log(self.probs)).sum() 55 | 56 | @ray.remote 57 | class Worker: 58 | def __init__(self, gamma=0.99): 59 | self.env = gym.make(env_name) 60 | self.obs = self.env.reset() 61 | 62 | self.gamma = gamma 63 | # create jax policy fcn -- need to define in .remote due to pickling 64 | policy_value = hk.transform(_policy_value) 65 | policy_value = hk.without_apply_rng(policy_value) 66 | self.pv_frwd = jax.jit(policy_value.apply) # forward fcn 67 | 68 | def rollout(self, params, n_steps): 69 | # obs, obs2 + a, r, done, 70 | v_buffer = onp.zeros((n_steps, 2 * obs_dim + 3)) 71 | 72 | for i in range(n_steps): 73 | a_probs, _ = self.pv_frwd(params, self.obs) 74 | a = Categorical(a_probs).sample() # stochastic sample 75 | 76 | obs2, r, done, _ = self.env.step(a) 77 | v_buffer[i] = onp.array([*self.obs, a, r, *obs2, float(done)]) 78 | 79 | self.obs = obs2 80 | if done: 81 | self.obs = self.env.reset() 82 | break 83 | 84 | v_buffer = v_buffer[:i+1] 85 | obs, a, r, obs2, done = onp.split(v_buffer, [obs_dim, obs_dim+1, obs_dim+2, obs_dim*2+2], axis=-1) 86 | 87 | for i in range(len(r) - 1)[::-1]: 88 | r[i] = r[i] + self.gamma * r[i + 1] 89 | 90 | return obs, a, r, obs2, done 91 | 92 | #%% 93 | def eval(params, env): 94 | rewards = 0 95 | obs = env.reset() 96 | while True: 97 | a_probs, _ = pv_frwd(params, obs) 98 | a_dist = Categorical(a_probs) 99 | a = a_dist.sample() 100 | obs2, r, done, _ = env.step(a) 101 | obs = obs2 102 | rewards += r 103 | if done: break 104 | return rewards 105 | 106 | def policy_loss(params, obs, a, r): 107 | a_probs, v_s = pv_frwd(params, obs) 108 | a_dist = Categorical(a_probs) 109 | 110 | log_prob = a_dist.log_prob(a.astype(np.int32)) 111 | advantage = jax.lax.stop_gradient(r - v_s) 112 | policy_loss = -(log_prob * advantage).sum() 113 | 114 | entropy_loss = -0.001 * a_dist.entropy() 115 | return policy_loss + entropy_loss 116 | 117 | def critic_loss(params, obs, r): 118 | _, v_s = pv_frwd(params, obs) 119 | return ((v_s - r) ** 2).sum() 120 | 121 | def a2c_loss(params, sample): 122 | obs, a, r, _, _ = sample 123 | ploss = policy_loss(params, obs, a, r) 124 | vloss = critic_loss(params, obs, r) 125 | loss = ploss + 0.25 * vloss 126 | return loss, ploss, vloss 127 | 128 | def batch_a2c_loss(params, samples): 129 | loss, ploss, vloss = jax.vmap(partial(a2c_loss, params))(samples) 130 | return loss.mean(), (ploss.mean(), vloss.mean()) 131 | 132 | @jax.jit 133 | def a2c_step(samples, params, opt_state): 134 | (loss, (ploss, vloss)), grad = jax.value_and_grad(batch_a2c_loss, has_aux=True)(params, samples) 135 | grad, opt_state = optim.update(grad, opt_state) 136 | params = optax.apply_updates(params, grad) 137 | return loss, ploss, vloss, opt_state, params, grad 138 | 139 | #%% 140 | seed = onp.random.randint(1e5) 141 | rng = jax.random.PRNGKey(seed) 142 | 143 | import random 144 | onp.random.seed(seed) 145 | random.seed(seed) 146 | 147 | obs = env.reset() # dummy input 148 | a = np.zeros(env.action_space.shape) 149 | params = policy_value.init(rng, obs) 150 | 151 | optim = optax.chain( 152 | optax.clip_by_global_norm(0.5), 153 | optax.scale_by_adam(), 154 | optax.scale(-1e-3), 155 | ) 156 | opt_state = optim.init(params) 157 | 158 | n_envs = 16 159 | n_steps = 20 160 | print(f'[LOGGER] using batchsize = {n_envs * n_steps}') 161 | 162 | workers = [Worker.remote() for _ in range(n_envs)] 163 | 164 | writer = SummaryWriter(comment=f'{env_name}_n-envs{n_envs}_seed{seed}') 165 | max_reward = -float('inf') 166 | 167 | import pathlib 168 | model_path = pathlib.Path(f'./models/a2c/{env_name}') 169 | model_path.mkdir(exist_ok=True, parents=True) 170 | 171 | for step_i in tqdm(range(1000)): 172 | rollouts = ray.get([worker.rollout.remote(params, n_steps) for worker in workers]) 173 | samples = jax.tree_multimap(lambda *a: np.concatenate(a), *rollouts, is_leaf=lambda node: hasattr(node, 'shape')) 174 | 175 | loss, ploss, vloss, opt_state, params, grads = a2c_step(samples, params, opt_state) 176 | writer.add_scalar('loss/policy', ploss.item(), step_i) 177 | writer.add_scalar('loss/critic', vloss.item(), step_i) 178 | writer.add_scalar('loss/total', loss.item(), step_i) 179 | writer.add_scalar('loss/batch_size', samples[0].shape[0], step_i) 180 | 181 | obs, a, r, done, _ = samples 182 | a_probs, v_s = jax.vmap(lambda o: pv_frwd(params, o))(obs) 183 | mean_entropy = jax.vmap(lambda p: Categorical(p).entropy())(a_probs).mean() 184 | mean_value = v_s.mean() 185 | writer.add_scalar('policy/mean_entropy', mean_entropy.item(), step_i) 186 | writer.add_scalar('critic/mean_value', mean_value.item(), step_i) 187 | 188 | eval_r = eval(params, env) 189 | writer.add_scalar('rollout/eval_reward', eval_r, step_i) 190 | 191 | if eval_r > max_reward: 192 | max_reward = eval_r 193 | model_save_path = str(model_path/f'params_{max_reward:.2f}') 194 | print(f'saving model to... {model_save_path}') 195 | with open(model_save_path, 'wb') as f: 196 | cloudpickle.dump(params, f) 197 | 198 | #%% 199 | #%% -------------------------------------------------------------------------------- /a2c/baseline.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import gym 3 | from stable_baselines3 import A2C, PPO 4 | from stable_baselines3.common.env_util import make_vec_env 5 | 6 | # Parallel environments 7 | env = make_vec_env('Pendulum-v0', n_envs=4) 8 | # model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="./tmp/") 9 | model = A2C("MlpPolicy", env, verbose=1, tensorboard_log="./tmp/") 10 | 11 | model.learn(total_timesteps=1000*200) 12 | 13 | #%% 14 | #%% -------------------------------------------------------------------------------- /ddpg/ddpg_jax.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import jax 3 | import jax.numpy as np 4 | import numpy as onp 5 | import haiku as hk 6 | 7 | import optax 8 | import gym 9 | import copy 10 | import matplotlib.pyplot as plt 11 | from torch.utils.tensorboard import SummaryWriter 12 | import cloudpickle 13 | 14 | # from tqdm.notebook import tqdm 15 | from tqdm import tqdm 16 | 17 | import collections 18 | import random 19 | from functools import partial 20 | 21 | import pybullet as p 22 | import pybullet_envs 23 | from numpngw import write_apng 24 | 25 | from jax.config import config 26 | config.update("jax_debug_nans", True) # break on nans 27 | 28 | #%% 29 | # env_name = 'AntBulletEnv-v0' 30 | # env_name = 'CartPoleContinuousBulletEnv-v0' 31 | # env_name = 'Pendulum-v0' ## hyperparams work for this env with correct seed 32 | # env_name = 'BipedalWalker-v3' 33 | env_name = 'HalfCheetahBulletEnv-v0' 34 | 35 | env = gym.make(env_name) 36 | n_actions = env.action_space.shape[0] 37 | obs_dim = env.observation_space.shape[0] 38 | 39 | a_high = env.action_space.high[0] 40 | a_low = env.action_space.low[0] 41 | assert -a_high == a_low 42 | 43 | #%% 44 | class FanIn_Uniform(hk.initializers.Initializer): 45 | def __call__(self, shape, dtype): 46 | bound = 1/(shape[0] ** .5) 47 | return hk.initializers.RandomUniform(-bound, bound)(shape, dtype) 48 | 49 | init_other = FanIn_Uniform() 50 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 51 | 52 | def middle_layer(d): 53 | return hk.Sequential([ 54 | hk.Linear(d, w_init=init_other), jax.nn.relu, 55 | ]) 56 | 57 | def _policy_fcn(s): 58 | policy = hk.Sequential([ 59 | middle_layer(256), 60 | middle_layer(256), 61 | hk.Linear(n_actions, w_init=init_final), 62 | jax.nn.tanh, 63 | ]) 64 | a = policy(s) * a_high # scale to action range 65 | return a 66 | 67 | def _q_fcn(s, a): 68 | z1 = middle_layer(256)(s) 69 | z1 = middle_layer(256)(z1) 70 | 71 | z2 = middle_layer(256)(a) 72 | z = np.concatenate([z1, z2]) 73 | z = middle_layer(256)(z) 74 | z = middle_layer(256)(z) 75 | q_sa = hk.Linear(1, w_init=init_final)(z) 76 | return q_sa 77 | 78 | #%% 79 | class ReplayBuffer(object): # clean code BUT extremely slow 80 | def __init__(self, capacity): 81 | self.buffer = collections.deque(maxlen=int(capacity)) 82 | 83 | def push(self, sample): 84 | self.buffer.append(sample) 85 | 86 | def sample(self, batch_size): 87 | samples = zip(*random.sample(self.buffer, batch_size)) 88 | samples = tuple(map(lambda x: np.stack(x).astype(np.float32), samples)) 89 | return samples 90 | 91 | def is_ready(self, batch_size): 92 | return batch_size <= len(self.buffer) 93 | 94 | class Vector_ReplayBuffer: 95 | def __init__(self, buffer_capacity): 96 | self.buffer_capacity = buffer_capacity = int(buffer_capacity) 97 | self.buffer_counter = 0 98 | # obs, obs2, a, r, done 99 | self.buffer = onp.zeros((buffer_capacity, 2 * obs_dim + n_actions + 2)) 100 | 101 | def push(self, sample): 102 | i = self.buffer_counter % self.buffer_capacity 103 | (obs, a, obs2, r, done) = sample 104 | self.buffer[i] = onp.array([*obs, *onp.array(a), onp.array(r), *obs2, float(done)]) 105 | self.buffer_counter += 1 106 | 107 | def sample(self, batch_size): 108 | record_range = min(self.buffer_counter, self.buffer_capacity) 109 | idxs = onp.random.choice(record_range, batch_size) 110 | batch = self.buffer[idxs] 111 | obs, a, r, obs2, done = onp.split(batch, [obs_dim, obs_dim+n_actions, obs_dim+n_actions+1, obs_dim*2+1+n_actions], axis=-1) 112 | assert obs.shape[-1] == obs_dim and obs2.shape[-1] == obs_dim and a.shape[-1] == n_actions \ 113 | and r.shape[-1] == 1, (obs.shape, a.shape, r.shape, obs2.shape, r.shape, done.shape) 114 | return (obs, a, obs2, r, done) 115 | 116 | def is_ready(self, batch_size): 117 | return self.buffer_counter >= batch_size 118 | 119 | #%% 120 | def critic_loss(q_params, target_params, sample): 121 | p_params_t, q_params_t = target_params 122 | obs, a, obs2, r, done = sample 123 | # pred 124 | q = q_frwd(q_params, obs, a) 125 | # target 126 | a_t = p_frwd(p_params_t, obs2) 127 | q_t = q_frwd(q_params_t, obs2, a_t) 128 | y = r + (1 - done) * gamma * q_t 129 | y = jax.lax.stop_gradient(y) 130 | 131 | loss = (q - y) ** 2 132 | return loss 133 | 134 | def policy_loss(p_params, q_params, sample): 135 | obs, _, _, _, _ = sample 136 | a = p_frwd(p_params, obs) 137 | return -q_frwd(q_params, obs, a) 138 | 139 | def batch_critic_loss(q_params, target_params, batch): 140 | return jax.vmap(partial(critic_loss, q_params, target_params))(batch).mean() 141 | 142 | def batch_policy_loss(p_params, q_params, batch): 143 | return jax.vmap(partial(policy_loss, p_params, q_params))(batch).mean() 144 | 145 | def eval(p_params, env, name): 146 | rewards = 0 147 | imgs = [] 148 | obs = env.reset() 149 | while True: 150 | img = env.render(mode='rgb_array') 151 | imgs.append(img) 152 | 153 | a = p_frwd(p_params, obs) 154 | obs2, r, done, _ = env.step(a) 155 | obs = obs2 156 | rewards += r 157 | if done: break 158 | 159 | print(f'writing len {len(imgs)} total reward {rewards}...') 160 | write_apng(f'{name}_{rewards:.2f}.png', imgs, delay=20) 161 | 162 | return imgs, rewards 163 | 164 | #%% 165 | @jax.jit 166 | def ddpg_step(params, opt_states, batch): 167 | p_params, q_params, p_params_t, q_params_t = params 168 | p_opt_state, q_opt_state = opt_states 169 | 170 | # update q/critic 171 | target_params = (p_params_t, q_params_t) 172 | q_loss, q_grad = jax.value_and_grad(batch_critic_loss)(q_params, target_params, batch) 173 | q_grad, q_opt_state = q_optim.update(q_grad, q_opt_state) 174 | q_params = optax.apply_updates(q_params, q_grad) 175 | 176 | # update policy 177 | p_loss, p_grad = jax.value_and_grad(batch_policy_loss)(p_params, q_params, batch) 178 | p_grad, p_opt_state = p_optim.update(p_grad, p_opt_state) 179 | p_params = optax.apply_updates(p_params, p_grad) 180 | 181 | # slow update targets 182 | polask_avg = lambda target, w: (1 - tau) * target + tau * w 183 | p_params_t = jax.tree_multimap(polask_avg, p_params_t, p_params) 184 | q_params_t = jax.tree_multimap(polask_avg, q_params_t, q_params) 185 | 186 | params = (p_params, q_params, p_params_t, q_params_t) # re-pack with updated q 187 | opt_states = (p_opt_state, q_opt_state) 188 | losses = (p_loss, q_loss) 189 | grads = (p_grad, q_grad) 190 | return params, opt_states, losses, grads 191 | 192 | class OU_Noise: 193 | def __init__(self, shape): 194 | self.shape = shape 195 | self.theta = 0.15 196 | self.dt = 1e-2 197 | self.sigma = 0.2 198 | self.mean = onp.zeros(shape) 199 | self.reset() # set prev 200 | 201 | def sample(self): 202 | noise = onp.random.normal(size=self.shape) 203 | x = ( 204 | self.prev 205 | + self.theta * (self.mean - self.prev) * self.dt 206 | + self.sigma * onp.sqrt(self.dt) * noise 207 | ) 208 | self.prev = x 209 | return x 210 | 211 | def reset(self): self.prev = onp.zeros(self.shape) 212 | 213 | #%% 214 | n_episodes = 1000 215 | batch_size = 64 216 | buffer_size = 1e6 217 | gamma = 0.99 218 | tau = 0.005 # 1e-4 ## very important parameter -- make or break 219 | seed = onp.random.randint(1e5) # 420 works 220 | print(f'[LOGGING] seed={seed}') 221 | 222 | policy_lr = 1e-3 223 | q_lr = 2e-3 224 | 225 | # metric writer 226 | writer = SummaryWriter(comment=f'{env_name}_seed={seed}') 227 | 228 | rng = jax.random.PRNGKey(seed) 229 | onp.random.seed(seed) 230 | random.seed(seed) 231 | 232 | ## model defn 233 | # actor 234 | policy_fcn = hk.transform(_policy_fcn) 235 | policy_fcn = hk.without_apply_rng(policy_fcn) 236 | p_frwd = jax.jit(policy_fcn.apply) 237 | 238 | # critic 239 | q_fcn = hk.transform(_q_fcn) 240 | q_fcn = hk.without_apply_rng(q_fcn) 241 | q_frwd = jax.jit(q_fcn.apply) 242 | 243 | ## optimizers 244 | p_optim = optax.adam(policy_lr) 245 | q_optim = optax.adam(q_lr) 246 | 247 | #%% 248 | eps = 1 249 | eps_decay = 1/n_episodes 250 | 251 | vbuffer = Vector_ReplayBuffer(buffer_size) 252 | action_noise = OU_Noise((n_actions,)) 253 | 254 | # init models + optims 255 | obs = env.reset() # dummy input 256 | a = np.zeros(env.action_space.shape) 257 | p_params = policy_fcn.init(rng, obs) 258 | q_params = q_fcn.init(rng, obs, a) 259 | 260 | # target networks 261 | p_params_t = copy.deepcopy(p_params) 262 | q_params_t = copy.deepcopy(q_params) 263 | 264 | p_opt_state = p_optim.init(p_params) 265 | q_opt_state = q_optim.init(q_params) 266 | 267 | # bundle 268 | params = (p_params, q_params, p_params_t, q_params_t) 269 | opt_states = (p_opt_state, q_opt_state) 270 | 271 | import pathlib 272 | model_path = pathlib.Path(f'./models/ddpg/{env_name}') 273 | model_path.mkdir(exist_ok=True, parents=True) 274 | 275 | # #%% 276 | # step_i = 0 277 | # for epi_i in tqdm(range(n_episodes)): 278 | 279 | # action_noise.reset() 280 | # obs = env.reset() 281 | # rewards = [] 282 | # while True: 283 | # # rollout 284 | # p_params = params[0] 285 | # a = p_frwd(p_params, obs) + eps * action_noise.sample() 286 | # a = np.clip(a, a_low, a_high) 287 | 288 | # obs2, r, done, _ = env.step(a) 289 | # vbuffer.push((obs, a, obs2, r, done)) 290 | # obs = obs2 291 | # rewards.append(onp.asanyarray(r)) 292 | 293 | # # update 294 | # if not vbuffer.is_ready(batch_size): continue 295 | # batch = vbuffer.sample(batch_size) 296 | # params, opt_states, losses, grads = ddpg_step(params, opt_states, batch) 297 | 298 | # p_loss, q_loss = losses 299 | # writer.add_scalar('loss/policy', p_loss.item(), step_i) 300 | # writer.add_scalar('loss/q_fcn', q_loss.item(), step_i) 301 | 302 | # step_i += 1 303 | # if done: break 304 | 305 | # writer.add_scalar('rollout/total_reward', sum(rewards), epi_i) 306 | # writer.add_scalar('rollout/length', len(rewards), epi_i) 307 | 308 | # eps -= eps_decay # decay exploration 309 | 310 | # if epi_i == 0 or sum(rewards) > max_reward: 311 | # max_reward = sum(rewards) 312 | # with open(str(model_path/f'params_{max_reward:.2f}'), 'wb') as f: 313 | # cloudpickle.dump((p_params, q_params), f) 314 | 315 | ## loading and evaluating model 316 | # %% 317 | ppath = str(model_path/f'params_-2.11') 318 | ppath = 'models/ddpg_td3/params_854.89' 319 | with open(ppath, 'rb') as f: 320 | p_params, q_params = cloudpickle.load(f) 321 | 322 | eval(p_params, env, f'{env_name}_ddpg') 323 | 324 | # # %% 325 | # obs = env.reset() # dummy input 326 | # a = np.zeros(env.action_space.shape) 327 | # p_params = policy_fcn.init(rng, obs) 328 | # eval(p_params, env, f'{env_name}_initmodel') 329 | 330 | # %% 331 | -------------------------------------------------------------------------------- /ddpg/ddpg_jax_profile.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import jax 3 | import jax.numpy as np 4 | import numpy as onp 5 | import haiku as hk 6 | import optax 7 | import gym 8 | import copy 9 | import matplotlib.pyplot as plt 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | # from tqdm.notebook import tqdm 13 | from tqdm import tqdm 14 | 15 | import collections 16 | import random 17 | from functools import partial 18 | 19 | import pybullet as p 20 | import pybullet_envs 21 | from numpngw import write_apng 22 | 23 | from jax.config import config 24 | config.update("jax_debug_nans", True) # break on nans 25 | 26 | #%% 27 | # env_name = 'AntBulletEnv-v0' 28 | # env_name = 'CartPoleContinuousBulletEnv-v0' 29 | env_name = 'Pendulum-v0' ## hyperparams work for this env with correct seed 30 | # env_name = 'HalfCheetahBulletEnv-v0' 31 | 32 | env = gym.make(env_name) 33 | n_actions = env.action_space.shape[0] 34 | obs_dim = env.observation_space.shape[0] 35 | 36 | a_high = env.action_space.high[0] 37 | a_low = env.action_space.low[0] 38 | 39 | #%% 40 | class FanIn_Uniform(hk.initializers.Initializer): 41 | def __call__(self, shape, dtype): 42 | bound = 1/(shape[0] ** .5) 43 | return hk.initializers.RandomUniform(-bound, bound)(shape, dtype) 44 | 45 | init_other = FanIn_Uniform() 46 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 47 | 48 | def middle_layer(d): 49 | return hk.Sequential([ 50 | hk.Linear(d, w_init=init_other), jax.nn.relu, 51 | ]) 52 | 53 | def _policy_fcn(s): 54 | policy = hk.Sequential([ 55 | middle_layer(256), 56 | middle_layer(256), 57 | hk.Linear(n_actions, w_init=init_final), 58 | jax.nn.tanh, 59 | ]) 60 | a = policy(s) * a_high # scale to action range 61 | return a 62 | 63 | def _q_fcn(s, a): 64 | z1 = middle_layer(16)(s) 65 | z1 = middle_layer(32)(z1) 66 | 67 | z2 = middle_layer(32)(a) 68 | z = np.concatenate([z1, z2]) 69 | z = middle_layer(256)(z) 70 | z = middle_layer(256)(z) 71 | q_sa = hk.Linear(1, w_init=init_final)(z) 72 | return q_sa 73 | 74 | #%% 75 | class ReplayBuffer(object): # clean code BUT extremely slow 76 | def __init__(self, capacity): 77 | self.buffer = collections.deque(maxlen=int(capacity)) 78 | 79 | def push(self, sample): 80 | self.buffer.append(sample) 81 | 82 | def sample(self, batch_size): 83 | samples = zip(*random.sample(self.buffer, batch_size)) 84 | samples = tuple(map(lambda x: np.stack(x).astype(np.float32), samples)) 85 | return samples 86 | 87 | def is_ready(self, batch_size): 88 | return batch_size <= len(self.buffer) 89 | 90 | #%% 91 | class Vector_ReplayBuffer: 92 | def __init__(self, buffer_capacity): 93 | buffer_capacity = int(buffer_capacity) 94 | # Number of "experiences" to store at max 95 | self.buffer_capacity = buffer_capacity 96 | 97 | # Its tells us num of times record() was called. 98 | self.buffer_counter = 0 99 | 100 | # Instead of list of tuples as the exp.replay concept go 101 | # We use different np.arrays for each tuple element 102 | num_states = env.observation_space.shape[0] 103 | num_actions = env.action_space.shape[0] 104 | self.state_buffer = onp.zeros((self.buffer_capacity, num_states)) 105 | self.action_buffer = onp.zeros((self.buffer_capacity, num_actions)) 106 | self.reward_buffer = onp.zeros((self.buffer_capacity, 1)) 107 | self.dones = onp.zeros((self.buffer_capacity, 1)) 108 | self.next_state_buffer = onp.zeros((self.buffer_capacity, num_states)) 109 | self.buffers = [self.state_buffer, self.action_buffer, self.next_state_buffer, self.reward_buffer, self.dones] 110 | 111 | def push(self, obs_tuple): 112 | # (obs, a, obs2, r, done) 113 | index = self.buffer_counter % self.buffer_capacity 114 | 115 | self.state_buffer[index] = obs_tuple[0] 116 | self.action_buffer[index] = obs_tuple[1] 117 | self.next_state_buffer[index] = obs_tuple[2] 118 | self.reward_buffer[index] = obs_tuple[3] 119 | self.dones[index] = float(obs_tuple[4]) # dones T/F -> 1/0 120 | 121 | self.buffer_counter += 1 122 | 123 | def is_ready(self, batch_size): return self.buffer_counter >= batch_size 124 | 125 | def sample(self, batch_size): 126 | record_range = min(self.buffer_counter, self.buffer_capacity) 127 | batch_indices = onp.random.choice(record_range, batch_size) 128 | batch = tuple(b[batch_indices] for b in self.buffers) 129 | return batch 130 | 131 | #%% 132 | def critic_loss(q_params, target_params, sample): 133 | p_params_t, q_params_t = target_params 134 | obs, a, obs2, r, done = sample 135 | # pred 136 | q = q_frwd(q_params, obs, a) 137 | # target 138 | a_t = p_frwd(p_params_t, obs2) 139 | q_t = q_frwd(q_params_t, obs2, a_t) 140 | y = r + (1 - done) * gamma * q_t 141 | y = jax.lax.stop_gradient(y) 142 | 143 | loss = (q - y) ** 2 144 | return loss 145 | 146 | def policy_loss(p_params, q_params, sample): 147 | obs, _, _, _, _ = sample 148 | a = p_frwd(p_params, obs) 149 | return -q_frwd(q_params, obs, a) 150 | 151 | def batch_critic_loss(q_params, target_params, batch): 152 | return jax.vmap(partial(critic_loss, q_params, target_params))(batch).mean() 153 | 154 | def batch_policy_loss(p_params, q_params, batch): 155 | return jax.vmap(partial(policy_loss, p_params, q_params))(batch).mean() 156 | 157 | #%% 158 | @jax.jit 159 | def ddpg_step(params, opt_states, batch): 160 | p_params, q_params, p_params_t, q_params_t = params 161 | p_opt_state, q_opt_state = opt_states 162 | 163 | # update q/critic 164 | target_params = (p_params_t, q_params_t) 165 | q_loss, q_grad = jax.value_and_grad(batch_critic_loss)(q_params, target_params, batch) 166 | q_grad, q_opt_state = q_optim.update(q_grad, q_opt_state) 167 | q_params = optax.apply_updates(q_params, q_grad) 168 | 169 | # update policy 170 | p_loss, p_grad = jax.value_and_grad(batch_policy_loss)(p_params, q_params, batch) 171 | p_grad, p_opt_state = p_optim.update(p_grad, p_opt_state) 172 | p_params = optax.apply_updates(p_params, p_grad) 173 | 174 | # slow update targets 175 | polask_avg = lambda target, w: (1 - tau) * target + tau * w 176 | p_params_t = jax.tree_multimap(polask_avg, p_params_t, p_params) 177 | q_params_t = jax.tree_multimap(polask_avg, q_params_t, q_params) 178 | 179 | params = (p_params, q_params, p_params_t, q_params_t) # re-pack with updated q 180 | opt_states = (p_opt_state, q_opt_state) 181 | losses = (p_loss, q_loss) 182 | grads = (p_grad, q_grad) 183 | return params, opt_states, losses, grads 184 | 185 | class OU_Noise: 186 | def __init__(self, shape): 187 | self.shape = shape 188 | self.theta = 0.15 189 | self.dt = 1e-2 190 | self.sigma = 0.2 191 | self.mean = onp.zeros(shape) 192 | self.reset() # set prev 193 | 194 | def sample(self): 195 | noise = onp.random.normal(size=self.shape) 196 | x = ( 197 | self.prev 198 | + self.theta * (self.mean - self.prev) * self.dt 199 | + self.sigma * onp.sqrt(self.dt) * noise 200 | ) 201 | self.prev = x 202 | return x 203 | 204 | def reset(self): self.prev = onp.zeros(self.shape) 205 | 206 | #%% 207 | n_episodes = 100 208 | batch_size = 64 209 | buffer_size = 1e6 210 | gamma = 0.99 211 | tau = 0.005 # 1e-4 ## very important parameter -- make or break 212 | seed = 420 # 420 works 213 | 214 | policy_lr = 1e-3 215 | q_lr = 2e-3 216 | 217 | # metric writer 218 | writer = SummaryWriter(comment='ant') 219 | 220 | rng = jax.random.PRNGKey(seed) 221 | onp.random.seed(seed) 222 | random.seed(seed) 223 | 224 | ## model defn 225 | # actor 226 | policy_fcn = hk.transform(_policy_fcn) 227 | policy_fcn = hk.without_apply_rng(policy_fcn) 228 | p_frwd = jax.jit(policy_fcn.apply) 229 | 230 | # critic 231 | q_fcn = hk.transform(_q_fcn) 232 | q_fcn = hk.without_apply_rng(q_fcn) 233 | q_frwd = jax.jit(q_fcn.apply) 234 | 235 | ## optimizers 236 | p_optim = optax.adam(policy_lr) 237 | q_optim = optax.adam(q_lr) 238 | 239 | #%% 240 | eps = 1 241 | eps_decay = 1/n_episodes 242 | 243 | vbuffer = ReplayBuffer(buffer_size) 244 | action_noise = OU_Noise((n_actions,)) 245 | 246 | # init models + optims 247 | obs = env.reset() # dummy input 248 | a = np.zeros(env.action_space.shape) 249 | p_params = policy_fcn.init(rng, obs) 250 | q_params = q_fcn.init(rng, obs, a) 251 | 252 | # target networks 253 | p_params_t = copy.deepcopy(p_params) 254 | q_params_t = copy.deepcopy(q_params) 255 | 256 | p_opt_state = p_optim.init(p_params) 257 | q_opt_state = q_optim.init(q_params) 258 | 259 | # bundle 260 | params = (p_params, q_params, p_params_t, q_params_t) 261 | opt_states = (p_opt_state, q_opt_state) 262 | 263 | 264 | while True: 265 | # rollout 266 | p_params = params[0] 267 | a = p_frwd(p_params, obs) + eps * action_noise.sample() 268 | a = np.clip(a, a_low, a_high) 269 | obs2, r, done, _ = env.step(a) 270 | vbuffer.push((obs, a, obs2, r, done)) 271 | obs = obs2 272 | 273 | # update 274 | if not vbuffer.is_ready(batch_size): continue 275 | break 276 | 277 | with jax.profiler.trace('tmp/'): 278 | # rollout 279 | p_params = params[0] 280 | a = p_frwd(p_params, obs) + eps * action_noise.sample() 281 | a = np.clip(a, a_low, a_high) 282 | obs2, r, done, _ = env.step(a) 283 | vbuffer.push((obs, a, obs2, r, done)) 284 | obs = obs2 285 | 286 | batch = vbuffer.sample(batch_size) 287 | params, opt_states, losses, grads = ddpg_step(params, opt_states, batch) 288 | 289 | p_loss, q_loss = losses 290 | p_loss.block_until_ready() 291 | 292 | # %% 293 | # %% 294 | # %% 295 | -------------------------------------------------------------------------------- /ddpg/ddpg_param_tuning.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import jax 3 | import jax.numpy as np 4 | import numpy as onp 5 | import haiku as hk 6 | import optax 7 | import gym 8 | import copy 9 | import matplotlib.pyplot as plt 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | # from tqdm.notebook import tqdm 13 | from tqdm import tqdm 14 | 15 | import collections 16 | import random 17 | from functools import partial 18 | 19 | import pybullet as p 20 | import pybullet_envs 21 | from numpngw import write_apng 22 | 23 | from jax.config import config 24 | config.update("jax_debug_nans", True) # break on nans 25 | 26 | #%% 27 | env_name = 'CartPoleContinuousBulletEnv-v0' 28 | # env_name = 'Pendulum-v0' 29 | 30 | env = gym.make(env_name) 31 | n_actions = env.action_space.shape[0] 32 | obs_dim = env.observation_space.shape[0] 33 | 34 | a_high = env.action_space.high[0] 35 | a_low = env.action_space.low[0] 36 | 37 | #%% 38 | class FanIn_Uniform(hk.initializers.Initializer): 39 | def __call__(self, shape, dtype): 40 | bound = 1/(shape[0] ** .5) 41 | return hk.initializers.RandomUniform(-bound, bound)(shape, dtype) 42 | 43 | init_other = FanIn_Uniform() 44 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 45 | 46 | def middle_layer(d): 47 | return hk.Sequential([ 48 | hk.Linear(d, w_init=init_other), jax.nn.relu, 49 | ]) 50 | 51 | def _policy_fcn(s): 52 | policy = hk.Sequential([ 53 | middle_layer(256), 54 | middle_layer(256), 55 | hk.Linear(n_actions, w_init=init_final), 56 | jax.nn.tanh, 57 | ]) 58 | a = policy(s) * a_high # scale to action range 59 | return a 60 | 61 | def _q_fcn(s, a): 62 | z1 = middle_layer(16)(s) 63 | z1 = middle_layer(32)(s) 64 | 65 | z2 = middle_layer(32)(a) 66 | z = np.concatenate([z1, z2]) 67 | z = middle_layer(256)(z) 68 | z = middle_layer(256)(z) 69 | q_sa = hk.Linear(1, w_init=init_final)(z) 70 | return q_sa 71 | 72 | #%% 73 | class ReplayBuffer(object): # clean code BUT extremely slow 74 | def __init__(self, capacity): 75 | self.buffer = collections.deque(maxlen=int(capacity)) 76 | 77 | def push(self, sample): 78 | self.buffer.append(sample) 79 | 80 | def sample(self, batch_size): 81 | samples = zip(*random.sample(self.buffer, batch_size)) 82 | samples = tuple(map(lambda x: np.stack(x).astype(np.float32), samples)) 83 | return samples 84 | 85 | def is_ready(self, batch_size): 86 | return batch_size <= len(self.buffer) 87 | 88 | #%% 89 | class Vector_ReplayBuffer: 90 | def __init__(self, buffer_capacity): 91 | buffer_capacity = int(buffer_capacity) 92 | # Number of "experiences" to store at max 93 | self.buffer_capacity = buffer_capacity 94 | 95 | # Its tells us num of times record() was called. 96 | self.buffer_counter = 0 97 | 98 | # Instead of list of tuples as the exp.replay concept go 99 | # We use different np.arrays for each tuple element 100 | num_states = env.observation_space.shape[0] 101 | num_actions = env.action_space.shape[0] 102 | self.state_buffer = onp.zeros((self.buffer_capacity, num_states)) 103 | self.action_buffer = onp.zeros((self.buffer_capacity, num_actions)) 104 | self.reward_buffer = onp.zeros((self.buffer_capacity, 1)) 105 | self.dones = onp.zeros((self.buffer_capacity, 1)) 106 | self.next_state_buffer = onp.zeros((self.buffer_capacity, num_states)) 107 | self.buffers = [self.state_buffer, self.action_buffer, self.next_state_buffer, self.reward_buffer, self.dones] 108 | 109 | def push(self, obs_tuple): 110 | # (obs, a, obs2, r, done) 111 | index = self.buffer_counter % self.buffer_capacity 112 | 113 | self.state_buffer[index] = obs_tuple[0] 114 | self.action_buffer[index] = obs_tuple[1] 115 | self.next_state_buffer[index] = obs_tuple[2] 116 | self.reward_buffer[index] = obs_tuple[3] 117 | self.dones[index] = float(obs_tuple[4]) # dones T/F -> 1/0 118 | 119 | self.buffer_counter += 1 120 | 121 | def is_ready(self, batch_size): return self.buffer_counter >= batch_size 122 | 123 | def sample(self, batch_size): 124 | record_range = min(self.buffer_counter, self.buffer_capacity) 125 | batch_indices = onp.random.choice(record_range, batch_size) 126 | batch = tuple(b[batch_indices] for b in self.buffers) 127 | return batch 128 | 129 | #%% 130 | def critic_loss(q_params, target_params, sample): 131 | p_params_t, q_params_t = target_params 132 | obs, a, obs2, r, done = sample 133 | # pred 134 | q = q_frwd(q_params, obs, a) 135 | # target 136 | a_t = p_frwd(p_params_t, obs2) 137 | q_t = q_frwd(q_params_t, obs2, a_t) 138 | y = r + (1 - done) * gamma * q_t 139 | y = jax.lax.stop_gradient(y) 140 | 141 | loss = (q - y) ** 2 142 | return loss 143 | 144 | def policy_loss(p_params, q_params, sample): 145 | obs, _, _, _, _ = sample 146 | a = p_frwd(p_params, obs) 147 | return -q_frwd(q_params, obs, a) 148 | 149 | def batch_critic_loss(q_params, target_params, batch): 150 | return jax.vmap(partial(critic_loss, q_params, target_params))(batch).mean() 151 | 152 | def batch_policy_loss(p_params, q_params, batch): 153 | return jax.vmap(partial(policy_loss, p_params, q_params))(batch).mean() 154 | 155 | #%% 156 | @jax.jit 157 | def ddpg_step(params, opt_states, batch): 158 | p_params, q_params, p_params_t, q_params_t = params 159 | p_opt_state, q_opt_state = opt_states 160 | 161 | # update q/critic 162 | target_params = (p_params_t, q_params_t) 163 | q_loss, q_grad = jax.value_and_grad(batch_critic_loss)(q_params, target_params, batch) 164 | q_grad, q_opt_state = q_optim.update(q_grad, q_opt_state) 165 | q_params = optax.apply_updates(q_params, q_grad) 166 | 167 | # update policy 168 | p_loss, p_grad = jax.value_and_grad(batch_policy_loss)(p_params, q_params, batch) 169 | p_grad, p_opt_state = p_optim.update(p_grad, p_opt_state) 170 | p_params = optax.apply_updates(p_params, p_grad) 171 | 172 | # slow update targets 173 | polask_avg = lambda target, w: (1 - tau) * target + tau * w 174 | p_params_t = jax.tree_multimap(polask_avg, p_params_t, p_params) 175 | q_params_t = jax.tree_multimap(polask_avg, q_params_t, q_params) 176 | 177 | params = (p_params, q_params, p_params_t, q_params_t) # re-pack with updated q 178 | opt_states = (p_opt_state, q_opt_state) 179 | losses = (p_loss, q_loss) 180 | grads = (p_grad, q_grad) 181 | return params, opt_states, losses, grads 182 | 183 | class OU_Noise: 184 | def __init__(self, shape): 185 | self.shape = shape 186 | self.theta = 0.15 187 | self.dt = 1e-2 188 | self.sigma = 0.2 189 | self.mean = onp.zeros(shape) 190 | self.reset() # set prev 191 | 192 | def sample(self): 193 | noise = onp.random.normal(size=self.shape) 194 | x = ( 195 | self.prev 196 | + self.theta * (self.mean - self.prev) * self.dt 197 | + self.sigma * onp.sqrt(self.dt) * noise 198 | ) 199 | self.prev = x 200 | return x 201 | 202 | def reset(self): self.prev = onp.zeros(self.shape) 203 | 204 | #%% 205 | import optuna 206 | 207 | def train(trial, study_name): 208 | n_episodes = 1000 209 | batch_size = trial.suggest_int("batch_size", 32, 256, step=32, log=False) 210 | buffer_size = 1e6 211 | seed = 100 # 420 works 212 | 213 | global tau, gamma 214 | gamma = 0.99 215 | # tau = 0.005 # 1e-4 ## very important parameter -- make or break 216 | tau = trial.suggest_float("tau", 1e-4, 1e-1, log=True) 217 | 218 | # policy_lr = 1e-3 219 | # q_lr = 2e-3 220 | policy_lr = trial.suggest_float("policy_lr", 1e-4, 1e-1, log=True) 221 | q_lr = trial.suggest_float("q_lr", 1e-4, 1e-1, log=True) 222 | 223 | # metric writer 224 | writer = SummaryWriter(comment=study_name) 225 | 226 | rng = jax.random.PRNGKey(seed) 227 | onp.random.seed(seed) 228 | random.seed(seed) 229 | 230 | ## model defn 231 | global p_frwd, q_frwd, p_optim, q_optim 232 | 233 | # actor 234 | policy_fcn = hk.transform(_policy_fcn) 235 | policy_fcn = hk.without_apply_rng(policy_fcn) 236 | p_frwd = jax.jit(policy_fcn.apply) 237 | 238 | # critic 239 | q_fcn = hk.transform(_q_fcn) 240 | q_fcn = hk.without_apply_rng(q_fcn) 241 | q_frwd = jax.jit(q_fcn.apply) 242 | 243 | ## optimizers 244 | p_optim = optax.adam(policy_lr) 245 | q_optim = optax.adam(q_lr) 246 | 247 | #%% 248 | eps = 1 249 | eps_decay = 1/n_episodes 250 | 251 | vbuffer = Vector_ReplayBuffer(buffer_size) 252 | action_noise = OU_Noise((n_actions,)) 253 | 254 | # init models + optims 255 | obs = env.reset() # dummy input 256 | a = np.zeros(env.action_space.shape) 257 | p_params = policy_fcn.init(rng, obs) 258 | q_params = q_fcn.init(rng, obs, a) 259 | 260 | # target networks 261 | p_params_t = copy.deepcopy(p_params) 262 | q_params_t = copy.deepcopy(q_params) 263 | 264 | p_opt_state = p_optim.init(p_params) 265 | q_opt_state = q_optim.init(q_params) 266 | 267 | # bundle 268 | params = (p_params, q_params, p_params_t, q_params_t) 269 | opt_states = (p_opt_state, q_opt_state) 270 | 271 | step_i = 0 272 | total_reward_sum = [] 273 | for epi_i in tqdm(range(n_episodes)): 274 | 275 | action_noise.reset() 276 | obs = env.reset() 277 | rewards = [] 278 | while True: 279 | # rollout 280 | p_params = params[0] 281 | a = p_frwd(p_params, obs) + eps * action_noise.sample() 282 | a = np.clip(a, a_low, a_high) 283 | obs2, r, done, _ = env.step(a) 284 | vbuffer.push((obs, a, obs2, r, done)) 285 | obs = obs2 286 | rewards.append(r) 287 | 288 | # update 289 | if not vbuffer.is_ready(batch_size): continue 290 | batch = vbuffer.sample(batch_size) 291 | params, opt_states, losses, _ = ddpg_step(params, opt_states, batch) 292 | 293 | p_loss, q_loss = losses 294 | writer.add_scalar('loss/policy', p_loss.item(), step_i) 295 | writer.add_scalar('loss/q_fcn', q_loss.item(), step_i) 296 | 297 | step_i += 1 298 | if done: break 299 | 300 | writer.add_scalar('rollout/total_reward', sum(rewards), epi_i) 301 | 302 | eps -= eps_decay # decay exploration 303 | total_reward_sum.append(sum(rewards)) 304 | 305 | reward_metric = sum(total_reward_sum[-10:]) 306 | 307 | trial.report(reward_metric, epi_i) 308 | if trial.should_prune(): 309 | raise optuna.TrialPruned() 310 | 311 | return reward_metric # thing to maximize 312 | 313 | import logging 314 | import sys 315 | 316 | #%% 317 | if __name__ == '__main__': 318 | study_name = 'ddpg_cartpole' 319 | 320 | optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout)) 321 | storage_name = "sqlite:///{}.db".format(study_name) 322 | pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=15) 323 | 324 | # note: will not be persistant without sqlDB storage 325 | study = optuna.create_study( 326 | direction="maximize", 327 | study_name=study_name, 328 | pruner=pruner, 329 | storage=storage_name, 330 | ) 331 | study.optimize(lambda trial: train(trial, study_name), n_trials=10) 332 | 333 | import pprint 334 | best_params = study.best_params 335 | pprint.pprint(best_params) 336 | -------------------------------------------------------------------------------- /demos/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/.DS_Store -------------------------------------------------------------------------------- /demos/bipedal/BipedalWalker-v3_ddpg_4.08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/bipedal/BipedalWalker-v3_ddpg_4.08.png -------------------------------------------------------------------------------- /demos/bipedal/BipedalWalker-v3_td3_ddpg_310.10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/bipedal/BipedalWalker-v3_td3_ddpg_310.10.png -------------------------------------------------------------------------------- /demos/bipedal/BipedalWalker-v3_td3_ddpg_97.04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/bipedal/BipedalWalker-v3_td3_ddpg_97.04.png -------------------------------------------------------------------------------- /demos/bipedal/BipedalWalker-v3_td3_ddpg_97.28.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/bipedal/BipedalWalker-v3_td3_ddpg_97.28.mp4 -------------------------------------------------------------------------------- /demos/bipedal/td3_tensorflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/bipedal/td3_tensorflow.png -------------------------------------------------------------------------------- /demos/ddpg_ant/AntBulletEnv-v0_td3_ddpg_207.26.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/ddpg_ant/AntBulletEnv-v0_td3_ddpg_207.26.mp4 -------------------------------------------------------------------------------- /demos/ddpg_ant/AntBulletEnv-v0_td3_ddpg_81.77.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/ddpg_ant/AntBulletEnv-v0_td3_ddpg_81.77.mp4 -------------------------------------------------------------------------------- /demos/ddpg_cheeta/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/ddpg_cheeta/.DS_Store -------------------------------------------------------------------------------- /demos/ddpg_cheeta/Screen Shot 2021-07-23 at 8.33.21 AM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/ddpg_cheeta/Screen Shot 2021-07-23 at 8.33.21 AM.png -------------------------------------------------------------------------------- /demos/ddpg_cheeta/Screen Shot 2021-07-23 at 8.33.28 AM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/ddpg_cheeta/Screen Shot 2021-07-23 at 8.33.28 AM.png -------------------------------------------------------------------------------- /demos/ddpg_cheeta/backwards/HalfCheetahBulletEnv-v0_backwards_td3_ddpg_612.73.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/ddpg_cheeta/backwards/HalfCheetahBulletEnv-v0_backwards_td3_ddpg_612.73.mp4 -------------------------------------------------------------------------------- /demos/ddpg_cheeta/forwards/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/ddpg_cheeta/forwards/.DS_Store -------------------------------------------------------------------------------- /demos/ddpg_cheeta/forwards/HalfCheetahBulletEnv-v0_td3_ddpg_167.27.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/ddpg_cheeta/forwards/HalfCheetahBulletEnv-v0_td3_ddpg_167.27.mp4 -------------------------------------------------------------------------------- /demos/ddpg_cheeta/forwards/HalfCheetahBulletEnv-v0_td3_ddpg_492.81.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/ddpg_cheeta/forwards/HalfCheetahBulletEnv-v0_td3_ddpg_492.81.mp4 -------------------------------------------------------------------------------- /demos/ddpg_cheeta/forwards/HalfCheetahBulletEnv-v0_td3_ddpg_94.63.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/ddpg_cheeta/forwards/HalfCheetahBulletEnv-v0_td3_ddpg_94.63.mp4 -------------------------------------------------------------------------------- /demos/pendulum/Pendulum-v0_ddpg_-3.32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/pendulum/Pendulum-v0_ddpg_-3.32.png -------------------------------------------------------------------------------- /demos/pendulum/Pendulum-v0_initmodel_-1897.19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/pendulum/Pendulum-v0_initmodel_-1897.19.png -------------------------------------------------------------------------------- /demos/pong/PongNoFrameskip-v4_13.00_PPO.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/demos/pong/PongNoFrameskip-v4_13.00_PPO.mp4 -------------------------------------------------------------------------------- /maml/3maml.py: -------------------------------------------------------------------------------- 1 | # #%% 2 | # %load_ext autoreload 3 | # %autoreload 2 4 | 5 | import jax 6 | import jax.numpy as np 7 | import numpy as onp 8 | import optax 9 | import gym 10 | from functools import partial 11 | from env import Navigation2DEnv, Navigation2DEnv_Disc 12 | import cloudpickle 13 | import pathlib 14 | import haiku as hk 15 | from tqdm import tqdm 16 | 17 | from jax.config import config 18 | config.update("jax_enable_x64", True) 19 | config.update("jax_debug_nans", True) # break on nans 20 | 21 | #%% 22 | from utils import gaussian_log_prob, gaussian_sample 23 | from utils import cont_policy as policy 24 | from utils import eval, init_policy_fcn, Cont_Vector_Buffer, discount_cumsum, \ 25 | tree_mean, mean_vmap_jit, sum_vmap_jit, optim_update_fcn 26 | 27 | env_name = 'Navigation2D' 28 | env = Navigation2DEnv(max_n_steps=200) # maml debug env 29 | 30 | n_actions = env.action_space.shape[0] 31 | obs_dim = env.observation_space.shape[0] 32 | 33 | a_high = env.action_space.high[0] 34 | a_low = env.action_space.low[0] 35 | clip_range = [a_low, a_high] 36 | print(f'[LOGGER] n_actions: {n_actions} obs_dim: {obs_dim} action_clip_range: {clip_range}') 37 | 38 | # value function / baseline 39 | # https://github.com/rll/rllab/blob/master/rllab/baselines/linear_feature_baseline.py 40 | def v_features(obs): 41 | o = np.clip(obs, -10, 10) 42 | l = len(o) 43 | al = np.arange(l).reshape(-1, 1) / 100.0 44 | return np.concatenate([o, o ** 2, al, al ** 2, al ** 3, np.ones((l, 1))], axis=1) 45 | 46 | def v_fit(trajectories, feature_fcn=v_features, reg_coeff=1e-5): 47 | featmat = np.concatenate([feature_fcn(traj['obs']) for traj in trajectories]) 48 | r = np.concatenate([traj['r'] for traj in trajectories]) 49 | for _ in range(5): 50 | # solve argmin_x (F x = R) <-- unsolvable (F non-sqr) 51 | # == argmin_x (F^T F x = F^T R) <-- solvable (sqr F^T F) 52 | # where F = Features, x = Weights, R = rewards 53 | _coeffs = np.linalg.lstsq( 54 | featmat.T.dot(featmat) + reg_coeff * np.identity(featmat.shape[1]), 55 | featmat.T.dot(r) 56 | )[0] 57 | if not np.any(np.isnan(_coeffs)): 58 | return _coeffs, 0 # succ 59 | reg_coeff *= 10 60 | return np.zeros_like(_coeffs), 1 # err 61 | 62 | def sample_trajectory(traj, p): 63 | traj_len = int(traj[0].shape[0] * p) 64 | idxs = onp.random.choice(traj_len, size=traj_len, replace=False) 65 | sampled_traj = jax.tree_map(lambda x: x[idxs], traj) 66 | return sampled_traj 67 | 68 | def rollout(env, p_params, rng): 69 | buffer = Cont_Vector_Buffer(n_actions, obs_dim, max_n_steps) 70 | obs = env.reset() 71 | for _ in range(max_n_steps): 72 | rng, subkey = jax.random.split(rng, 2) 73 | a, log_prob = policy(p_frwd, p_params, obs, subkey, clip_range, False) 74 | 75 | a = jax.lax.stop_gradient(a) 76 | log_prob = jax.lax.stop_gradient(log_prob) 77 | a = onp.array(a) 78 | 79 | obs2, r, done, _ = env.step(a) 80 | buffer.push((obs, a, r, obs2, done, log_prob)) 81 | 82 | obs = obs2 83 | if done: break 84 | 85 | trajectory = buffer.contents() 86 | return trajectory 87 | 88 | #%% 89 | # inner optim 90 | @jax.jit 91 | def _reinforce_loss(p_params, obs, a, adv): 92 | mu, std = p_frwd(p_params, obs) 93 | log_prob = gaussian_log_prob(a, mu, std) 94 | loss = -(log_prob * adv).sum() 95 | return loss 96 | 97 | reinforce_loss = sum_vmap_jit(_reinforce_loss, (None, 0, 0, 0)) 98 | reinforce_loss_grad = jax.jit(jax.value_and_grad(reinforce_loss)) 99 | 100 | @jax.jit 101 | def _ppo_loss(p_params, obs, a, adv, old_log_prob): 102 | ## policy losses 103 | mu, std = p_frwd(p_params, obs) 104 | 105 | # policy gradient 106 | log_prob = gaussian_log_prob(a, mu, std) 107 | 108 | approx_kl = (old_log_prob - log_prob).sum() 109 | ratio = np.exp(log_prob - old_log_prob) 110 | p_loss1 = ratio * adv 111 | p_loss2 = np.clip(ratio, 1-eps, 1+eps) * adv 112 | policy_loss = -np.fmin(p_loss1, p_loss2).sum() 113 | 114 | clipped_mask = ((ratio > 1+eps) | (ratio < 1-eps)).astype(np.float32) 115 | clip_frac = clipped_mask.mean() 116 | 117 | loss = policy_loss 118 | info = dict(ploss=policy_loss, approx_kl=approx_kl, cf=clip_frac) 119 | 120 | return loss, info 121 | 122 | ppo_loss = mean_vmap_jit(_ppo_loss, (None, 0, 0, 0, 0)) 123 | 124 | @jax.jit 125 | def sgd_step_int(params, grads, alpha): 126 | sgd_update = lambda param, grad: param - alpha * grad 127 | return jax.tree_multimap(sgd_update, params, grads) 128 | 129 | @jax.jit 130 | def sgd_step_tree(params, grads, alphas): 131 | sgd_update = lambda param, grad, alpha: param - alpha * grad 132 | return jax.tree_multimap(sgd_update, params, grads, alphas) 133 | 134 | def sgd_step(params, grads, alpha): 135 | step_fcn = sgd_step_int if type(alpha) in [int, float] else sgd_step_tree 136 | return step_fcn(params, grads, alpha) 137 | 138 | #%% 139 | seed = onp.random.randint(1e5) 140 | epochs = 500 141 | eval_every = 1 142 | max_n_steps = 100 # env._max_episode_steps 143 | ## PPO 144 | eps = 0.2 145 | gamma = 0.99 146 | lmbda = 0.95 147 | lr = 1e-3 148 | ## MAML 149 | task_batch_size = 40 150 | train_n_traj = 20 151 | eval_n_traj = 40 152 | alpha = 0.1 153 | damp_lambda = 0.01 154 | 155 | rng = jax.random.PRNGKey(seed) 156 | p_frwd, p_params = init_policy_fcn('continuous', env, rng) 157 | 158 | p_update_fcn, p_opt_state = optim_update_fcn(optax.adam(lr), p_params) 159 | 160 | #%% 161 | task = env.sample_tasks(1)[0] 162 | env.reset_task(task) 163 | 164 | #%% 165 | @jax.jit 166 | def compute_advantage(W, traj): 167 | # linear fcn predict 168 | v_obs = v_features(traj['obs']) @ W 169 | # baseline 170 | adv = traj['r'] - v_obs 171 | # normalize 172 | adv = (adv - adv.mean()) / (adv.std() + 1e-8) 173 | return adv.squeeze() 174 | 175 | def maml_inner(p_params, env, rng, n_traj, alpha): 176 | subkeys = jax.random.split(rng, n_traj) 177 | trajectories = [] 178 | for i in range(n_traj): 179 | traj = rollout(env, p_params, subkeys[i]) 180 | traj['r'] = discount_cumsum(traj['r'], discount=gamma) 181 | trajectories.append(traj) 182 | 183 | W = v_fit(trajectories)[0] 184 | for i in range(len(trajectories)): 185 | trajectories[i]['adv'] = compute_advantage(W, trajectories[i]) 186 | 187 | gradients = [] 188 | for traj in trajectories: 189 | _, grad = reinforce_loss_grad(p_params, traj['obs'], traj['a'], traj['adv']) 190 | gradients.append(grad) 191 | grads = jax.tree_multimap(lambda *x: np.stack(x).sum(0), *gradients) 192 | inner_params_p = sgd_step(p_params, grads, alpha) 193 | 194 | return inner_params_p, W 195 | 196 | def maml_outer(p_params, env, rng): 197 | subkeys = jax.random.split(rng, 3) 198 | inner_p_params, W = maml_inner(p_params, env, subkeys[0], train_n_traj, alpha) 199 | 200 | traj = rollout(env, inner_p_params, subkeys[1]) 201 | traj['adv'] = compute_advantage(W, traj) 202 | loss, info = ppo_loss(inner_p_params, traj['obs'], traj['a'], traj['adv'], traj['log_prob']) 203 | return loss, (info, traj) 204 | 205 | def maml_eval(env, p_params, rng, n_steps=1): 206 | rewards = [] 207 | 208 | rng, subkey = jax.random.split(rng, 2) 209 | reward_0step = eval(p_frwd, policy, p_params, env, subkey, clip_range, True) 210 | rewards.append(reward_0step) 211 | 212 | eval_alpha = alpha 213 | for _ in range(n_steps): 214 | rng, *subkeys = jax.random.split(rng, 3) 215 | 216 | inner_p_params, _ = maml_inner(p_params, env, subkeys[0], eval_n_traj, eval_alpha) 217 | r = eval(p_frwd, policy, inner_p_params, env, subkeys[1], clip_range, True) 218 | 219 | rewards.append(r) 220 | eval_alpha = alpha / 2 221 | p_params = inner_p_params 222 | 223 | return rewards 224 | 225 | #%% 226 | env.seed(0) 227 | n_tasks = 2 228 | task = env.sample_tasks(1)[0] ## only two tasks 229 | assert n_tasks in [1, 2] 230 | if n_tasks == 1: 231 | tasks = [task] * task_batch_size 232 | elif n_tasks == 2: 233 | task2 = {'goal': -task['goal'].copy()} 234 | tasks = [task, task2] * (task_batch_size//2) 235 | 236 | for task in tasks[:n_tasks]: 237 | env.reset_task(task) 238 | # log max reward 239 | goal = env._task['goal'] 240 | reward = 0 241 | step_count = 0 242 | obs = env.reset() 243 | while True: 244 | a = goal - obs 245 | obs2, r, done, _ = env.step(a) 246 | reward += r 247 | step_count += 1 248 | if done: break 249 | obs = obs2 250 | print(f'[LOGGER]: MAX_REWARD={reward} IN {step_count} STEPS') 251 | 252 | #%% 253 | from torch.utils.tensorboard import SummaryWriter 254 | writer = SummaryWriter(comment=f'maml_{n_tasks}task_test_seed={seed}') 255 | 256 | #%% 257 | maml_grad = jax.value_and_grad(maml_outer, has_aux=True) 258 | 259 | step_count = 0 260 | for e in tqdm(range(1, epochs+1)): 261 | # training 262 | # tasks = env.sample_tasks(task_batch_size) 263 | 264 | gradients = [] 265 | mean_loss = 0 266 | for task_i, task in enumerate(tqdm(tasks)): 267 | env.reset_task(task) 268 | rng, subkey = jax.random.split(rng, 2) 269 | (loss, (info, traj)), grads = maml_grad(p_params, env, rng) 270 | 271 | gradients.append(grads) 272 | mean_loss += loss 273 | step_count += 1 274 | 275 | mean_loss /= len(tasks) 276 | writer.add_scalar(f'loss/mean_task_loss', mean_loss.item(), step_count) 277 | 278 | # update 279 | gradients = jax.tree_multimap(lambda *x: np.stack(x).mean(0), *gradients) 280 | p_params, p_opt_state = p_update_fcn(p_params, p_opt_state, gradients) 281 | 282 | # eval 283 | if e % eval_every == 0: 284 | eval_tasks = tasks[:n_tasks] 285 | task_rewards = [] 286 | for task_i, eval_task in enumerate(eval_tasks): 287 | env.reset_task(eval_task) 288 | 289 | rng, subkey = jax.random.split(rng, 2) 290 | rewards = maml_eval(env, p_params, subkey, n_steps=3) 291 | task_rewards.append(rewards) 292 | 293 | for step_i, r in enumerate(rewards): 294 | writer.add_scalar(f'task{task_i}/reward_{step_i}step', r, e) 295 | 296 | mean_rewards=[] 297 | for step_i in range(len(task_rewards[0])): 298 | mean_r = sum([task_rewards[j][step_i] for j in range(len(task_rewards))]) / 2 299 | writer.add_scalar(f'mean_task/reward_{step_i}step', mean_r, e) 300 | mean_rewards.append(mean_r) 301 | 302 | 303 | #%% 304 | #%% -------------------------------------------------------------------------------- /maml/4maml.py: -------------------------------------------------------------------------------- 1 | #%% 2 | # maml: 3 | # step1: 4 | # sample train episodes 5 | # update model 6 | # sample test episodes 7 | # return (train, test), new_params 8 | 9 | # step2: 10 | # compute loss0 with params 11 | # task_i: update model on new_params, compute loss 12 | # compute mean over all tasks 13 | # compute gradients on mean_loss_tasks 14 | # TRPO on gradients 15 | 16 | #%% 17 | # %load_ext autoreload 18 | # %autoreload 2 19 | 20 | import jax 21 | import jax.numpy as np 22 | import numpy as onp 23 | import optax 24 | import gym 25 | from functools import partial 26 | from env import Navigation2DEnv, Navigation2DEnv_Disc 27 | import cloudpickle 28 | import pathlib 29 | import haiku as hk 30 | from tqdm import tqdm 31 | 32 | from jax.config import config 33 | config.update("jax_enable_x64", True) 34 | config.update("jax_debug_nans", True) # break on nans 35 | 36 | #%% 37 | from utils import gaussian_log_prob, gaussian_sample 38 | from utils import cont_policy as policy 39 | from utils import eval, init_policy_fcn, Cont_Vector_Buffer, discount_cumsum, \ 40 | tree_mean, mean_vmap_jit, sum_vmap_jit, optim_update_fcn 41 | 42 | env_name = 'Navigation2D' 43 | env = Navigation2DEnv(max_n_steps=200) # maml debug env 44 | 45 | n_actions = env.action_space.shape[0] 46 | obs_dim = env.observation_space.shape[0] 47 | 48 | a_high = env.action_space.high[0] 49 | a_low = env.action_space.low[0] 50 | clip_range = [a_low, a_high] 51 | print(f'[LOGGER] n_actions: {n_actions} obs_dim: {obs_dim} action_clip_range: {clip_range}') 52 | 53 | # value function / baseline 54 | # https://github.com/rll/rllab/blob/master/rllab/baselines/linear_feature_baseline.py 55 | def v_features(obs): 56 | o = np.clip(obs, -10, 10) 57 | l = len(o) 58 | al = np.arange(l).reshape(-1, 1) / 100.0 59 | return np.concatenate([o, o ** 2, al, al ** 2, al ** 3, np.ones((l, 1))], axis=1) 60 | 61 | def v_fit(featmat, r, reg_coeff=1e-5): 62 | for _ in range(5): 63 | # solve argmin_x (F x = R) <-- unsolvable (F non-sqr) 64 | # == argmin_x (F^T F x = F^T R) <-- solvable (sqr F^T F) 65 | # where F = Features, x = Weights, R = rewards 66 | _coeffs = np.linalg.lstsq( 67 | featmat.T.dot(featmat) + reg_coeff * np.identity(featmat.shape[1]), 68 | featmat.T.dot(r) 69 | )[0] 70 | if not np.any(np.isnan(_coeffs)): 71 | return _coeffs, 0 # succ 72 | reg_coeff *= 10 73 | return np.zeros_like(_coeffs), 1 # err 74 | 75 | def rollout(env, p_params, rng): 76 | buffer = Cont_Vector_Buffer(n_actions, obs_dim, max_n_steps) 77 | obs = env.reset() 78 | for _ in range(max_n_steps): 79 | rng, subkey = jax.random.split(rng, 2) 80 | a, log_prob = policy(p_frwd, p_params, obs, subkey, clip_range, False) 81 | 82 | a = jax.lax.stop_gradient(a) 83 | log_prob = jax.lax.stop_gradient(log_prob) 84 | a = onp.array(a) 85 | 86 | obs2, r, done, _ = env.step(a) 87 | buffer.push((obs, a, r, obs2, done, log_prob)) 88 | 89 | obs = obs2 90 | if done: break 91 | 92 | trajectory = buffer.contents() 93 | return trajectory 94 | 95 | # inner optim 96 | @jax.jit 97 | def _reinforce_loss(p_params, obs, a, adv): 98 | mu, std = p_frwd(p_params, obs) 99 | log_prob = gaussian_log_prob(a, mu, std) 100 | loss = -(log_prob * adv).sum() 101 | return loss 102 | 103 | reinforce_loss = sum_vmap_jit(_reinforce_loss, (None, 0, 0, 0)) 104 | reinforce_loss_grad = jax.jit(jax.value_and_grad(reinforce_loss)) 105 | 106 | @jax.jit 107 | def _ppo_loss(p_params, obs, a, adv, old_log_prob): 108 | ## policy losses 109 | mu, std = p_frwd(p_params, obs) 110 | 111 | # policy gradient 112 | log_prob = gaussian_log_prob(a, mu, std) 113 | 114 | approx_kl = (old_log_prob - log_prob).sum() 115 | ratio = np.exp(log_prob - old_log_prob) 116 | p_loss1 = ratio * adv 117 | p_loss2 = np.clip(ratio, 1-eps, 1+eps) * adv 118 | policy_loss = -np.fmin(p_loss1, p_loss2).sum() 119 | 120 | clipped_mask = ((ratio > 1+eps) | (ratio < 1-eps)).astype(np.float32) 121 | clip_frac = clipped_mask.mean() 122 | 123 | loss = policy_loss 124 | info = dict(ploss=policy_loss, approx_kl=approx_kl, cf=clip_frac) 125 | 126 | return loss, info 127 | 128 | ppo_loss = mean_vmap_jit(_ppo_loss, (None, 0, 0, 0, 0)) 129 | 130 | @jax.jit 131 | def sgd_step_int(params, grads, alpha): 132 | sgd_update = lambda param, grad: param - alpha * grad 133 | return jax.tree_multimap(sgd_update, params, grads) 134 | 135 | @jax.jit 136 | def sgd_step_tree(params, grads, alphas): 137 | sgd_update = lambda param, grad, alpha: param - alpha * grad 138 | return jax.tree_multimap(sgd_update, params, grads, alphas) 139 | 140 | def sgd_step(params, grads, alpha): 141 | step_fcn = sgd_step_int if type(alpha) in [int, float] else sgd_step_tree 142 | return step_fcn(params, grads, alpha) 143 | 144 | #%% 145 | seed = onp.random.randint(1e5) 146 | epochs = 500 147 | eval_every = 1 148 | max_n_steps = 100 # env._max_episode_steps 149 | ## PPO 150 | eps = 0.2 151 | gamma = 0.99 152 | lmbda = 0.95 153 | lr = 1e-3 154 | ## MAML 155 | task_batch_size = 40 156 | train_n_traj = 20 157 | eval_n_traj = 40 158 | alpha = 0.1 159 | damp_lambda = 0.01 160 | 161 | rng = jax.random.PRNGKey(seed) 162 | p_frwd, p_params = init_policy_fcn('continuous', env, rng) 163 | 164 | p_update_fcn, p_opt_state = optim_update_fcn(optax.adam(lr), p_params) 165 | 166 | #%% 167 | @jax.jit 168 | def compute_advantage(W, obs, r): 169 | # linear fcn predict 170 | v_obs = v_features(obs) @ W 171 | # baseline 172 | adv = r - v_obs 173 | # normalize 174 | adv = (adv - adv.mean()) / (adv.std() + 1e-8) 175 | return adv.squeeze() 176 | 177 | def n_rollouts(p_params, env, rng, n_traj): 178 | subkeys = jax.random.split(rng, n_traj) 179 | trajectories = [] 180 | for i in range(n_traj): 181 | traj = rollout(env, p_params, subkeys[i]) 182 | (obs, a, r, obs2, done, log_prob) = traj 183 | r = discount_cumsum(r, discount=gamma) 184 | traj = (obs, a, r, obs2, done, log_prob) 185 | trajectories.append(traj) 186 | return trajectories 187 | 188 | @jax.jit 189 | def reinforce_step(traj, p_params, alpha): 190 | (obs, a, adv) = traj 191 | _, grads = reinforce_loss_grad(p_params, obs, a, adv) 192 | inner_params_p = sgd_step_int(p_params, grads, alpha) 193 | return inner_params_p 194 | 195 | def maml_inner(p_params, env, rng, alpha): 196 | # rollout + post process (train) 197 | train_trajectories = n_rollouts(p_params, env, rng, train_n_traj) 198 | featmat = np.concatenate([v_features(traj[0]) for traj in train_trajectories]) 199 | train_trajectories = jax.tree_multimap(lambda *x: np.concatenate(x, 0), *train_trajectories) 200 | (obs, a, r, _, _, log_prob) = train_trajectories 201 | W = v_fit(featmat, r)[0] 202 | adv = compute_advantage(W, obs, r) 203 | train_trajectories = (obs, a, adv, log_prob) 204 | r0 = r # metrics 205 | 206 | # compute gradient + step 207 | inner_params_p = reinforce_step(train_trajectories[:-1], p_params, alpha) 208 | 209 | # rollout again (test) 210 | test_trajectories = n_rollouts(inner_params_p, env, rng, eval_n_traj) 211 | featmat = np.concatenate([v_features(traj[0]) for traj in test_trajectories]) 212 | test_trajectories = jax.tree_multimap(lambda *x: np.concatenate(x, 0), *test_trajectories) 213 | (obs, a, r, _, _, log_prob) = test_trajectories 214 | Wtest = v_fit(featmat, r)[0] 215 | adv = compute_advantage(Wtest, obs, r) 216 | test_trajectories = (obs, a, adv, log_prob) 217 | 218 | return (train_trajectories, test_trajectories), (r0, r) 219 | 220 | @jax.jit 221 | def maml_outter(p_params, inner): 222 | train_trajectories, test_trajectories = inner 223 | # step on train 224 | inner_params_p = reinforce_step(train_trajectories[:-1], p_params, alpha) 225 | # loss on test 226 | loss, _ = ppo_loss(inner_params_p, *test_trajectories) 227 | return loss 228 | 229 | env.seed(0) 230 | n_tasks = 1 231 | task = env.sample_tasks(1)[0] ## only two tasks 232 | assert n_tasks in [1, 2] 233 | if n_tasks == 1: 234 | tasks = [task] * task_batch_size 235 | elif n_tasks == 2: 236 | task2 = {'goal': -task['goal'].copy()} 237 | tasks = [task, task2] * (task_batch_size//2) 238 | 239 | for task in tasks[:n_tasks]: 240 | env.reset_task(task) 241 | # log max reward 242 | goal = env._task['goal'] 243 | reward = 0 244 | step_count = 0 245 | obs = env.reset() 246 | while True: 247 | a = goal - obs 248 | obs2, r, done, _ = env.step(a) 249 | reward += r 250 | step_count += 1 251 | if done: break 252 | obs = obs2 253 | print(f'[LOGGER]: MAX_REWARD={reward} IN {step_count} STEPS') 254 | 255 | #%% 256 | from torch.utils.tensorboard import SummaryWriter 257 | writer = SummaryWriter(comment=f'maml4_{n_tasks}task_test_seed={seed}') 258 | 259 | #%% 260 | maml_grads = jax.jit(jax.value_and_grad(maml_outter)) 261 | 262 | # tasks = env.sample_tasks(2) 263 | step_count = 0 264 | for e in tqdm(range(1, epochs+1)): 265 | 266 | inners = [] 267 | reward_step0 = [] 268 | reward_step1 = [] 269 | for i in tqdm(range(len(tasks))): 270 | env.reset_task(tasks[i]) 271 | inner, (r0, r1) = maml_inner(p_params, env, rng, 0.1) 272 | inners.append(inner) 273 | r0, r1 = r0.sum().item(), r1.sum().item() 274 | reward_step0.append(r0) 275 | reward_step1.append(r1) 276 | 277 | writer.add_scalar(f'task{i}/reward_0step', r0, e) 278 | writer.add_scalar(f'task{i}/reward_1step', r1, e) 279 | 280 | print(f'step0 mean reward: {onp.mean(reward_step0)} ({reward_step0})' ) 281 | print(f'step1 mean reward: {onp.mean(reward_step1)} ({reward_step1})' ) 282 | 283 | writer.add_scalar(f'mean_task/reward_0step', onp.mean(reward_step0), e) 284 | writer.add_scalar(f'mean_task/reward_1step', onp.mean(reward_step1), e) 285 | 286 | grads = [maml_grads(p_params, innr)[1] for innr in inners] 287 | grads = jax.tree_multimap(lambda *g: np.stack(g, 0).mean(0), *grads) 288 | p_params, p_opt_state = p_update_fcn(p_params, p_opt_state, grads) 289 | 290 | #%% 291 | #%% 292 | 293 | 294 | #%% 295 | #%% 296 | #%% 297 | #%% 298 | #%% 299 | #%% -------------------------------------------------------------------------------- /maml/env.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import numpy as np 3 | import gym 4 | 5 | from gym import spaces 6 | from gym.utils import seeding 7 | 8 | class Navigation2DEnv(gym.Env): 9 | """2D navigation problems, as described in [1]. The code is adapted from 10 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/maml_examples/point_env_randgoal.py 11 | At each time step, the 2D agent takes an action (its velocity, clipped in 12 | [-0.1, 0.1]), and receives a penalty equal to its L2 distance to the goal 13 | position (ie. the reward is `-distance`). The 2D navigation tasks are 14 | generated by sampling goal positions from the uniform distribution 15 | on [-0.5, 0.5]^2. 16 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 17 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 18 | (https://arxiv.org/abs/1703.03400) 19 | """ 20 | def __init__(self, task={}, low=-0.5, high=0.5, max_n_steps=100): 21 | super().__init__() 22 | self.low = low 23 | self.high = high 24 | self.max_n_steps = max_n_steps 25 | self._step_count = 0 26 | 27 | self.observation_space = spaces.Box(low=-np.inf, high=np.inf, 28 | shape=(2,), dtype=np.float32) 29 | self.action_space = spaces.Box(low=-0.1, high=0.1, 30 | shape=(2,), dtype=np.float32) 31 | 32 | self._task = task 33 | self._goal = task.get('goal', np.zeros(2, dtype=np.float32)) 34 | self._state = np.zeros(2, dtype=np.float32) 35 | self.seed() 36 | 37 | def seed(self, seed=None): 38 | self.np_random, seed = seeding.np_random(seed) 39 | return [seed] 40 | 41 | def sample_tasks(self, num_tasks): 42 | goals = self.np_random.uniform(self.low, self.high, size=(num_tasks, 2)) 43 | tasks = [{'goal': goal} for goal in goals] 44 | return tasks 45 | 46 | def reset_task(self, task): 47 | self._task = task 48 | self._goal = task['goal'] 49 | 50 | def reset(self): 51 | self._step_count = 0 52 | self._state = np.zeros(2, dtype=np.float32) 53 | return self._state 54 | 55 | def step(self, action): 56 | action = np.clip(action, -0.1, 0.1) 57 | action = np.array(action) 58 | assert self.action_space.contains(action) 59 | self._state = self._state + action 60 | 61 | diff = self._state - self._goal 62 | reward = -np.sqrt((diff**2).sum()) 63 | done = (np.abs(diff) < 0.01).sum() == 2 64 | 65 | done = done or self._step_count >= self.max_n_steps 66 | self._step_count += 1 67 | 68 | return self._state, reward, done, {'task': self._task} 69 | 70 | class Navigation2DEnv_Disc(gym.Env): 71 | def __init__(self, task={}, low=-0.5, high=0.5, max_n_steps=100): 72 | super().__init__() 73 | self.low = low 74 | self.high = high 75 | self.max_n_steps = max_n_steps 76 | self._step_count = 0 77 | 78 | self.observation_space = spaces.Box(low=self.low, high=self.high, 79 | shape=(2,), dtype=np.float32) 80 | self.action_space = spaces.Discrete(4) # left, right, up, down 81 | 82 | self._task = task 83 | self._goal = task.get('goal', np.zeros(2, dtype=np.float32)) 84 | self._state = np.zeros(2, dtype=np.float32) 85 | self.seed() 86 | 87 | def seed(self, seed=None): 88 | self.np_random, seed = seeding.np_random(seed) 89 | return [seed] 90 | 91 | def sample_tasks(self, num_tasks): 92 | while True: 93 | goals = self.np_random.uniform(self.low, self.high, size=(num_tasks, 2)) 94 | if not (goals.sum(0) == 0).any(): break 95 | 96 | goals = np.round_(goals, 1) # discrete them to 0.1 steps 97 | tasks = [{'goal': goal} for goal in goals] 98 | return tasks 99 | 100 | def reset_task(self, task): 101 | self._task = task 102 | self._goal = task['goal'] 103 | 104 | def reset(self): 105 | self._step_count = 0 106 | self._state = np.zeros(2, dtype=np.float32) 107 | return self._state 108 | 109 | def step(self, action): 110 | assert self.action_space.contains(action) 111 | # up down left right 112 | step = np.array({ 113 | 0: [0, 0.1], 114 | 1: [0, -0.1], 115 | 2: [-0.1, 0], 116 | 3: [0.1, 0] 117 | }[action]) 118 | 119 | self._state = self._state + step 120 | self._state = np.clip(self._state, self.low, self.high) 121 | 122 | diff = self._state - self._goal 123 | reward = -np.sqrt((diff**2).sum()) 124 | done = (np.abs(diff) < 0.01).sum() == 2 125 | 126 | done = done or self._step_count >= self.max_n_steps 127 | self._step_count += 1 128 | 129 | return self._state, reward, done, {'task': self._task} 130 | -------------------------------------------------------------------------------- /maml/maml.py: -------------------------------------------------------------------------------- 1 | #%% 2 | %load_ext autoreload 3 | %autoreload 2 4 | 5 | import jax 6 | import jax.numpy as np 7 | import numpy as onp 8 | import distrax 9 | import optax 10 | import gym 11 | from functools import partial 12 | from env import Navigation2DEnv, Navigation2DEnv_Disc 13 | import cloudpickle 14 | import pathlib 15 | import haiku as hk 16 | 17 | from jax.config import config 18 | config.update("jax_enable_x64", True) 19 | config.update("jax_debug_nans", True) # break on nans 20 | 21 | #%% 22 | from utils import normal_log_density, sample_gaussian 23 | from utils import disc_policy as policy 24 | from utils import eval, init_policy_fcn, Disc_Vector_Buffer, discount_cumsum, \ 25 | tree_mean, mean_vmap_jit, sum_vmap_jit 26 | 27 | env_name = 'Navigation2D' 28 | env = Navigation2DEnv_Disc(max_n_steps=200) # maml debug env 29 | 30 | n_actions = env.action_space.n 31 | obs_dim = env.observation_space.shape[0] 32 | print(f'[LOGGER] n_actions: {n_actions} obs_dim: {obs_dim}') 33 | 34 | # value function / baseline 35 | # https://github.com/rll/rllab/blob/master/rllab/baselines/linear_feature_baseline.py 36 | def v_features(obs): 37 | o = np.clip(obs, -10, 10) 38 | l = len(o) 39 | al = np.arange(l).reshape(-1, 1) / 100.0 40 | return np.concatenate([o, o ** 2, al, al ** 2, al ** 3, np.ones((l, 1))], axis=1) 41 | 42 | def v_fit(trajectories, feature_fcn=v_features, reg_coeff=1e-5): 43 | featmat = np.concatenate([feature_fcn(traj['obs']) for traj in trajectories]) 44 | r = np.concatenate([traj['r'] for traj in trajectories]) 45 | for _ in range(5): 46 | # solve argmin_x (F x = R) <-- unsolvable (F non-sqr) 47 | # == argmin_x (F^T F x = F^T R) <-- solvable (sqr F^T F) 48 | # where F = Features, x = Weights, R = rewards 49 | _coeffs = np.linalg.lstsq( 50 | featmat.T.dot(featmat) + reg_coeff * np.identity(featmat.shape[1]), 51 | featmat.T.dot(r) 52 | )[0] 53 | if not np.any(np.isnan(_coeffs)): 54 | return _coeffs, 0 # succ 55 | reg_coeff *= 10 56 | return np.zeros_like(_coeffs), 1 # err 57 | 58 | def sample_trajectory(traj, p): 59 | traj_len = int(traj[0].shape[0] * p) 60 | idxs = onp.random.choice(traj_len, size=traj_len, replace=False) 61 | sampled_traj = jax.tree_map(lambda x: x[idxs], traj) 62 | return sampled_traj 63 | 64 | def rollout(env, p_params, rng): 65 | buffer = Disc_Vector_Buffer(obs_dim, max_n_steps) 66 | obs = env.reset() 67 | for _ in range(max_n_steps): 68 | rng, subkey = jax.random.split(rng, 2) 69 | a, log_prob = policy(p_frwd, p_params, obs, subkey, False) 70 | 71 | a = jax.lax.stop_gradient(a) 72 | log_prob = jax.lax.stop_gradient(log_prob) 73 | a = a.item() 74 | 75 | obs2, r, done, _ = env.step(a) 76 | buffer.push((obs, a, r, obs2, done, log_prob)) 77 | 78 | obs = obs2 79 | if done: break 80 | 81 | trajectory = buffer.contents() 82 | return trajectory 83 | 84 | #%% 85 | # inner optim 86 | @jax.jit 87 | def _reinforce_loss(p_params, obs, a, adv): 88 | pi = p_frwd(p_params, obs) 89 | log_prob = distrax.Categorical(probs=pi).log_prob(a) 90 | loss = -(log_prob * adv).sum() 91 | return loss 92 | 93 | reinforce_loss = sum_vmap_jit(_reinforce_loss, (None, 0, 0, 0)) 94 | reinforce_loss_grad = jax.jit(jax.value_and_grad(reinforce_loss)) 95 | 96 | @jax.jit 97 | def sgd_step_int(params, grads, alpha): 98 | sgd_update = lambda param, grad: param - alpha * grad 99 | return jax.tree_multimap(sgd_update, params, grads) 100 | 101 | @jax.jit 102 | def sgd_step_tree(params, grads, alphas): 103 | sgd_update = lambda param, grad, alpha: param - alpha * grad 104 | return jax.tree_multimap(sgd_update, params, grads, alphas) 105 | 106 | def sgd_step(params, grads, alpha): 107 | step_fcn = sgd_step_int if type(alpha) in [int, float] else sgd_step_tree 108 | return step_fcn(params, grads, alpha) 109 | 110 | # %% 111 | seed = onp.random.randint(1e5) 112 | epochs = 500 113 | eval_every = 1 114 | max_n_steps = 100 # env._max_episode_steps 115 | ## TRPO 116 | delta = 0.01 117 | n_search_iters = 10 118 | cg_iters = 10 119 | gamma = 0.99 120 | lmbda = 0.95 121 | ## MAML 122 | task_batch_size = 40 123 | train_n_traj = 20 124 | eval_n_traj = 40 125 | alpha = 0.1 126 | damp_lambda = 0.01 127 | 128 | rng = jax.random.PRNGKey(seed) 129 | onp.random.seed(seed) 130 | 131 | ## model init 132 | p_frwd, p_params = init_policy_fcn('discrete', env, rng) 133 | 134 | ## save path 135 | model_path = pathlib.Path(f'./models/maml/{env_name}') 136 | model_path.mkdir(exist_ok=True, parents=True) 137 | 138 | # %% 139 | task = env.sample_tasks(1)[0] 140 | env.reset_task(task) 141 | 142 | # %% 143 | @jax.jit 144 | def compute_advantage(W, traj): 145 | # linear fcn predict 146 | v_obs = v_features(traj['obs']) @ W 147 | # baseline 148 | adv = traj['r'] - v_obs 149 | # normalize 150 | adv = (adv - adv.mean()) / (adv.std() + 1e-8) 151 | return adv.squeeze() 152 | 153 | def maml_inner(p_params, env, rng, n_traj, alpha): 154 | subkeys = jax.random.split(rng, n_traj) 155 | trajectories = [] 156 | for i in range(n_traj): 157 | traj = rollout(env, p_params, subkeys[i]) 158 | traj['r'] = discount_cumsum(traj['r'], discount=gamma) 159 | trajectories.append(traj) 160 | 161 | W = v_fit(trajectories)[0] 162 | for i in range(len(trajectories)): 163 | trajectories[i]['adv'] = compute_advantage(W, trajectories[i]) 164 | 165 | gradients = [] 166 | for traj in trajectories: 167 | _, grad = reinforce_loss_grad(p_params, traj['obs'], traj['a'], traj['adv']) 168 | gradients.append(grad) 169 | grads = jax.tree_multimap(lambda *x: np.stack(x).sum(0), *gradients) 170 | inner_params_p = sgd_step(p_params, grads, alpha) 171 | 172 | return inner_params_p, W 173 | 174 | def _trpo_policy_loss(p_params, obs, a, adv, old_log_prob): 175 | pi = p_frwd(p_params, obs) 176 | dist = distrax.Categorical(probs=pi) 177 | ratio = np.exp(dist.log_prob(a) - old_log_prob) 178 | loss = -(ratio * adv).sum() 179 | return loss 180 | 181 | trpo_policy_loss = mean_vmap_jit(_trpo_policy_loss, (None, *([0]*4))) 182 | 183 | def maml_outer(p_params, env, rng): 184 | subkeys = jax.random.split(rng, 3) 185 | newp, W = maml_inner(p_params, env, subkeys[0], train_n_traj, 0.1) 186 | 187 | traj = rollout(env, p_params, subkeys[1]) 188 | adv = compute_advantage(W, traj) 189 | loss = trpo_policy_loss(newp, traj['obs'], traj['a'], adv, traj['log_prob']) 190 | return loss, traj 191 | 192 | (loss, traj), grads = jax.value_and_grad(maml_outer, has_aux=True)(p_params, env, rng) 193 | loss 194 | 195 | # grads = grad(maml_outer) 196 | # compute natural gradient 197 | # line search step 198 | 199 | # %% 200 | def _natural_gradient(params, grads, obs): 201 | f = lambda w: p_frwd(w, obs) 202 | rho = D_KL_probs 203 | ngrad, _ = jax.scipy.sparse.linalg.cg( 204 | tree_mvp_dampen(lambda v: gnh_vp(f, rho, params, v), damp_lambda), 205 | grads, maxiter=cg_iters) 206 | 207 | vec = lambda x: x.flatten()[:, None] 208 | mat_mul = lambda x, y: np.sqrt(2 * delta / (vec(x).T @ vec(y)).flatten()) 209 | alpha = jax.tree_multimap(mat_mul, grads, ngrad) 210 | return ngrad, alpha 211 | natural_gradient = mean_vmap_jit(_natural_gradient, (None, 0)) 212 | 213 | ngrad, alpha = natural_gradient(p_params, grads, traj['obs']) 214 | alpha 215 | 216 | # %% 217 | probs = p_frwd(p_params, traj['obs'][0]) 218 | probs 219 | 220 | # %% 221 | jax.hessian(D_KL_probs)(probs, np.array(onp.array([0.55, 0.1, 0.25, 0.1]))) 222 | 223 | # %% 224 | # %% 225 | # %% 226 | #%% 227 | ### TRPO FCNS 228 | from utils import gnh_vp, tree_mvp_dampen 229 | 230 | def D_KL_probs(p1, p2): 231 | d_kl = (p1 * (np.log(p1) - np.log(p2))).sum() 232 | return d_kl 233 | 234 | def D_KL_probs_params(param1, param2, obs): 235 | p1, p2 = p_frwd(param1, obs), p_frwd(param2, obs) 236 | return D_KL_probs(p1, p2) 237 | 238 | def sample(traj, p): 239 | traj_len = int(traj['obs'].shape[0] * p) 240 | idxs = onp.random.choice(traj_len, size=traj_len, replace=False) 241 | sampled_traj = jax.tree_map(lambda x: x[idxs], traj) 242 | return sampled_traj 243 | 244 | import operator 245 | tree_scalar_op = lambda op: lambda tree, arg2: jax.tree_map(lambda x: op(x, arg2), tree) 246 | tree_scalar_divide = tree_scalar_op(operator.truediv) 247 | tree_scalar_mult = tree_scalar_op(operator.mul) 248 | 249 | # backtracking line-search 250 | def line_search(alpha_start, init_loss, p_params, p_ngrad, rollout, n_iters, delta): 251 | obs = rollout[0] 252 | for i in np.arange(n_iters): 253 | alpha = tree_scalar_divide(alpha_start, 2 ** i) 254 | 255 | new_p_params = sgd_step_tree(p_params, p_ngrad, alpha) 256 | new_loss = batch_policy_loss(new_p_params, rollout) 257 | 258 | d_kl = jax.vmap(partial(D_KL_probs_params, new_p_params, p_params))(obs).mean() 259 | 260 | if (new_loss < init_loss) and (d_kl <= delta): 261 | writer.add_scalar('info/line_search_n_iters', i, e) 262 | return new_p_params # new weights 263 | 264 | writer.add_scalar('info/line_search_n_iters', -1, e) 265 | return p_params # no new weights 266 | 267 | # %% 268 | 269 | 270 | # %% 271 | # %% 272 | # %% 273 | -------------------------------------------------------------------------------- /maml/maml_wave.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from functools import partial 3 | import jax 4 | import jax.numpy as np 5 | from jax import random, vmap, jit, grad 6 | from jax.experimental import stax, optimizers 7 | from jax.experimental.stax import Dense, Relu 8 | import matplotlib.pyplot as plt 9 | from tqdm.notebook import tqdm 10 | 11 | #%% 12 | # Use stax to set up network initialization and evaluation functions 13 | net_init, net_apply = stax.serial( 14 | Dense(40), Relu, 15 | Dense(40), Relu, 16 | Dense(1) 17 | ) 18 | in_shape = (-1, 1,) 19 | rng = random.PRNGKey(0) 20 | out_shape, params = net_init(rng, in_shape) 21 | 22 | #%% 23 | import numpy as onp 24 | def get_wave(wave_gen, n_samples=100, wave_params=False): 25 | x = wave_gen(n_samples) 26 | amp = onp.random.uniform(low=0.1, high=5.0) 27 | phase = onp.random.uniform(low=0., high=onp.pi) 28 | wave_data = x, onp.sin(x + phase) * amp 29 | 30 | if wave_params: wave_data = (wave_data, (phase, amp)) 31 | return wave_data 32 | 33 | def vis_wave_gen(N): # better for visualization 34 | x = onp.linspace(-5, 5, N).reshape((N, 1)) 35 | return x 36 | 37 | def train_wave_gen(N): # for model training 38 | x = onp.random.uniform(low=-5., high=5., size=(N, 1)) 39 | return x 40 | 41 | def mse(params, batch): 42 | x, y = batch 43 | ypred = net_apply(params, x) 44 | return np.mean((y - ypred)**2) 45 | 46 | #%% 47 | batch = get_wave(vis_wave_gen, 100) 48 | predictions = net_apply(params, batch[0]) 49 | losses = mse(params, batch) 50 | 51 | plt.plot(batch[0], predictions, label='prediction') 52 | plt.plot(*batch, label='target') 53 | plt.legend() 54 | 55 | #%% 56 | opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2) 57 | 58 | @jit 59 | def step(i, opt_state, batch): 60 | params = get_params(opt_state) 61 | g = grad(mse)(params, batch) 62 | return opt_update(i, g, opt_state) 63 | 64 | #%% 65 | out_shape, params = net_init(rng, in_shape) # re-init model 66 | opt_state = opt_init(params) # init optim 67 | 68 | batch = get_wave(vis_wave_gen, 100) 69 | for i in range(200): 70 | opt_state = step(i, opt_state, batch) 71 | params = get_params(opt_state) 72 | 73 | xb, yb = batch 74 | plt.plot(xb, net_apply(params, xb), label='prediction') 75 | plt.plot(xb, yb, label='target') 76 | plt.legend() 77 | 78 | # %% 79 | ### MAML 80 | alpha = 0.1 81 | 82 | # inner loop -- take one gradient step on the data 83 | def inner_update(params, batch): 84 | grads = grad(mse)(params, batch) 85 | sgd_update = lambda param, grad: param - alpha * grad 86 | inner_params = jax.tree_multimap(sgd_update, params, grads) 87 | return inner_params 88 | 89 | # outer loop 90 | def maml_loss(params, train_batch, test_batch): 91 | task_params = inner_update(params, train_batch) 92 | loss = mse(task_params, test_batch) 93 | return loss 94 | 95 | @jit 96 | def maml_step(i, opt_state, train_batch, test_batch): 97 | params = get_params(opt_state) 98 | g = grad(maml_loss)(params, train_batch, test_batch) 99 | return opt_update(i, g, opt_state) 100 | 101 | ## task extractor 102 | def get_task(n_train, n_test, wave_params=False): 103 | if not wave_params: 104 | batch = get_wave(train_wave_gen, n_train + n_test) 105 | else: 106 | batch, wparams = get_wave(train_wave_gen, n_train + n_test, wave_params=True) 107 | 108 | # extract train/test elements from batch=(xb, yb) with treemap :) 109 | train_batch = jax.tree_map(lambda l: l[:n_train], batch, is_leaf=lambda node: hasattr(node, 'shape')) 110 | test_batch = jax.tree_map(lambda l: l[n_train:], batch, is_leaf=lambda node: hasattr(node, 'shape')) 111 | 112 | task = train_batch, test_batch 113 | if wave_params: task = (*task, wparams) 114 | 115 | return task 116 | 117 | # %% 118 | opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3) 119 | out_shape, params = net_init(rng, in_shape) # re-init model 120 | opt_state = opt_init(params) # init optim 121 | 122 | for i in tqdm(range(20000)): 123 | train_batch, test_batch = get_task(20, 1) 124 | opt_state = maml_step(i, opt_state, train_batch, test_batch) 125 | params = get_params(opt_state) 126 | 127 | # %% 128 | train_batch, test_batch, wparams = get_task(20, 1, wave_params=True) 129 | 130 | # re-create wave smoother for visualization 131 | phase, amp = wparams 132 | x = vis_wave_gen(100) 133 | y = np.sin(x + phase) * amp 134 | plt.plot(x, y, label='targets') 135 | 136 | step_params = params.copy() 137 | for i in range(5): # visualize wave at each grad step 138 | ypred = net_apply(step_params, x) 139 | plt.plot(x, ypred, label=f'step{i}') 140 | step_params = inner_update(step_params, train_batch) 141 | 142 | plt.legend() 143 | 144 | # %% 145 | task_batch_size = 5 146 | tasks = [get_task(20, 1) for _ in range(task_batch_size)] 147 | train_batch, test_batch = jax.tree_multimap(lambda *b: np.stack(b), *tasks, is_leaf=lambda node: hasattr(node, 'shape')) 148 | 149 | xb, yb = train_batch 150 | for i in range(len(xb)): 151 | plt.scatter(xb[i], yb[i]) 152 | 153 | # %% 154 | def batch_maml_loss(params, train_batch, test_batch): 155 | losses = vmap(partial(maml_loss, params))(train_batch, test_batch) 156 | loss = losses.mean() 157 | return loss 158 | 159 | @jit 160 | def batch_maml_step(i, opt_state, train_batch, test_batch): 161 | params = get_params(opt_state) 162 | g = grad(batch_maml_loss)(params, train_batch, test_batch) 163 | return opt_update(i, g, opt_state) 164 | 165 | # %% 166 | task_batch_size = 4 167 | 168 | opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3) 169 | out_shape, params = net_init(rng, in_shape) # re-init model 170 | opt_state = opt_init(params) # init optim 171 | 172 | for i in tqdm(range(20000)): 173 | # get batch of tasks 174 | tasks = [get_task(20, 1) for _ in range(task_batch_size)] 175 | train_batch, test_batch = jax.tree_multimap(lambda *b: np.stack(b), *tasks, is_leaf=lambda node: hasattr(node, 'shape')) 176 | # take gradient step over the mean 177 | opt_state = batch_maml_step(i, opt_state, train_batch, test_batch) 178 | 179 | params = get_params(opt_state) 180 | 181 | # %% 182 | train_batch, test_batch, wparams = get_task(20, 1, wave_params=True) 183 | 184 | # re-create wave smoother for visualization 185 | phase, amp = wparams 186 | x = vis_wave_gen(100) 187 | y = np.sin(x + phase) * amp 188 | plt.plot(x, y, label='targets') 189 | plt.scatter(*train_batch, label='train') 190 | 191 | step_params = params.copy() 192 | for i in range(5): # visualize wave at each grad step 193 | ypred = net_apply(step_params, x) 194 | plt.plot(x, ypred, label=f'step{i}') 195 | step_params = inner_update(step_params, train_batch) 196 | 197 | plt.legend() 198 | 199 | # %% 200 | -------------------------------------------------------------------------------- /maml/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax.numpy as np 3 | import jax 4 | from functools import partial 5 | import haiku as hk 6 | import optax 7 | 8 | def init_policy_fcn(type, env, rng, nhidden=64, jit=True): 9 | 10 | if type == 'continuous': 11 | n_actions = env.action_space.shape[0] 12 | a_high = env.action_space.high[0] 13 | a_low = env.action_space.low[0] 14 | assert -a_high == a_low 15 | 16 | def _cont_policy_fcn(s): 17 | log_std = hk.get_parameter("log_std", shape=[n_actions,], init=np.ones, dtype=np.float64) 18 | mu = hk.Sequential([ 19 | hk.Linear(nhidden), jax.nn.relu, 20 | hk.Linear(nhidden), jax.nn.relu, 21 | hk.Linear(n_actions), np.tanh 22 | ])(s) * a_high 23 | sig = np.exp(log_std) 24 | return mu, sig 25 | _policy_fcn = _cont_policy_fcn 26 | 27 | elif type == 'discrete': 28 | n_actions = env.action_space.n 29 | 30 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 31 | def _disc_policy_fcn(s): 32 | pi = hk.Sequential([ 33 | hk.Linear(nhidden), jax.nn.relu, 34 | hk.Linear(nhidden), jax.nn.relu, 35 | hk.Linear(n_actions, w_init=init_final), jax.nn.softmax 36 | ])(s) 37 | return pi 38 | _policy_fcn = _disc_policy_fcn 39 | 40 | policy_fcn = hk.transform(_policy_fcn) 41 | policy_fcn = hk.without_apply_rng(policy_fcn) 42 | p_frwd = policy_fcn.apply 43 | if jit: 44 | p_frwd = jax.jit(p_frwd) 45 | 46 | obs = np.zeros(env.observation_space.shape) # dummy input 47 | p_params = policy_fcn.init(rng, obs) 48 | 49 | return p_frwd, p_params 50 | 51 | def optim_update_fcn(optim, params): 52 | opt_state = optim.init(params) 53 | @jax.jit 54 | def update_step(params, opt_state, grads): 55 | grads, opt_state = optim.update(grads, opt_state) 56 | params = optax.apply_updates(params, grads) 57 | return params, opt_state 58 | return update_step, opt_state 59 | 60 | def gaussian_log_prob(x, mean, std): 61 | log_std = np.log(std) 62 | var = np.power(std, 2) 63 | log_density = -np.power(x - mean, 2) / ( 64 | 2 * var) - 0.5 * np.log(2 * np.pi) - log_std 65 | return np.sum(log_density) 66 | 67 | def gaussian_sample(mean, std, rng): 68 | return jax.random.normal(rng) * std + mean 69 | 70 | @partial(jax.jit, static_argnums=(0,)) 71 | def cont_policy(p_frwd, params, obs, rng, clip_range, greedy): 72 | mu, std = p_frwd(params, obs) 73 | a = jax.lax.cond(greedy, lambda _: mu, lambda _: gaussian_sample(mu, std, rng), None) 74 | a = np.clip(a, *clip_range) # [low, high] 75 | log_prob = gaussian_log_prob(a, mu, std) 76 | return a, log_prob 77 | 78 | @partial(jax.jit, static_argnums=(0,)) 79 | def disc_policy(p_frwd, params, obs, rng, greedy): 80 | pi = p_frwd(params, obs) 81 | dist = distrax.Categorical(probs=pi) 82 | a = jax.lax.cond(greedy, lambda _: pi.argmax(), lambda _: dist.sample(seed=rng), None) 83 | a = dist.sample(seed=rng) 84 | log_prob = dist.log_prob(a) 85 | return a, log_prob 86 | 87 | def eval(p_frwd, policy, params, env, rng, clip_range, greedy): 88 | rewards = 0 89 | obs = env.reset() 90 | while True: 91 | rng, subkey = jax.random.split(rng, 2) 92 | a = policy(p_frwd, params, obs, subkey, clip_range, greedy)[0] 93 | a = onp.array(a) 94 | obs2, r, done, _ = env.step(a) 95 | obs = obs2 96 | rewards += r 97 | if done: break 98 | return rewards 99 | 100 | class Cont_Vector_Buffer: 101 | def __init__(self, n_actions, obs_dim, buffer_capacity): 102 | self.obs_dim, self.n_actions = obs_dim, n_actions 103 | self.buffer_capacity = buffer_capacity = int(buffer_capacity) 104 | self.i = 0 105 | # obs, a, r, obs2, done 106 | self.splits = [obs_dim, obs_dim+n_actions, obs_dim+n_actions+1, obs_dim*2+1+n_actions, obs_dim*2+1+n_actions+1] 107 | self.split_names = ['obs', 'a', 'r', 'obs2', 'done', 'log_prob'] 108 | self.clear() 109 | 110 | def push(self, sample): 111 | assert self.i < self.buffer_capacity # dont let it get full 112 | (obs, a, r, obs2, done, log_prob) = sample 113 | self.buffer[self.i] = onp.array([*obs, *onp.array(a), onp.array(r), *obs2, float(done), onp.array(log_prob)]) 114 | self.i += 1 115 | 116 | def contents(self): 117 | contents = onp.split(self.buffer[:self.i], self.splits, axis=-1) 118 | return contents 119 | # d = {} 120 | # for n, c in zip(self.split_names, contents): d[n] = c 121 | # return d 122 | 123 | def clear(self): 124 | self.i = 0 125 | self.buffer = onp.zeros((self.buffer_capacity, 2 * self.obs_dim + self.n_actions + 2 + 1)) 126 | 127 | class Disc_Vector_Buffer(Cont_Vector_Buffer): 128 | def __init__(self, obs_dim, buffer_capacity): 129 | super().__init__(1, obs_dim, buffer_capacity) 130 | 131 | def push(self, sample): 132 | assert self.i < self.buffer_capacity # dont let it get full 133 | (obs, a, r, obs2, done, log_prob) = sample 134 | self.buffer[self.i] = onp.array([*obs, onp.array(a), onp.array(r), *obs2, float(done), onp.array(log_prob)]) 135 | self.i += 1 136 | 137 | ## second order stuff (trpo) 138 | def hvp(J, w, v): 139 | return jax.jvp(jax.grad(J), (w,), (v,))[1] 140 | 141 | def gnh_vp(f, rho, w, v): 142 | z, R_z = jax.jvp(f, (w,), (v,)) 143 | R_gz = hvp(lambda z1: rho(z, z1), z, R_z) 144 | _, f_vjp = jax.vjp(f, w) 145 | return f_vjp(R_gz)[0] 146 | 147 | def tree_mvp_dampen(mvp, lmbda=0.1): 148 | dampen_fcn = lambda mvp_, v_: mvp_ + lmbda * v_ 149 | damp_mvp = lambda v: jax.tree_multimap(dampen_fcn, mvp(v), v) 150 | return damp_mvp 151 | 152 | def discount_cumsum(l, discount): 153 | l = onp.array(l) 154 | for i in range(len(l) - 1)[::-1]: 155 | l[i] = l[i] + discount * l[i+1] 156 | return l 157 | 158 | def tree_shape(tree): 159 | for l in jax.tree_leaves(tree): print(l.shape) 160 | 161 | tree_mean = jax.jit(lambda tree: jax.tree_map(lambda x: x.mean(0), tree)) 162 | tree_sum = jax.jit(lambda tree: jax.tree_map(lambda x: x.sum(0), tree)) 163 | 164 | def jit_vmap_tree_op(jit_tree_op, f, *vmap_args): 165 | return lambda *args: jit_tree_op(jax.vmap(f, *vmap_args)(*args)) 166 | 167 | mean_vmap_jit = partial(jit_vmap_tree_op, tree_mean) 168 | sum_vmap_jit = partial(jit_vmap_tree_op, tree_sum) -------------------------------------------------------------------------------- /ppo/env.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import numpy as np 3 | import gym 4 | 5 | from gym import spaces 6 | from gym.utils import seeding 7 | 8 | class Navigation2DEnv(gym.Env): 9 | """2D navigation problems, as described in [1]. The code is adapted from 10 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/maml_examples/point_env_randgoal.py 11 | At each time step, the 2D agent takes an action (its velocity, clipped in 12 | [-0.1, 0.1]), and receives a penalty equal to its L2 distance to the goal 13 | position (ie. the reward is `-distance`). The 2D navigation tasks are 14 | generated by sampling goal positions from the uniform distribution 15 | on [-0.5, 0.5]^2. 16 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 17 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 18 | (https://arxiv.org/abs/1703.03400) 19 | """ 20 | def __init__(self, task={}, low=-0.5, high=0.5, max_n_steps=100): 21 | super(Navigation2DEnv, self).__init__() 22 | self.low = low 23 | self.high = high 24 | self.max_n_steps = max_n_steps 25 | 26 | self.observation_space = spaces.Box(low=-np.inf, high=np.inf, 27 | shape=(2,), dtype=np.float32) 28 | self.action_space = spaces.Box(low=-0.1, high=0.1, 29 | shape=(2,), dtype=np.float32) 30 | 31 | self._step_count = 0 32 | self._task = task 33 | self._goal = task.get('goal', np.zeros(2, dtype=np.float32)) 34 | self._state = np.zeros(2, dtype=np.float32) 35 | self.seed() 36 | 37 | def seed(self, seed=None): 38 | self.np_random, seed = seeding.np_random(seed) 39 | return [seed] 40 | 41 | def sample_tasks(self, num_tasks): 42 | goals = self.np_random.uniform(self.low, self.high, size=(num_tasks, 2)) 43 | tasks = [{'goal': goal} for goal in goals] 44 | return tasks 45 | 46 | def reset_task(self, task): 47 | self._task = task 48 | self._goal = task['goal'] 49 | 50 | def reset(self): 51 | self._step_count = 0 52 | self._state = np.zeros(2, dtype=np.float32) 53 | return self._state 54 | 55 | def step(self, action): 56 | action = np.clip(action, -0.1, 0.1) 57 | assert self.action_space.contains(action) 58 | self._state = self._state + action 59 | 60 | diff = self._state - self._goal 61 | reward = -np.sqrt((diff**2).sum()) 62 | done = (np.abs(diff) < 0.01).sum() == 2 63 | 64 | done = done or self._step_count >= self.max_n_steps 65 | self._step_count += 1 66 | 67 | return self._state, reward, done, {'task': self._task} 68 | 69 | class Navigation2DEnv_Disc(gym.Env): 70 | def __init__(self, task={}, low=-0.5, high=0.5, max_n_steps=100): 71 | super().__init__() 72 | self.low = low 73 | self.high = high 74 | self.max_n_steps = max_n_steps 75 | self._step_count = 0 76 | 77 | self.observation_space = spaces.Box(low=self.low, high=self.high, 78 | shape=(2,), dtype=np.float32) 79 | self.action_space = spaces.Discrete(4) # left, right, up, down 80 | 81 | self._task = task 82 | self._goal = task.get('goal', np.zeros(2, dtype=np.float32)) 83 | self._state = np.zeros(2, dtype=np.float32) 84 | self.seed() 85 | 86 | def seed(self, seed=None): 87 | self.np_random, seed = seeding.np_random(seed) 88 | return [seed] 89 | 90 | def sample_tasks(self, num_tasks): 91 | while True: 92 | goals = self.np_random.uniform(self.low, self.high, size=(num_tasks, 2)) 93 | if not (goals.sum(0) == 0).any(): break 94 | 95 | goals = np.round_(goals, 1) # discrete them to 0.1 steps 96 | tasks = [{'goal': goal} for goal in goals] 97 | return tasks 98 | 99 | def reset_task(self, task): 100 | self._task = task 101 | self._goal = task['goal'] 102 | 103 | def reset(self): 104 | self._step_count = 0 105 | self._state = np.zeros(2, dtype=np.float32) 106 | return self._state 107 | 108 | def step(self, action): 109 | assert self.action_space.contains(action) 110 | # up down left right 111 | step = np.array({ 112 | 0: [0, 0.1], 113 | 1: [0, -0.1], 114 | 2: [-0.1, 0], 115 | 3: [0.1, 0] 116 | }[action]) 117 | 118 | self._state = self._state + step 119 | self._state = np.clip(self._state, self.low, self.high) 120 | 121 | diff = self._state - self._goal 122 | reward = -np.sqrt((diff**2).sum()) 123 | done = (np.abs(diff) < 0.01).sum() == 2 124 | 125 | done = done or self._step_count >= self.max_n_steps 126 | self._step_count += 1 127 | 128 | return self._state, reward, done, {'task': self._task} 129 | 130 | # %% 131 | 132 | # %% 133 | # %% 134 | # %% 135 | # %% 136 | -------------------------------------------------------------------------------- /ppo/ppo.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import jax 3 | import jax.numpy as np 4 | import numpy as onp 5 | import distrax 6 | import optax 7 | import gym 8 | from functools import partial 9 | import cloudpickle 10 | import pybullet_envs 11 | 12 | from jax.config import config 13 | config.update("jax_enable_x64", True) 14 | config.update("jax_debug_nans", True) # break on nans 15 | 16 | env_name = 'Pendulum-v0' 17 | make_env = lambda: gym.make(env_name) 18 | 19 | # from env import Navigation2DEnv 20 | # env_name = 'Navigation2D' 21 | # def make_env(): 22 | # env = Navigation2DEnv(max_n_steps=200) 23 | # env.seed(0) 24 | # task = env.sample_tasks(1)[0] 25 | # print(f'[LOGGER]: task = {task}') 26 | # env.reset_task(task) 27 | 28 | # # log max reward 29 | # goal = env._task['goal'] 30 | # reward = 0 31 | # step_count = 0 32 | # obs = env.reset() 33 | # while True: 34 | # a = goal - obs 35 | # obs2, r, done, _ = env.step(a) 36 | # reward += r 37 | # step_count += 1 38 | # if done: break 39 | # obs = obs2 40 | # print(f'[LOGGER]: MAX_REWARD={reward} IN {step_count} STEPS') 41 | # return env 42 | 43 | env = make_env() 44 | n_actions = env.action_space.shape[0] 45 | obs_dim = env.observation_space.shape[0] 46 | 47 | a_high = env.action_space.high[0] 48 | a_low = env.action_space.low[0] 49 | 50 | print(f'[LOGGER] a_high: {a_high} a_low: {a_low} n_actions: {n_actions} obs_dim: {obs_dim}') 51 | assert -a_high == a_low 52 | 53 | #%% 54 | import haiku as hk 55 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 56 | 57 | def _policy_fcn(s): 58 | log_std = hk.get_parameter("log_std", shape=[n_actions,], init=np.ones) 59 | mu = hk.Sequential([ 60 | hk.Linear(64), jax.nn.relu, 61 | hk.Linear(64), jax.nn.relu, 62 | hk.Linear(n_actions, w_init=init_final), np.tanh 63 | ])(s) * a_high 64 | sig = np.exp(log_std) 65 | return mu, sig 66 | 67 | def _critic_fcn(s): 68 | v = hk.Sequential([ 69 | hk.Linear(64), jax.nn.relu, 70 | hk.Linear(64), jax.nn.relu, 71 | hk.Linear(1), 72 | ])(s) 73 | return v 74 | 75 | policy_fcn = hk.transform(_policy_fcn) 76 | policy_fcn = hk.without_apply_rng(policy_fcn) 77 | p_frwd = jax.jit(policy_fcn.apply) 78 | 79 | critic_fcn = hk.transform(_critic_fcn) 80 | critic_fcn = hk.without_apply_rng(critic_fcn) 81 | v_frwd = jax.jit(critic_fcn.apply) 82 | 83 | # %% 84 | @jax.jit 85 | def policy(params, obs, rng): 86 | mu, sig = p_frwd(params, obs) 87 | dist = distrax.MultivariateNormalDiag(mu, sig) 88 | a = dist.sample(seed=rng) 89 | a = np.clip(a, a_low, a_high) 90 | log_prob = dist.log_prob(a) 91 | return a, log_prob 92 | 93 | @jax.jit 94 | def eval_policy(params, obs, _): 95 | a, _ = p_frwd(params, obs) 96 | a = np.clip(a, a_low, a_high) 97 | return a, None 98 | 99 | def eval(params, env, rng): 100 | rewards = 0 101 | obs = env.reset() 102 | while True: 103 | rng, subrng = jax.random.split(rng) 104 | # a = eval_policy(params, obs, subrng)[0] 105 | a = policy(params, obs, subrng)[0] 106 | a = onp.array(a) 107 | obs2, r, done, _ = env.step(a) 108 | obs = obs2 109 | rewards += r 110 | if done: break 111 | return rewards 112 | 113 | class Vector_ReplayBuffer: 114 | def __init__(self, buffer_capacity): 115 | self.buffer_capacity = buffer_capacity = int(buffer_capacity) 116 | self.i = 0 117 | # obs, obs2, a, r, done 118 | self.splits = [obs_dim, obs_dim+n_actions, obs_dim+n_actions+1, obs_dim*2+1+n_actions, obs_dim*2+1+n_actions+1] 119 | self.clear() 120 | 121 | def push(self, sample): 122 | assert self.i < self.buffer_capacity # dont let it get full 123 | (obs, a, r, obs2, done, log_prob) = sample 124 | self.buffer[self.i] = onp.array([*obs, *onp.array(a), onp.array(r), *obs2, float(done), onp.array(log_prob)]) 125 | self.i += 1 126 | 127 | def contents(self): 128 | return onp.split(self.buffer[:self.i], self.splits, axis=-1) 129 | 130 | def clear(self): 131 | self.i = 0 132 | self.buffer = onp.zeros((self.buffer_capacity, 2 * obs_dim + n_actions + 2 + 1)) 133 | 134 | def shuffle_rollout(rollout): 135 | rollout_len = rollout[0].shape[0] 136 | idxs = onp.arange(rollout_len) 137 | onp.random.shuffle(idxs) 138 | rollout = jax.tree_map(lambda x: x[idxs], rollout) 139 | return rollout 140 | 141 | def rollout2batches(rollout, batch_size): 142 | rollout_len = rollout[0].shape[0] 143 | n_chunks = rollout_len // batch_size 144 | # shuffle / de-correlate 145 | rollout = shuffle_rollout(rollout) 146 | if n_chunks == 0: return rollout 147 | # batch 148 | batched_rollout = jax.tree_map(lambda x: np.array_split(x, n_chunks), rollout) 149 | for i in range(n_chunks): 150 | batch = [d[i] for d in batched_rollout] 151 | yield batch 152 | 153 | def discount_cumsum(l, discount): 154 | l = onp.array(l) 155 | for i in range(len(l) - 1)[::-1]: 156 | l[i] = l[i] + discount * l[i+1] 157 | return l 158 | 159 | def compute_advantage_targets(v_params, rollout): 160 | (obs, _, r, obs2, done, _) = rollout 161 | 162 | batch_v_fcn = jax.vmap(partial(v_frwd, v_params)) 163 | v_obs = batch_v_fcn(obs) 164 | v_obs2 = batch_v_fcn(obs2) 165 | 166 | # gae 167 | deltas = (r + (1 - done) * gamma * v_obs2) - v_obs 168 | deltas = jax.lax.stop_gradient(deltas) 169 | adv = discount_cumsum(deltas, discount=gamma * lmbda) 170 | adv = (adv - adv.mean()) / (adv.std() + 1e-8) 171 | 172 | # reward2go 173 | v_target = discount_cumsum(r, discount=gamma) 174 | 175 | return adv, v_target 176 | 177 | @jax.jit 178 | def ppo_loss(p_params, v_params, sample): 179 | (obs, a, old_log_prob, v_target, advantages) = sample 180 | 181 | ## critic loss 182 | v_obs = v_frwd(v_params, obs) 183 | critic_loss = (0.5 * ((v_obs - v_target) ** 2)).sum() 184 | 185 | ## policy losses 186 | mu, sig = p_frwd(p_params, obs) 187 | dist = distrax.MultivariateNormalDiag(mu, sig) 188 | # entropy 189 | entropy_loss = -dist.entropy() 190 | # policy gradient 191 | log_prob = dist.log_prob(a) 192 | 193 | approx_kl = (old_log_prob - log_prob).sum() 194 | ratio = np.exp(log_prob - old_log_prob) 195 | p_loss1 = ratio * advantages 196 | p_loss2 = np.clip(ratio, 1-eps, 1+eps) * advantages 197 | policy_loss = -np.fmin(p_loss1, p_loss2).sum() 198 | 199 | clipped_mask = ((ratio > 1+eps) | (ratio < 1-eps)).astype(np.float32) 200 | clip_frac = clipped_mask.mean() 201 | 202 | loss = policy_loss + 0.001 * entropy_loss + critic_loss 203 | 204 | info = dict(ploss=policy_loss, entr=-entropy_loss, vloss=critic_loss, 205 | approx_kl=approx_kl, cf=clip_frac) 206 | 207 | return loss, info 208 | 209 | def ppo_loss_batch(p_params, v_params, batch): 210 | out = jax.vmap(partial(ppo_loss, p_params, v_params))(batch) 211 | loss, info = jax.tree_map(lambda x: x.mean(), out) 212 | return loss, info 213 | 214 | ppo_loss_grad = jax.jit(jax.value_and_grad(ppo_loss_batch, argnums=[0,1], has_aux=True)) 215 | 216 | def optim_update_fcn(optim): 217 | @jax.jit 218 | def update_step(params, grads, opt_state): 219 | grads, opt_state = optim.update(grads, opt_state) 220 | params = optax.apply_updates(params, grads) 221 | return params, opt_state 222 | return update_step 223 | 224 | @jax.jit 225 | def ppo_step(p_params, v_params, p_opt_state, v_opt_state, batch): 226 | (loss, info), (p_grads, v_grads) = ppo_loss_grad(p_params, v_params, batch) 227 | p_params, p_opt_state = p_update_step(p_params, p_grads, p_opt_state) 228 | v_params, v_opt_state = v_update_step(v_params, v_grads, v_opt_state) 229 | return loss, info, p_params, v_params, p_opt_state, v_opt_state 230 | 231 | class Worker: 232 | def __init__(self, n_steps): 233 | self.n_steps = n_steps 234 | self.buffer = Vector_ReplayBuffer(n_step_rollout) 235 | self.env = make_env() 236 | self.obs = self.env.reset() 237 | 238 | def rollout(self, p_params, v_params, rng): 239 | self.buffer.clear() 240 | 241 | for _ in range(self.n_steps): # rollout 242 | rng, subrng = jax.random.split(rng) 243 | a, log_prob = policy(p_params, self.obs, subrng) 244 | a = onp.array(a) 245 | 246 | obs2, r, done, _ = self.env.step(a) 247 | 248 | self.buffer.push((self.obs, a, r, obs2, done, log_prob)) 249 | self.obs = obs2 250 | if done: 251 | self.obs = self.env.reset() 252 | 253 | # update rollout contents 254 | rollout = self.buffer.contents() 255 | advantages, v_target = compute_advantage_targets(v_params, rollout) 256 | (obs, a, r, _, _, log_prob) = rollout 257 | rollout = (obs, a, log_prob, v_target, advantages) 258 | 259 | return rollout 260 | 261 | #%% 262 | seed = 90897 #onp.random.randint(1e5) # 90897 works very well 263 | gamma = 0.99 264 | lmbda = 0.95 265 | eps = 0.2 266 | batch_size = 128 267 | policy_lr = 1e-3 268 | v_lr = 1e-3 269 | max_n_steps = 1e6 270 | n_step_rollout = 2048 # env._max_episode_steps 271 | save_models = False 272 | 273 | rng = jax.random.PRNGKey(seed) 274 | onp.random.seed(seed) 275 | 276 | obs = env.reset() # dummy input 277 | p_params = policy_fcn.init(rng, obs) 278 | v_params = critic_fcn.init(rng, obs) 279 | 280 | worker = Worker(n_step_rollout) 281 | 282 | ## optimizers 283 | optimizer = lambda lr: optax.chain( 284 | optax.clip_by_global_norm(0.5), 285 | optax.scale_by_adam(), 286 | optax.scale(-lr), 287 | ) 288 | p_optim = optimizer(policy_lr) 289 | v_optim = optimizer(v_lr) 290 | 291 | p_opt_state = p_optim.init(p_params) 292 | v_opt_state = v_optim.init(v_params) 293 | 294 | p_update_step = optim_update_fcn(p_optim) 295 | v_update_step = optim_update_fcn(v_optim) 296 | 297 | import pathlib 298 | model_path = pathlib.Path(f'./models/ppo/{env_name}') 299 | model_path.mkdir(exist_ok=True, parents=True) 300 | 301 | #%% 302 | from torch.utils.tensorboard import SummaryWriter 303 | writer = SummaryWriter(comment=f'ppo_{env_name}_seed={seed}') 304 | 305 | epi_i = 0 306 | step_i = 0 307 | from tqdm import tqdm 308 | pbar = tqdm(total=max_n_steps) 309 | while step_i < max_n_steps: 310 | rng, subkey = jax.random.split(rng, 2) 311 | rollout = worker.rollout(p_params, v_params, subkey) 312 | 313 | for batch in rollout2batches(rollout, batch_size): 314 | loss, info, p_params, v_params, p_opt_state, v_opt_state = \ 315 | ppo_step(p_params, v_params, p_opt_state, v_opt_state, batch) 316 | step_i += 1 317 | pbar.update(1) 318 | writer.add_scalar('loss/loss', loss.item(), step_i) 319 | for k in info.keys(): 320 | writer.add_scalar(f'info/{k}', info[k].item(), step_i) 321 | 322 | rng, subrng = jax.random.split(rng) 323 | reward = eval(p_params, env, subrng) 324 | writer.add_scalar('eval/total_reward', reward.item(), step_i) 325 | 326 | if save_models and (epi_i == 0 or reward > max_reward): 327 | max_reward = reward 328 | save_path = str(model_path/f'params_{max_reward:.2f}') 329 | print(f'Saving model {save_path}...') 330 | with open(save_path, 'wb') as f: 331 | cloudpickle.dump((p_params, v_params), f) 332 | 333 | epi_i += 1 334 | 335 | -------------------------------------------------------------------------------- /ppo/ppo_brax.py: -------------------------------------------------------------------------------- 1 | # very slow (even on TPU :( ) 2 | import os 3 | from IPython.display import clear_output 4 | 5 | !pip install distrax optax dm-haiku 6 | clear_output() 7 | 8 | try: 9 | import brax 10 | except ImportError: 11 | !pip install git+https://github.com/google/brax.git@main 12 | clear_output() 13 | import brax 14 | 15 | if 'COLAB_TPU_ADDR' in os.environ: 16 | from jax.tools import colab_tpu 17 | colab_tpu.setup_tpu() 18 | 19 | #%% 20 | import jax 21 | import jax.numpy as np 22 | import numpy as onp 23 | import distrax 24 | import optax 25 | import gym 26 | from functools import partial 27 | import cloudpickle 28 | 29 | from jax.config import config 30 | # config.update("jax_enable_x64", True) 31 | config.update("jax_debug_nans", True) # break on nans 32 | 33 | import brax 34 | from brax.envs import _envs, create_gym_env 35 | 36 | for env_name, env_class in _envs.items(): 37 | env_id = f"brax_{env_name}-v0" 38 | entry_point = partial(create_gym_env, env_name=env_name) 39 | if env_id not in gym.envs.registry.env_specs: 40 | print(f"Registring brax's '{env_name}' env under id '{env_id}'.") 41 | gym.register(env_id, entry_point=entry_point) 42 | 43 | make_env = lambda n_envs: gym.make("brax_halfcheetah-v0", batch_size=n_envs, \ 44 | episode_length=1000) 45 | env = make_env(1) # tmp 46 | 47 | obs_dim = env.observation_space.shape[-1] 48 | n_actions = env.action_space.shape[-1] 49 | 50 | a_low = -1 51 | a_high = 1 52 | assert (env.action_space.low == -1).all() and (env.action_space.high == 1).all() 53 | print(f'[LOGGER] n_actions: {n_actions} obs_dim: {obs_dim}') 54 | 55 | #%% 56 | import haiku as hk 57 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 58 | 59 | def _policy_fcn(s): 60 | log_std = hk.get_parameter("log_std", shape=[n_actions,], init=np.ones) 61 | mu = hk.Sequential([ 62 | hk.Linear(64), jax.nn.relu, 63 | hk.Linear(64), jax.nn.relu, 64 | hk.Linear(n_actions, w_init=init_final), np.tanh 65 | ])(s) 66 | sig = np.exp(log_std) 67 | return mu, sig 68 | 69 | def _critic_fcn(s): 70 | v = hk.Sequential([ 71 | hk.Linear(64), jax.nn.relu, 72 | hk.Linear(64), jax.nn.relu, 73 | hk.Linear(1), 74 | ])(s) 75 | return v 76 | 77 | policy_fcn = hk.transform(_policy_fcn) 78 | policy_fcn = hk.without_apply_rng(policy_fcn) 79 | p_frwd = jax.jit(policy_fcn.apply) 80 | 81 | critic_fcn = hk.transform(_critic_fcn) 82 | critic_fcn = hk.without_apply_rng(critic_fcn) 83 | v_frwd = jax.jit(critic_fcn.apply) 84 | 85 | @jax.jit 86 | def policy(params, obs, rng): 87 | mu, sig = p_frwd(params, obs) 88 | dist = distrax.MultivariateNormalDiag(mu, sig) 89 | a = dist.sample(seed=rng) 90 | a = np.clip(a, a_low, a_high) 91 | log_prob = dist.log_prob(a) 92 | return a, log_prob 93 | 94 | def pmap_policy(params, obs, rng): 95 | keys = jax.random.split(rng, n_envs) 96 | a, log_prob = jax.pmap(policy, in_axes=(None, 0, 0))(params, obs, keys) 97 | return a, log_prob 98 | 99 | def eval(params, env, rng): 100 | rewards = 0 101 | obs = env.reset() 102 | while True: 103 | rng, subrng = jax.random.split(rng) 104 | a = pmap_policy(params, obs, subrng)[0] 105 | obs2, r, done, _ = env.step(a) 106 | obs = obs2 107 | rewards += r 108 | if done.any(): break 109 | return rewards.sum() 110 | 111 | class Batch_Vector_ReplayBuffer: 112 | def __init__(self, buffer_capacity, n_envs): 113 | self.buffer_capacity = buffer_capacity = int(buffer_capacity) 114 | self.i = 0 115 | # obs, obs2, a, r, done 116 | self.splits = [obs_dim, obs_dim+n_actions, obs_dim+n_actions+1, obs_dim*2+1+n_actions, obs_dim*2+1+n_actions+1] 117 | self.n_envs = n_envs 118 | self.clear() 119 | 120 | def push(self, sample): 121 | assert self.i < self.buffer_capacity # dont let it get full 122 | (obs, a, r, obs2, done, log_prob) = sample 123 | self.buffer[:, self.i] = onp.concatenate([obs, a, r[:, None], obs2, done.astype(float)[:, None], log_prob[:, None]], -1) 124 | self.i += 1 125 | 126 | def contents(self): 127 | return onp.split(self.buffer[:, :self.i], self.splits, axis=-1) 128 | 129 | def clear(self): 130 | self.i = 0 131 | self.buffer = onp.zeros((self.n_envs, self.buffer_capacity, 2 * obs_dim + n_actions + 2 + 1)) 132 | 133 | def shuffle_rollout(rollout): 134 | rollout_len = rollout[0].shape[0] 135 | idxs = onp.arange(rollout_len) 136 | onp.random.shuffle(idxs) 137 | rollout = jax.tree_map(lambda x: x[idxs], rollout) 138 | return rollout 139 | 140 | def rollout2batches(rollout, batch_size): 141 | rollout_len = rollout[0].shape[0] 142 | n_chunks = rollout_len // batch_size 143 | # shuffle / de-correlate 144 | rollout = shuffle_rollout(rollout) 145 | if n_chunks == 0: return rollout 146 | # batch 147 | batched_rollout = jax.tree_map(lambda x: np.array_split(x, n_chunks), rollout) 148 | for i in range(n_chunks): 149 | batch = [d[i] for d in batched_rollout] 150 | yield batch 151 | 152 | from jax.ops import index, index_add 153 | def discount_cumsum(l, discount): 154 | for i in range(len(l) - 1)[::-1]: 155 | l = index_add(l, index[i], discount * l[i+1]) 156 | return l 157 | 158 | def compute_advantage_targets(v_params, rollout): 159 | (obs, _, r, obs2, done, _) = rollout 160 | 161 | batch_v_fcn = jax.vmap(partial(v_frwd, v_params)) 162 | v_obs = batch_v_fcn(obs) 163 | v_obs2 = batch_v_fcn(obs2) 164 | 165 | # gae 166 | deltas = (r + (1 - done) * gamma * v_obs2) - v_obs 167 | deltas = jax.lax.stop_gradient(deltas) 168 | adv = discount_cumsum(deltas, discount=gamma * lmbda) 169 | adv = (adv - adv.mean()) / (adv.std() + 1e-8) 170 | 171 | # reward2go 172 | v_target = discount_cumsum(r, discount=gamma) 173 | 174 | return adv, v_target 175 | 176 | @jax.jit 177 | def ppo_loss(p_params, v_params, sample): 178 | (obs, a, old_log_prob, v_target, advantages) = sample 179 | 180 | ## critic loss 181 | v_obs = v_frwd(v_params, obs) 182 | critic_loss = (0.5 * ((v_obs - v_target) ** 2)).sum() 183 | 184 | ## policy losses 185 | mu, sig = p_frwd(p_params, obs) 186 | dist = distrax.MultivariateNormalDiag(mu, sig) 187 | # entropy 188 | entropy_loss = -dist.entropy() 189 | # policy gradient 190 | log_prob = dist.log_prob(a) 191 | 192 | approx_kl = (old_log_prob - log_prob).sum() 193 | ratio = np.exp(log_prob - old_log_prob) 194 | p_loss1 = ratio * advantages 195 | p_loss2 = np.clip(ratio, 1-eps, 1+eps) * advantages 196 | policy_loss = -np.fmin(p_loss1, p_loss2).sum() 197 | 198 | clipped_mask = ((ratio > 1+eps) | (ratio < 1-eps)).astype(np.float32) 199 | clip_frac = clipped_mask.mean() 200 | 201 | loss = policy_loss + 0.001 * entropy_loss + critic_loss 202 | 203 | info = dict(ploss=policy_loss, entr=-entropy_loss, vloss=critic_loss, 204 | approx_kl=approx_kl, cf=clip_frac) 205 | 206 | return loss, info 207 | 208 | def ppo_loss_batch(p_params, v_params, batch): 209 | out = jax.pmap(partial(ppo_loss, p_params, v_params))(batch) 210 | loss, info = jax.tree_map(lambda x: x.mean(), out) 211 | return loss, info 212 | 213 | ppo_loss_grad = jax.jit(jax.value_and_grad(ppo_loss_batch, argnums=[0,1], has_aux=True)) 214 | 215 | def optim_update_fcn(optim): 216 | @jax.jit 217 | def update_step(params, grads, opt_state): 218 | grads, opt_state = optim.update(grads, opt_state) 219 | params = optax.apply_updates(params, grads) 220 | return params, opt_state 221 | return update_step 222 | 223 | @jax.jit 224 | def ppo_step(p_params, v_params, p_opt_state, v_opt_state, batch): 225 | (loss, info), (p_grads, v_grads) = ppo_loss_grad(p_params, v_params, batch) 226 | p_params, p_opt_state = p_update_step(p_params, p_grads, p_opt_state) 227 | v_params, v_opt_state = v_update_step(v_params, v_grads, v_opt_state) 228 | return loss, info, p_params, v_params, p_opt_state, v_opt_state 229 | 230 | class Worker: 231 | def __init__(self, n_steps): 232 | self.n_steps = n_steps 233 | self.buffer = Batch_Vector_ReplayBuffer(n_step_rollout, n_envs) 234 | self.env = make_env(n_envs) 235 | # self.env_reset = jax.jit(self.env.reset) 236 | # self.env_step = jax.jit(self.env.step) 237 | 238 | def rollout(self, p_params, v_params, rng): 239 | self.buffer.clear() 240 | obs = self.env.reset() 241 | 242 | for _ in range(n_step_rollout): 243 | a, log_prob = pmap_policy(p_params, obs, rng) 244 | 245 | obs2, r, done, _ = self.env.step(a) 246 | self.buffer.push((obs, a, r, obs2, done, log_prob)) 247 | 248 | obs = obs2 249 | if done.all(): break # .any() == .all() in this case 250 | 251 | rollout = self.buffer.contents() 252 | 253 | advantages, v_target = jax.pmap(compute_advantage_targets, in_axes=(None, 0))(v_params, rollout) 254 | (obs, a, r, _, _, log_prob) = rollout 255 | rollout = (obs, a, log_prob, v_target, advantages) 256 | # (n_envs, length, _) -> (n_envs * length, _) 257 | rollout = jax.tree_map(lambda x: x.reshape(-1, x.shape[-1]), rollout) 258 | 259 | return rollout 260 | 261 | #%% 262 | seed = 90897 #onp.random.randint(1e5) # 90897 works very well 263 | gamma = 0.99 264 | lmbda = 0.95 265 | eps = 0.2 266 | batch_size = 128 267 | policy_lr = 1e-3 268 | v_lr = 1e-3 269 | max_n_steps = 1e6 270 | n_step_rollout = 1000 # env._max_episode_steps 271 | save_models = False 272 | n_envs = 8 273 | 274 | rng = jax.random.PRNGKey(seed) 275 | onp.random.seed(seed) 276 | 277 | obs = np.array(onp.random.randn(obs_dim)) # dummy input 278 | p_params = policy_fcn.init(rng, obs) 279 | v_params = critic_fcn.init(rng, obs) 280 | 281 | worker = Worker(n_step_rollout) 282 | 283 | ## optimizers 284 | optimizer = lambda lr: optax.chain( 285 | optax.clip_by_global_norm(0.5), 286 | optax.scale_by_adam(), 287 | optax.scale(-lr), 288 | ) 289 | p_optim = optimizer(policy_lr) 290 | v_optim = optimizer(v_lr) 291 | 292 | p_opt_state = p_optim.init(p_params) 293 | v_opt_state = v_optim.init(v_params) 294 | 295 | p_update_step = optim_update_fcn(p_optim) 296 | v_update_step = optim_update_fcn(v_optim) 297 | 298 | eval_env = make_env(1) 299 | 300 | import pathlib 301 | model_path = pathlib.Path(f'./models/ppo/{env_name}') 302 | model_path.mkdir(exist_ok=True, parents=True) 303 | 304 | #%% 305 | from torch.utils.tensorboard import SummaryWriter 306 | writer = SummaryWriter(comment=f'ppo_brax_{env_name}_seed={seed}') 307 | 308 | epi_i = 0 309 | step_i = 0 310 | from tqdm.notebook import tqdm 311 | pbar = tqdm(total=max_n_steps) 312 | while step_i < max_n_steps: 313 | rng, subkey = jax.random.split(rng, 2) 314 | rollout = worker.rollout(p_params, v_params, subkey) 315 | 316 | for batch in rollout2batches(rollout, batch_size): 317 | loss, info, p_params, v_params, p_opt_state, v_opt_state = \ 318 | ppo_step(p_params, v_params, p_opt_state, v_opt_state, batch) 319 | step_i += 1 320 | pbar.update(1) 321 | writer.add_scalar('loss/loss', loss.item(), step_i) 322 | for k in info.keys(): 323 | writer.add_scalar(f'info/{k}', info[k].item(), step_i) 324 | 325 | rng, subrng = jax.random.split(rng) 326 | reward = eval(p_params, eval_env, subrng) 327 | writer.add_scalar('eval/total_reward', reward.item(), step_i) 328 | 329 | if save_models and (epi_i == 0 or reward > max_reward): 330 | max_reward = reward 331 | save_path = str(model_path/f'params_{max_reward:.2f}') 332 | print(f'Saving model {save_path}...') 333 | with open(save_path, 'wb') as f: 334 | cloudpickle.dump((p_params, v_params), f) 335 | 336 | epi_i += 1 -------------------------------------------------------------------------------- /ppo/ppo_disc.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import jax 3 | import jax.numpy as np 4 | import numpy as onp 5 | import distrax 6 | import optax 7 | import gym 8 | from functools import partial 9 | import cloudpickle 10 | 11 | from jax.config import config 12 | config.update("jax_enable_x64", True) 13 | config.update("jax_debug_nans", True) # break on nans 14 | 15 | # env_name = 'CartPole-v0' 16 | # env = gym.make(env_name) 17 | 18 | from env import Navigation2DEnv_Disc 19 | env_name = 'Navigation2D' 20 | def make_env(init_task=False): 21 | env = Navigation2DEnv_Disc(max_n_steps=200) 22 | 23 | if init_task: 24 | env.seed(0) 25 | task = env.sample_tasks(1)[0] 26 | print(f'[LOGGER]: task = {task}') 27 | env.reset_task(task) 28 | 29 | # log max reward 30 | goal = env._task['goal'] 31 | x, y = goal 32 | n_right = x / 0.1 33 | n_up = y / 0.1 34 | action_seq = [] 35 | for _ in range(int(abs(n_right))): action_seq.append(3 if n_right > 0 else 2) 36 | for _ in range(int(abs(n_up))): action_seq.append(0 if n_up > 0 else 1) 37 | 38 | reward = 0 39 | step_count = 0 40 | env.reset() 41 | for a in action_seq: 42 | _, r, done, _ = env.step(a) 43 | reward += r 44 | step_count += 1 45 | if done: break 46 | assert done 47 | print(f'[LOGGER]: MAX_REWARD={reward} IN {step_count} STEPS') 48 | return env 49 | 50 | env = make_env(init_task=True) 51 | 52 | n_actions = env.action_space.n 53 | obs_dim = env.observation_space.shape[0] 54 | 55 | print(f'[LOGGER] n_actions: {n_actions} obs_dim: {obs_dim}') 56 | 57 | #%% 58 | import haiku as hk 59 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 60 | 61 | def _policy_fcn(s): 62 | pi = hk.Sequential([ 63 | hk.Linear(64), jax.nn.relu, 64 | hk.Linear(64), jax.nn.relu, 65 | hk.Linear(n_actions, w_init=init_final), jax.nn.softmax 66 | ])(s) 67 | return pi 68 | 69 | def _critic_fcn(s): 70 | v = hk.Sequential([ 71 | hk.Linear(64), jax.nn.relu, 72 | hk.Linear(64), jax.nn.relu, 73 | hk.Linear(1, w_init=init_final), 74 | ])(s) 75 | return v 76 | 77 | policy_fcn = hk.transform(_policy_fcn) 78 | policy_fcn = hk.without_apply_rng(policy_fcn) 79 | p_frwd = jax.jit(policy_fcn.apply) 80 | 81 | critic_fcn = hk.transform(_critic_fcn) 82 | critic_fcn = hk.without_apply_rng(critic_fcn) 83 | v_frwd = jax.jit(critic_fcn.apply) 84 | 85 | # %% 86 | @jax.jit 87 | def policy(params, obs, rng): 88 | pi = p_frwd(params, obs) 89 | dist = distrax.Categorical(probs=pi) 90 | a = dist.sample(seed=rng) 91 | log_prob = dist.log_prob(a) 92 | return a, log_prob 93 | 94 | def eval(p_frwd, params, env, rng): 95 | rewards = 0 96 | states = [] 97 | obs = env.reset() 98 | while True: 99 | states.append(obs) 100 | rng, subrng = jax.random.split(rng) 101 | a = policy(params, obs, subrng)[0].item() 102 | obs2, r, done, _ = env.step(a) 103 | obs = obs2 104 | rewards += r 105 | if done: break 106 | states = np.stack(states) 107 | return rewards, states 108 | 109 | class Vector_ReplayBuffer: 110 | def __init__(self, buffer_capacity): 111 | self.buffer_capacity = buffer_capacity = int(buffer_capacity) 112 | self.i = 0 113 | # obs, obs2, a, r, done 114 | self.splits = [obs_dim, obs_dim+1, obs_dim+1+1, obs_dim*2+1+1, obs_dim*2+1+1+1] 115 | self.clear() 116 | 117 | def push(self, sample): 118 | assert self.i < self.buffer_capacity # dont let it get full 119 | (obs, a, r, obs2, done, log_prob) = sample 120 | self.buffer[self.i] = onp.array([*obs, onp.array(a), onp.array(r), *obs2, float(done), onp.array(log_prob)]) 121 | self.i += 1 122 | 123 | def contents(self): 124 | return onp.split(self.buffer[:self.i], self.splits, axis=-1) 125 | 126 | def clear(self): 127 | self.i = 0 128 | self.buffer = onp.zeros((self.buffer_capacity, 2 * obs_dim + 1 + 2 + 1)) 129 | 130 | def shuffle_rollout(rollout): 131 | rollout_len = rollout[0].shape[0] 132 | idxs = onp.arange(rollout_len) 133 | onp.random.shuffle(idxs) 134 | rollout = jax.tree_map(lambda x: x[idxs], rollout, is_leaf=lambda x: hasattr(x, 'shape')) 135 | return rollout 136 | 137 | def rollout2batches(rollout, batch_size): 138 | rollout_len = rollout[0].shape[0] 139 | n_chunks = rollout_len // batch_size 140 | # shuffle / de-correlate 141 | rollout = shuffle_rollout(rollout) 142 | if n_chunks == 0: return rollout 143 | # batch 144 | batched_rollout = jax.tree_map(lambda x: np.array_split(x, n_chunks), rollout, is_leaf=lambda x: hasattr(x, 'shape')) 145 | for i in range(n_chunks): 146 | batch = [d[i] for d in batched_rollout] 147 | yield batch 148 | 149 | def discount_cumsum(l, discount): 150 | l = onp.array(l) 151 | for i in range(len(l) - 1)[::-1]: 152 | l[i] = l[i] + discount * l[i+1] 153 | return l 154 | 155 | def compute_advantage_targets(v_params, rollout): 156 | (obs, _, r, obs2, done, _) = rollout 157 | 158 | batch_v_fcn = jax.vmap(partial(v_frwd, v_params)) 159 | v_obs = batch_v_fcn(obs) 160 | v_obs2 = batch_v_fcn(obs2) 161 | 162 | # gae 163 | deltas = (r + (1 - done) * gamma * v_obs2) - v_obs 164 | deltas = jax.lax.stop_gradient(deltas) 165 | adv = discount_cumsum(deltas, discount=gamma * lmbda) 166 | adv = (adv - adv.mean()) / (adv.std() + 1e-8) 167 | 168 | # reward2go 169 | v_target = discount_cumsum(r, discount=gamma) 170 | 171 | return adv, v_target 172 | 173 | @jax.jit 174 | def ppo_loss(p_params, v_params, sample): 175 | (obs, a, old_log_prob, v_target, advantages) = sample 176 | 177 | ## critic loss 178 | v_obs = v_frwd(v_params, obs) 179 | critic_loss = 0.5 * ((v_obs - v_target) ** 2) 180 | 181 | ## policy losses 182 | pi = p_frwd(p_params, obs) 183 | dist = distrax.Categorical(probs=pi) 184 | # entropy 185 | entropy_loss = -dist.entropy() 186 | # policy gradient 187 | log_prob = dist.log_prob(a) 188 | 189 | ratio = np.exp(log_prob - old_log_prob) 190 | p_loss1 = ratio * advantages 191 | p_loss2 = np.clip(ratio, 1-eps, 1+eps) * advantages 192 | policy_loss = -np.fmin(p_loss1, p_loss2) 193 | 194 | loss = policy_loss + 0.001 * entropy_loss + critic_loss 195 | 196 | return loss.sum() 197 | 198 | def ppo_loss_batch(p_params, v_params, batch): 199 | return jax.vmap(partial(ppo_loss, p_params, v_params))(batch).mean() 200 | 201 | ppo_loss_grad = jax.jit(jax.value_and_grad(ppo_loss_batch, argnums=[0,1])) 202 | 203 | def update_step(params, grads, optim, opt_state): 204 | grads, opt_state = optim.update(grads, opt_state) 205 | params = optax.apply_updates(params, grads) 206 | return params, opt_state 207 | 208 | @jax.jit 209 | def ppo_step(p_params, v_params, p_opt_state, v_opt_state, batch): 210 | loss, (p_grads, v_grads) = ppo_loss_grad(p_params, v_params, batch) 211 | p_params, p_opt_state = update_step(p_params, p_grads, p_optim, p_opt_state) 212 | v_params, v_opt_state = update_step(v_params, v_grads, v_optim, v_opt_state) 213 | return loss, p_params, v_params, p_opt_state, v_opt_state 214 | 215 | # can be easily extended for _multi workers 216 | class Worker: 217 | def __init__(self, n_steps): 218 | self.n_steps = n_steps 219 | self.buffer = Vector_ReplayBuffer(1e6) 220 | # import pybullet_envs 221 | self.env = make_env() 222 | # self.env = gym.make(env_name) 223 | self.obs = self.env.reset() 224 | 225 | def rollout(self, p_params, v_params, rng): 226 | self.buffer.clear() 227 | 228 | for _ in range(self.n_steps): # rollout 229 | rng, subrng = jax.random.split(rng) 230 | a, log_prob = policy(p_params, self.obs, subrng) 231 | a = a.item() 232 | 233 | obs2, r, done, _ = self.env.step(a) 234 | 235 | self.buffer.push((self.obs, a, r, obs2, done, log_prob)) 236 | self.obs = obs2 237 | if done: 238 | self.obs = self.env.reset() 239 | 240 | # update rollout contents 241 | rollout = self.buffer.contents() 242 | advantages, v_target = compute_advantage_targets(v_params, rollout) 243 | (obs, a, r, _, _, log_prob) = rollout 244 | rollout = (obs, a, log_prob, v_target, advantages) 245 | 246 | return rollout 247 | 248 | #%% 249 | seed = onp.random.randint(1e5) 250 | 251 | batch_size = 64 252 | policy_lr = 1e-3 253 | v_lr = 1e-3 254 | gamma = 0.99 255 | lmbda = 0.95 256 | eps = 0.2 257 | max_n_steps = 1e6 258 | n_step_rollout = 100 #env._max_episode_steps 259 | 260 | rng = jax.random.PRNGKey(seed) 261 | onp.random.seed(seed) 262 | 263 | obs = env.reset() # dummy input 264 | p_params = policy_fcn.init(rng, obs) 265 | v_params = critic_fcn.init(rng, obs) 266 | 267 | worker = Worker(n_step_rollout) 268 | 269 | ## optimizers 270 | # optax.clip_by_global_norm(0.5), 271 | optimizer = lambda lr: optax.chain( 272 | optax.scale_by_adam(), 273 | optax.scale(-lr), 274 | ) 275 | p_optim = optimizer(policy_lr) 276 | v_optim = optimizer(v_lr) 277 | 278 | # p_optim = optax.sgd(policy_lr) 279 | # v_optim = optax.sgd(v_lr) 280 | 281 | p_opt_state = p_optim.init(p_params) 282 | v_opt_state = v_optim.init(v_params) 283 | 284 | import pathlib 285 | model_path = pathlib.Path(f'./models/ppo/{env_name}') 286 | model_path.mkdir(exist_ok=True, parents=True) 287 | 288 | #%% 289 | from torch.utils.tensorboard import SummaryWriter 290 | writer = SummaryWriter(comment=f'ppo_{env_name}_seed={seed}') 291 | 292 | #%% 293 | epi_i = 0 294 | step_i = 0 295 | from tqdm import tqdm 296 | pbar = tqdm(total=max_n_steps) 297 | while step_i < max_n_steps: 298 | rng, subkey = jax.random.split(rng, 2) 299 | rollout = worker.rollout(p_params, v_params, subkey) 300 | 301 | for batch in rollout2batches(rollout, batch_size): 302 | loss, p_params, v_params, p_opt_state, v_opt_state = \ 303 | ppo_step(p_params, v_params, p_opt_state, v_opt_state, batch) 304 | step_i += 1 305 | pbar.update(1) 306 | writer.add_scalar('loss/loss', loss.item(), step_i) 307 | 308 | rng, subrng = jax.random.split(rng) 309 | reward, states = eval(p_params, env, subrng) 310 | writer.add_scalar('eval/total_reward', reward, step_i) 311 | 312 | import matplotlib.pyplot as plt 313 | plt.scatter(*env._goal, marker='*') 314 | plt.scatter(states[:, 0], states[:, 1], color='b') 315 | plt.plot(states[:, 0], states[:, 1], color='b') 316 | plt.savefig('tmp.png') 317 | plt.close() 318 | import cv2 319 | img = cv2.imread('tmp.png').transpose(-1, 0, 1) 320 | writer.add_image('eval_rollout', img, step_i) 321 | 322 | if epi_i == 0 or reward > max_reward: 323 | max_reward = reward 324 | with open(str(model_path/f'params_{max_reward:.2f}'), 'wb') as f: 325 | cloudpickle.dump((p_params, v_params), f) 326 | 327 | epi_i += 1 328 | 329 | # %% 330 | # %% 331 | # %% -------------------------------------------------------------------------------- /ppo/ppo_multi.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import jax 3 | import jax.numpy as np 4 | import numpy as onp 5 | import distrax 6 | import optax 7 | import gym 8 | from functools import partial 9 | import ray 10 | import pybullet_envs 11 | 12 | from jax.config import config 13 | config.update("jax_enable_x64", True) 14 | config.update("jax_debug_nans", True) # break on nans 15 | 16 | ray.init() 17 | 18 | #%% 19 | env_name = 'Pendulum-v0' 20 | # env_name = 'HalfCheetahBulletEnv-v0' 21 | env = gym.make(env_name) 22 | 23 | # from env import Navigation2DEnv 24 | # env_name = 'Navigation2D' 25 | # def make_env(): 26 | # env = Navigation2DEnv() 27 | # env.seed(0) 28 | # task = env.sample_tasks(1)[0] 29 | # env.reset_task(task) 30 | # return env 31 | # env = make_env() 32 | 33 | n_actions = env.action_space.shape[0] 34 | obs_dim = env.observation_space.shape[0] 35 | 36 | a_high = env.action_space.high[0] 37 | a_low = env.action_space.low[0] 38 | 39 | print(f'[LOGGER] a_high: {a_high} a_low: {a_low} n_actions: {n_actions} obs_dim: {obs_dim}') 40 | assert -a_high == a_low 41 | 42 | #%% 43 | import haiku as hk 44 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 45 | 46 | def _policy_fcn(s): 47 | log_std = hk.get_parameter("log_std", shape=[n_actions,], init=np.ones) 48 | mu = hk.Sequential([ 49 | hk.Linear(64), jax.nn.relu, 50 | hk.Linear(64), jax.nn.relu, 51 | hk.Linear(n_actions, w_init=init_final), np.tanh 52 | ])(s) * a_high 53 | sig = np.exp(log_std) 54 | return mu, sig 55 | 56 | def _critic_fcn(s): 57 | v = hk.Sequential([ 58 | hk.Linear(64), jax.nn.relu, 59 | hk.Linear(64), jax.nn.relu, 60 | hk.Linear(1), 61 | ])(s) 62 | return v 63 | 64 | policy_fcn = hk.transform(_policy_fcn) 65 | policy_fcn = hk.without_apply_rng(policy_fcn) 66 | p_frwd = jax.jit(policy_fcn.apply) 67 | 68 | critic_fcn = hk.transform(_critic_fcn) 69 | critic_fcn = hk.without_apply_rng(critic_fcn) 70 | v_frwd = jax.jit(critic_fcn.apply) 71 | 72 | class Vector_ReplayBuffer: 73 | def __init__(self, buffer_capacity): 74 | self.buffer_capacity = buffer_capacity = int(buffer_capacity) 75 | self.i = 0 76 | # obs, obs2, a, r, done 77 | self.splits = [obs_dim, obs_dim+n_actions, obs_dim+n_actions+1, obs_dim*2+1+n_actions, obs_dim*2+1+n_actions+1] 78 | self.clear() 79 | 80 | def push(self, sample): 81 | assert self.i < self.buffer_capacity # dont let it get full 82 | (obs, a, r, obs2, done, log_prob) = sample 83 | self.buffer[self.i] = onp.array([*obs, *onp.array(a), onp.array(r), *obs2, float(done), onp.array(log_prob)]) 84 | self.i += 1 85 | 86 | def contents(self): 87 | return onp.split(self.buffer[:self.i], self.splits, axis=-1) 88 | 89 | def clear(self): 90 | self.i = 0 91 | self.buffer = onp.zeros((self.buffer_capacity, 2 * obs_dim + n_actions + 2 + 1)) 92 | 93 | # %% 94 | def policy(params, obs, rng): 95 | mu, sig = p_frwd(params, obs) 96 | dist = distrax.MultivariateNormalDiag(mu, sig) 97 | a = dist.sample(seed=rng) 98 | a = np.clip(a, a_low, a_high) 99 | log_prob = dist.log_prob(a) 100 | return a, log_prob 101 | 102 | @jax.jit 103 | def eval_policy(params, obs, _): 104 | a, _ = p_frwd(params, obs) 105 | a = np.clip(a, a_low, a_high) 106 | return a, None 107 | 108 | def eval(params, env, rng): 109 | rewards = 0 110 | obs = env.reset() 111 | while True: 112 | rng, subrng = jax.random.split(rng) 113 | a = policy(params, obs, subrng)[0] 114 | a = onp.array(a) 115 | obs2, r, done, _ = env.step(a) 116 | obs = obs2 117 | rewards += r 118 | if done: break 119 | return rewards 120 | 121 | def shuffle_rollout(rollout): 122 | rollout_len = rollout[0].shape[0] 123 | idxs = onp.arange(rollout_len) 124 | onp.random.shuffle(idxs) 125 | rollout = jax.tree_map(lambda x: x[idxs], rollout, is_leaf=lambda x: hasattr(x, 'shape')) 126 | return rollout 127 | 128 | def rollout2batches(rollout, batch_size): 129 | rollout_len = rollout[0].shape[0] 130 | n_chunks = rollout_len // batch_size 131 | # shuffle / de-correlate 132 | rollout = shuffle_rollout(rollout) 133 | if n_chunks == 0: return rollout 134 | # batch 135 | batched_rollout = jax.tree_map(lambda x: np.array_split(x, n_chunks), rollout, is_leaf=lambda x: hasattr(x, 'shape')) 136 | for i in range(n_chunks): 137 | batch = [d[i] for d in batched_rollout] 138 | yield batch 139 | 140 | def discount_cumsum(l, discount): 141 | l = onp.array(l) 142 | for i in range(len(l) - 1)[::-1]: 143 | l[i] = l[i] + discount * l[i+1] 144 | return l 145 | 146 | @jax.jit 147 | def ppo_loss(p_params, v_params, sample): 148 | (obs, a, old_log_prob, v_target, advantages) = sample 149 | 150 | ## critic loss 151 | v_obs = v_frwd(v_params, obs) 152 | critic_loss = (0.5 * ((v_obs - v_target) ** 2)).sum() 153 | 154 | ## policy losses 155 | mu, sig = p_frwd(p_params, obs) 156 | dist = distrax.MultivariateNormalDiag(mu, sig) 157 | # entropy 158 | entropy_loss = -dist.entropy() 159 | # policy gradient 160 | log_prob = dist.log_prob(a) 161 | 162 | approx_kl = (old_log_prob - log_prob).sum() 163 | ratio = np.exp(log_prob - old_log_prob) 164 | p_loss1 = ratio * advantages 165 | p_loss2 = np.clip(ratio, 1-eps, 1+eps) * advantages 166 | policy_loss = -np.fmin(p_loss1, p_loss2).sum() 167 | 168 | clipped_mask = ((ratio > 1+eps) | (ratio < 1-eps)).astype(np.float32) 169 | clip_frac = clipped_mask.mean() 170 | 171 | loss = policy_loss + 0.001 * entropy_loss + critic_loss 172 | 173 | info = dict(ploss=policy_loss, entr=-entropy_loss, vloss=critic_loss, 174 | approx_kl=approx_kl, cf=clip_frac) 175 | 176 | return loss, info 177 | 178 | def ppo_loss_batch(p_params, v_params, batch): 179 | out = jax.vmap(partial(ppo_loss, p_params, v_params))(batch) 180 | loss, info = jax.tree_map(lambda x: x.mean(), out) 181 | return loss, info 182 | 183 | ppo_loss_grad = jax.jit(jax.value_and_grad(ppo_loss_batch, argnums=[0,1], has_aux=True)) 184 | 185 | def optim_update_fcn(optim): 186 | @jax.jit 187 | def update_step(params, grads, opt_state): 188 | grads, opt_state = optim.update(grads, opt_state) 189 | params = optax.apply_updates(params, grads) 190 | return params, opt_state 191 | return update_step 192 | 193 | @jax.jit 194 | def ppo_step(p_params, v_params, p_opt_state, v_opt_state, batch): 195 | (loss, info), (p_grads, v_grads) = ppo_loss_grad(p_params, v_params, batch) 196 | p_params, p_opt_state = p_update_step(p_params, p_grads, p_opt_state) 197 | v_params, v_opt_state = v_update_step(v_params, v_grads, v_opt_state) 198 | return loss, info, p_params, v_params, p_opt_state, v_opt_state 199 | 200 | @ray.remote 201 | class Worker: 202 | def __init__(self, n_steps): 203 | self.n_steps = n_steps 204 | self.p_frwd = jax.jit(policy_fcn.apply) 205 | self.v_frwd = jax.jit(critic_fcn.apply) 206 | 207 | self.buffer = Vector_ReplayBuffer(1e6) 208 | import pybullet_envs 209 | self.env = gym.make(env_name) 210 | # self.env = make_env() 211 | self.obs = self.env.reset() 212 | 213 | def compute_advantage_targets(self, v_params, rollout): 214 | (obs, _, r, obs2, done, _) = rollout 215 | 216 | batch_v_fcn = jax.vmap(partial(self.v_frwd, v_params)) # need in class bc of this line i.e v_frwd 217 | v_obs = batch_v_fcn(obs) 218 | v_obs2 = batch_v_fcn(obs2) 219 | 220 | # gae 221 | deltas = (r + (1 - done) * gamma * v_obs2) - v_obs 222 | deltas = jax.lax.stop_gradient(deltas) 223 | adv = discount_cumsum(deltas, discount=gamma * lmbda) 224 | adv = (adv - adv.mean()) / (adv.std() + 1e-8) 225 | 226 | # reward2go 227 | v_target = discount_cumsum(r, discount=gamma) 228 | 229 | return adv, v_target 230 | 231 | def rollout(self, p_params, v_params, rng): 232 | self.buffer.clear() 233 | 234 | for _ in range(self.n_steps): # rollout 235 | rng, subrng = jax.random.split(rng) 236 | 237 | mu, sig = self.p_frwd(p_params, self.obs) 238 | dist = distrax.MultivariateNormalDiag(mu, sig) 239 | a = dist.sample(seed=subrng) 240 | a = np.clip(a, a_low, a_high) 241 | log_prob = dist.log_prob(a) 242 | a = onp.array(a) 243 | 244 | obs2, r, done, _ = self.env.step(a) 245 | 246 | self.buffer.push((self.obs, a, r, obs2, done, log_prob)) 247 | self.obs = obs2 248 | if done: 249 | self.obs = self.env.reset() 250 | 251 | # update rollout contents 252 | rollout = self.buffer.contents() 253 | advantages, v_target = self.compute_advantage_targets(v_params, rollout) 254 | (obs, a, r, _, _, log_prob) = rollout 255 | rollout = (obs, a, log_prob, v_target, advantages) 256 | 257 | return rollout 258 | 259 | #%% 260 | n_envs = 4 261 | 262 | seed = onp.random.randint(1e5) # 90897 works well for pendulum 263 | gamma = 0.99 264 | lmbda = 0.95 265 | eps = 0.2 266 | batch_size = 128 267 | policy_lr = 1e-3 268 | v_lr = 1e-3 269 | max_n_steps = 1e6 270 | n_step_rollout = 2048 #env._max_episode_steps 271 | 272 | rng = jax.random.PRNGKey(seed) 273 | onp.random.seed(seed) 274 | 275 | obs = env.reset() # dummy input 276 | p_params = policy_fcn.init(rng, obs) 277 | v_params = critic_fcn.init(rng, obs) 278 | 279 | ## optimizers 280 | optimizer = lambda lr: optax.chain( 281 | optax.clip_by_global_norm(0.5), 282 | optax.scale_by_adam(), 283 | optax.scale(-lr), 284 | ) 285 | p_optim = optimizer(policy_lr) 286 | v_optim = optimizer(v_lr) 287 | 288 | p_opt_state = p_optim.init(p_params) 289 | v_opt_state = v_optim.init(v_params) 290 | 291 | p_update_step = optim_update_fcn(p_optim) 292 | v_update_step = optim_update_fcn(v_optim) 293 | 294 | #%% 295 | from torch.utils.tensorboard import SummaryWriter 296 | writer = SummaryWriter(comment=f'ppo_multi_{n_envs}_{env_name}_seed={seed}_nrollout={n_step_rollout}') 297 | 298 | #%% 299 | workers = [Worker.remote(n_step_rollout) for _ in range(n_envs)] 300 | 301 | step_i = 0 302 | from tqdm import tqdm 303 | pbar = tqdm(total=max_n_steps) 304 | while step_i < max_n_steps: 305 | ## rollout 306 | rng, *subkeys = jax.random.split(rng, 1+n_envs+1) # +1 for eval rollout 307 | rollouts = ray.get([workers[i].rollout.remote(p_params, v_params, subkeys[i]) for i in range(n_envs)]) 308 | rollout = jax.tree_multimap(lambda *a: np.concatenate(a), *rollouts, is_leaf=lambda node: hasattr(node, 'shape')) 309 | 310 | ## update 311 | for batch in rollout2batches(rollout, batch_size): 312 | loss, info, p_params, v_params, p_opt_state, v_opt_state = \ 313 | ppo_step(p_params, v_params, p_opt_state, v_opt_state, batch) 314 | step_i += 1 315 | pbar.update(1) 316 | writer.add_scalar('loss/loss', loss.item(), step_i) 317 | for k in info.keys(): 318 | writer.add_scalar(f'info/{k}', info[k].item(), step_i) 319 | 320 | reward = eval(p_params, env, subkeys[-1]) 321 | writer.add_scalar('eval/total_reward', reward.item(), step_i) 322 | 323 | #%% 324 | #%% 325 | #%% -------------------------------------------------------------------------------- /ppo/tmp.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import jax 3 | import jax.numpy as jnp 4 | import numpy as onp 5 | import distrax 6 | import optax 7 | import gym 8 | from functools import partial 9 | import cloudpickle 10 | 11 | from jax.config import config 12 | config.update("jax_enable_x64", True) 13 | config.update("jax_debug_nans", True) # break on nans 14 | 15 | import brax 16 | from brax.envs import _envs, create_gym_env 17 | 18 | for env_name, env_class in _envs.items(): 19 | env_id = f"brax_{env_name}-v0" 20 | entry_point = partial(create_gym_env, env_name=env_name) 21 | if env_id not in gym.envs.registry.env_specs: 22 | print(f"Registring brax's '{env_name}' env under id '{env_id}'.") 23 | gym.register(env_id, entry_point=entry_point) 24 | 25 | #%% 26 | batch_size = 2 27 | env = gym.make("brax_halfcheetah-v0", batch_size=batch_size, \ 28 | episode_length=10) # this is very slow 29 | 1 30 | 31 | #%% 32 | buffer.clear() 33 | obs = env.reset() # this can be relatively slow (~10 secs) 34 | for _ in range(100): 35 | a = onp.random.randn(*env.action_space.shape) 36 | obs2, r, done, _ = env.step(a) 37 | sample = (obs, a, r, obs2, done, r) 38 | buffer.push(sample) 39 | if done.any(): break 40 | 41 | (obs, a, r, obs2, done, r) = buffer.contents() 42 | done 43 | 44 | #%% 45 | #%% 46 | obs_dim = env.observation_space.shape[-1] 47 | n_actions = env.action_space.shape[-1] 48 | 49 | class Vector_ReplayBuffer: 50 | def __init__(self, buffer_capacity): 51 | self.buffer_capacity = buffer_capacity = int(buffer_capacity) 52 | self.i = 0 53 | # obs, obs2, a, r, done 54 | self.splits = [obs_dim, obs_dim+n_actions, obs_dim+n_actions+1, obs_dim*2+1+n_actions, obs_dim*2+1+n_actions+1] 55 | self.clear() 56 | 57 | def push(self, sample): 58 | assert self.i < self.buffer_capacity # dont let it get full 59 | (obs, a, r, obs2, done, log_prob) = sample 60 | self.buffer[:, self.i] = onp.concatenate([obs, a, r[:, None], obs2, done.astype(float)[:, None], log_prob[:, None]], -1) 61 | self.i += 1 62 | 63 | def contents(self): 64 | return onp.split(self.buffer[:, :self.i], self.splits, axis=-1) 65 | 66 | def clear(self): 67 | self.i = 0 68 | self.buffer = onp.zeros((batch_size, self.buffer_capacity, 2 * obs_dim + n_actions + 2 + 1)) 69 | 70 | buffer = Vector_ReplayBuffer(1e6) 71 | 72 | #%% 73 | class Worker: 74 | def __init__(self, n_steps): 75 | self.n_steps = n_steps 76 | self.buffer = Vector_ReplayBuffer(n_step_rollout) 77 | self.env = make_env() 78 | self.obs = self.env.reset() 79 | 80 | def rollout(self, p_params, v_params, rng): 81 | self.buffer.clear() 82 | 83 | for _ in range(self.n_steps): # rollout 84 | rng, subrng = jax.random.split(rng) 85 | a, log_prob = policy(p_params, self.obs, subrng) 86 | a = onp.array(a) 87 | 88 | obs2, r, done, _ = self.env.step(a) 89 | 90 | self.buffer.push((self.obs, a, r, obs2, done, log_prob)) 91 | self.obs = obs2 92 | if done: 93 | self.obs = self.env.reset() 94 | 95 | # update rollout contents 96 | rollout = self.buffer.contents() 97 | advantages, v_target = compute_advantage_targets(v_params, rollout) 98 | (obs, a, r, _, _, log_prob) = rollout 99 | rollout = (obs, a, log_prob, v_target, advantages) 100 | 101 | return rollout -------------------------------------------------------------------------------- /qlearn.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import gym 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import matplotlib.pyplot as plt 7 | from torch.distributions.categorical import Categorical 8 | import copy 9 | from numpngw import write_apng 10 | 11 | #%% 12 | env = gym.make('CartPole-v1') 13 | # env = gym.make('LunarLander-v2') 14 | 15 | T = 200 16 | num_episodes = 5000 17 | 18 | update_target_every = 2 19 | batch_size = 64 20 | gamma = 0.99 21 | tau = 0.999 22 | max_replay_buffer_size = 50000 23 | 24 | eval_every = 200 25 | 26 | replay_buffer = [] 27 | rollout_rewards = [] 28 | 29 | # loss_func = nn.MSELoss() 30 | loss_func = nn.SmoothL1Loss() 31 | 32 | dim_in = env.observation_space.shape[0] 33 | dim_out = env.action_space.n 34 | qnet = nn.Sequential( 35 | nn.Linear(dim_in, 32), 36 | nn.ReLU(), 37 | nn.Linear(32, 32), 38 | nn.ReLU(), 39 | nn.Linear(32, dim_out), 40 | ) 41 | optim = torch.optim.Adam(qnet.parameters(), lr=1e-3) 42 | 43 | qnet_target = copy.deepcopy(qnet) 44 | 45 | for epi_i in range(num_episodes): 46 | # rollout 47 | obs = env.reset() 48 | for _ in range(T): 49 | # a = env.action_space.sample() # random policy rollout 50 | 51 | tobs = torch.from_numpy(obs).float() 52 | with torch.no_grad(): 53 | a_logits = qnet(tobs) 54 | a = Categorical(logits=a_logits).sample().numpy() 55 | 56 | obs2, r, done, _ = env.step(a) 57 | 58 | t = (obs, a, obs2, r, done) # add to buffer 59 | replay_buffer.append(t) 60 | 61 | if done: break 62 | obs = obs2 63 | 64 | # optim 65 | idxs = np.random.randint(len(replay_buffer), size=(batch_size,)) 66 | trans = [replay_buffer[i] for i in idxs] 67 | 68 | def extract(trans, i): return torch.from_numpy(np.array([t[i] for t in trans])) 69 | obst = extract(trans, 0).float() 70 | at = extract(trans, 1) 71 | obs2t = extract(trans, 2).float() 72 | rt = extract(trans, 3).float() 73 | donet = extract(trans, 4).float() 74 | 75 | for i in range(len(idxs)): 76 | dql = qnet(obst[i]).max(-1)[-1] # idxs of max 77 | qtarget = qnet_target(obs2t[i])[dql] 78 | y = rt[i] + (1 - donet[i]) * gamma * qtarget 79 | y = y.detach() # stop grad 80 | 81 | ypred = qnet(obst[i])[at[i]] 82 | 83 | loss = loss_func(ypred, y) 84 | 85 | optim.zero_grad() 86 | loss.backward() 87 | for param in qnet.parameters(): # gradient clipping 88 | param.grad.data.clamp_(-1, 1) 89 | optim.step() 90 | 91 | # # polyak avging -- worse perf. than hard update 92 | # for sp, tp in zip(qnet.parameters(), qnet_target.parameters()): 93 | # tp.data.copy_(tau * sp.data + (1.0 - tau) * tp.data) 94 | 95 | # hard updates 96 | if (epi_i+1) % update_target_every == 0: 97 | qnet_target = copy.deepcopy(qnet) 98 | 99 | if (epi_i+1) % eval_every == 0: 100 | eval_reward = 0 101 | obs = env.reset() 102 | while True: 103 | tobs = torch.from_numpy(obs).float() 104 | a_space = qnet(tobs) 105 | a = a_space.argmax(-1) 106 | 107 | obs2, r, done, _ = env.step(a.numpy()) 108 | 109 | eval_reward += r 110 | obs = obs2 111 | 112 | if done: break 113 | 114 | print(f'Total Reward @ step {epi_i}: {eval_reward}') 115 | rollout_rewards.append(eval_reward) 116 | 117 | 118 | if len(replay_buffer) > max_replay_buffer_size: 119 | print('cleaning up replay buffer...') 120 | cut_idx = int(max_replay_buffer_size * 0.5) # remove half of exps 121 | replay_buffer = replay_buffer[cut_idx:] 122 | 123 | # %% 124 | imgs = [] 125 | eval_reward = 0 126 | obs = env.reset() 127 | while True: 128 | img = env.render(mode='rgb_array') 129 | imgs.append(img) 130 | 131 | tobs = torch.from_numpy(obs).float() 132 | a_space = qnet(tobs) 133 | a = a_space.argmax(-1) 134 | 135 | obs2, r, done, _ = env.step(a.numpy()) 136 | 137 | eval_reward += r 138 | obs = obs2 139 | 140 | if done: break 141 | 142 | print(f'Total Reward: {eval_reward:.2f} #frames: {len(imgs)}') 143 | print('writing...') 144 | write_apng('dqn_cartpole.png', imgs, delay=20) 145 | 146 | # %% 147 | plt.plot(rollout_rewards) 148 | plt.show() 149 | 150 | # %% 151 | # %% 152 | # %% 153 | -------------------------------------------------------------------------------- /reinforce/cartpole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/reinforce/cartpole.png -------------------------------------------------------------------------------- /reinforce/cartpole.sh: -------------------------------------------------------------------------------- 1 | python policy_grad.py --random_seed --exp_name seed1 --runs_folder cartpole_reinforce && 2 | python policy_grad.py --random_seed --exp_name seed2 --runs_folder cartpole_reinforce && 3 | python policy_grad.py --random_seed --exp_name seed3 --runs_folder cartpole_reinforce && 4 | python policy_grad.py --random_seed --exp_name seed4 --runs_folder cartpole_reinforce -------------------------------------------------------------------------------- /reinforce/env.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import numpy as np 3 | import gym 4 | 5 | from gym import spaces 6 | from gym.utils import seeding 7 | 8 | class Navigation2DEnv(gym.Env): 9 | """2D navigation problems, as described in [1]. The code is adapted from 10 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/maml_examples/point_env_randgoal.py 11 | At each time step, the 2D agent takes an action (its velocity, clipped in 12 | [-0.1, 0.1]), and receives a penalty equal to its L2 distance to the goal 13 | position (ie. the reward is `-distance`). The 2D navigation tasks are 14 | generated by sampling goal positions from the uniform distribution 15 | on [-0.5, 0.5]^2. 16 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 17 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 18 | (https://arxiv.org/abs/1703.03400) 19 | """ 20 | def __init__(self, task={}, low=-0.5, high=0.5, max_n_steps=100): 21 | super(Navigation2DEnv, self).__init__() 22 | self.low = low 23 | self.high = high 24 | self.max_n_steps = max_n_steps 25 | self._step_count = 0 26 | 27 | self.observation_space = spaces.Box(low=-np.inf, high=np.inf, 28 | shape=(2,), dtype=np.float32) 29 | self.action_space = spaces.Box(low=-0.1, high=0.1, 30 | shape=(2,), dtype=np.float32) 31 | 32 | self._task = task 33 | self._goal = task.get('goal', np.zeros(2, dtype=np.float32)) 34 | self._state = np.zeros(2, dtype=np.float32) 35 | self.seed() 36 | 37 | def seed(self, seed=None): 38 | self.np_random, seed = seeding.np_random(seed) 39 | return [seed] 40 | 41 | def sample_tasks(self, num_tasks): 42 | goals = self.np_random.uniform(self.low, self.high, size=(num_tasks, 2)) 43 | tasks = [{'goal': goal} for goal in goals] 44 | return tasks 45 | 46 | def reset_task(self, task): 47 | self._task = task 48 | self._goal = task['goal'] 49 | 50 | def reset(self): 51 | self._step_count = 0 52 | self._state = np.zeros(2, dtype=np.float32) 53 | return self._state 54 | 55 | def step(self, action): 56 | action = np.clip(action, -0.1, 0.1) 57 | action = np.array(action) 58 | assert self.action_space.contains(action) 59 | self._state = self._state + action 60 | 61 | diff = self._state - self._goal 62 | reward = -np.sqrt((diff**2).sum()) 63 | done = (np.abs(diff) < 0.01).sum() == 2 64 | 65 | done = done or self._step_count >= self.max_n_steps 66 | self._step_count += 1 67 | 68 | return self._state, reward, done, {'task': self._task} 69 | -------------------------------------------------------------------------------- /reinforce/jax2.py: -------------------------------------------------------------------------------- 1 | ## different implementation version of batch REINFORCE (still works and is 2x faster) 2 | 3 | #%% 4 | import jax 5 | import jax.numpy as np 6 | import numpy as onp 7 | import distrax 8 | import optax 9 | import gym 10 | from functools import partial 11 | 12 | from jax.config import config 13 | config.update("jax_enable_x64", True) 14 | config.update("jax_debug_nans", True) # break on nans 15 | 16 | #%% 17 | env_name = 'CartPole-v0' 18 | env = gym.make(env_name) 19 | 20 | n_actions = env.action_space.n 21 | obs_dim = env.observation_space.shape[0] 22 | print(f'[LOGGER] n_actions: {n_actions} obs_dim: {obs_dim}') 23 | 24 | #%% 25 | import haiku as hk 26 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 27 | 28 | def _policy_fcn(obs): 29 | a_probs = hk.Sequential([ 30 | hk.Linear(32), jax.nn.relu, 31 | hk.Linear(32), jax.nn.relu, 32 | hk.Linear(n_actions, w_init=init_final), jax.nn.softmax 33 | ])(obs) 34 | return a_probs 35 | 36 | policy_fcn = hk.transform(_policy_fcn) 37 | policy_fcn = hk.without_apply_rng(policy_fcn) 38 | p_frwd = jax.jit(policy_fcn.apply) 39 | 40 | @jax.jit 41 | def update_step(params, grads, opt_state): 42 | grads, opt_state = p_optim.update(grads, opt_state) 43 | params = optax.apply_updates(params, grads) 44 | return params, opt_state 45 | 46 | def reward2go(r, gamma=0.99): 47 | for i in range(len(r) - 1)[::-1]: 48 | r[i] = r[i] + gamma * r[i+1] 49 | r = (r - r.mean()) / (r.std() + 1e-8) 50 | return r 51 | 52 | @jax.jit 53 | def policy(p_params, obs, rng): 54 | a_probs = p_frwd(p_params, obs) 55 | dist = distrax.Categorical(probs=a_probs) 56 | a = dist.sample(seed=rng) 57 | entropy = dist.entropy() 58 | return a, entropy 59 | 60 | def rollout(p_params, rng): 61 | global step_count 62 | 63 | observ, action, rew = [], [], [] 64 | obs = env.reset() 65 | while True: 66 | rng, subkey = jax.random.split(rng, 2) 67 | a, entropy = policy(p_params, obs, subkey) 68 | a = a.item() 69 | 70 | writer.add_scalar('policy/entropy', entropy.item(), step_count) 71 | 72 | obs2, r, done, _ = env.step(a) 73 | step_count += 1 74 | pbar.update(1) 75 | 76 | observ.append(obs) 77 | action.append(a) 78 | rew.append(r) 79 | 80 | if done: break 81 | obs = obs2 82 | 83 | obs = np.stack(observ) 84 | a = np.stack(action) 85 | r = onp.stack(rew) # 86 | 87 | return obs, a, r 88 | 89 | def reinforce_loss(p_params, obs, a, r): 90 | a_probs = p_frwd(p_params, obs) 91 | log_prob = distrax.Categorical(probs=a_probs).log_prob(a.astype(int)) 92 | loss = -(log_prob * r).sum() 93 | return loss 94 | 95 | from functools import partial 96 | def batch_reinforce_loss(params, batch): 97 | return jax.vmap(partial(reinforce_loss, params))(*batch).sum() 98 | 99 | # %% 100 | seed = onp.random.randint(1e5) 101 | policy_lr = 1e-3 102 | batch_size = 32 103 | max_n_steps = 100000 104 | 105 | rng = jax.random.PRNGKey(seed) 106 | onp.random.seed(seed) 107 | env.seed(seed) 108 | 109 | obs = env.reset() # dummy input 110 | p_params = policy_fcn.init(rng, obs) 111 | 112 | ## optimizers 113 | p_optim = optax.sgd(policy_lr) 114 | p_opt_state = p_optim.init(p_params) 115 | 116 | # %% 117 | from torch.utils.tensorboard import SummaryWriter 118 | writer = SummaryWriter(comment=f'reinforce_{env_name}_seed={seed}') 119 | 120 | # %% 121 | from tqdm import tqdm 122 | 123 | step_count = 0 124 | epi_i = 0 125 | 126 | pbar = tqdm(total=max_n_steps) 127 | loss_grad_fcn = jax.jit(jax.value_and_grad(batch_reinforce_loss)) 128 | 129 | while step_count < max_n_steps: 130 | 131 | trajs = [] 132 | for _ in range(batch_size): 133 | rng, subkey = jax.random.split(rng, 2) 134 | obs, a, r = rollout(p_params, subkey) 135 | writer.add_scalar('rollout/reward', r.sum().item(), epi_i) 136 | r = reward2go(r) 137 | trajs.append((obs, a, r)) 138 | epi_i += 1 139 | 140 | trajs = jax.tree_multimap(lambda *x: np.concatenate(x, 0), *trajs) 141 | loss, grads = loss_grad_fcn(p_params, trajs) 142 | p_params, p_opt_state = update_step(p_params, grads, p_opt_state) 143 | 144 | writer.add_scalar('loss/loss', loss.item(), step_count) 145 | step_count += 1 146 | 147 | # %% 148 | # %% 149 | # %% 150 | -------------------------------------------------------------------------------- /reinforce/policy_grad.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import gym 3 | from numpy.lib.shape_base import column_stack 4 | import torch 5 | import torch.nn as nn 6 | import matplotlib.pyplot as plt 7 | from torch.distributions.categorical import Categorical 8 | from numpngw import write_apng 9 | import numpy as np 10 | from absl import flags, app 11 | from tqdm import tqdm 12 | import wandb 13 | import pprint 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | #%% 17 | FLAGS = flags.FLAGS 18 | flags.DEFINE_integer('T', 200, 'max number of steps per rollout') 19 | flags.DEFINE_integer('n_steps', 1000 * 100, '') # 1k episodes -- 100 steps per episode 20 | flags.DEFINE_integer('batch_size', 64, '') 21 | flags.DEFINE_float('lr', 1e-3, '') 22 | flags.DEFINE_integer('seed', 0, '') 23 | flags.DEFINE_integer('grad_clip_v', -1, '') 24 | 25 | flags.DEFINE_string('exp_name', '', '') 26 | flags.DEFINE_string('runs_folder', 'runs/', '') 27 | 28 | flags.DEFINE_boolean('random_seed', False, '') 29 | flags.DEFINE_boolean('view_result', False, '') 30 | 31 | def main(_): 32 | T = FLAGS.T 33 | n_steps = FLAGS.n_steps 34 | batch_size = FLAGS.batch_size 35 | lr = FLAGS.lr 36 | seed = np.random.randint(low=0, high=100) if FLAGS.random_seed else FLAGS.seed 37 | grad_clip_v = FLAGS.grad_clip_v 38 | 39 | config = { 40 | 'n_steps': n_steps, 41 | 'batch_size': batch_size, 42 | 'n_timesteps_per_rollout': T, 43 | 'seed': seed, 44 | 'lr': lr, 45 | 'grad_clip_value': grad_clip_v, 46 | } 47 | 48 | print('Config: ') 49 | pprint.pprint(config) 50 | 51 | if FLAGS.exp_name: 52 | writer = SummaryWriter(f'{FLAGS.runs_folder}/{FLAGS.exp_name}') 53 | else: 54 | writer = SummaryWriter() 55 | 56 | for c in config: 57 | writer.add_text(c, str(config[c]), 0) 58 | 59 | torch.manual_seed(seed) 60 | np.random.seed(seed) 61 | 62 | env = gym.make('CartPole-v0') 63 | 64 | policy = nn.Sequential( 65 | nn.Linear(env.observation_space.shape[0], 32), 66 | nn.ReLU(), 67 | nn.Linear(32, 32), 68 | nn.ReLU(), 69 | nn.Linear(32, env.action_space.n), 70 | nn.Softmax(dim=-1), 71 | ) 72 | 73 | optim = torch.optim.SGD(policy.parameters(), lr=lr) 74 | optim.zero_grad() 75 | 76 | epi_count = 0 77 | step_count = 0 78 | 79 | pbar = tqdm(total=n_steps) 80 | while step_count < n_steps: 81 | log_action_probs = [] 82 | rewards = [] 83 | 84 | obs = env.reset() 85 | for _ in range(T): 86 | tobs = torch.from_numpy(obs).float() 87 | a_space = policy(tobs) 88 | a_space = Categorical(a_space) 89 | 90 | a = a_space.sample() 91 | prob = a_space.log_prob(a) 92 | 93 | entropy = a_space.entropy() 94 | writer.add_scalar('policy/entropy', entropy.item(), step_count) 95 | 96 | obs2, r, done, _ = env.step(a.numpy()) 97 | step_count += 1 98 | pbar.update(1) 99 | 100 | rewards.append(torch.tensor(r).float()) 101 | log_action_probs.append(prob) 102 | 103 | if done: break 104 | obs = obs2 105 | 106 | metrics = { 107 | 'reward/min': min(rewards), 108 | 'reward/max': max(rewards), 109 | 'reward/total': sum(rewards), 110 | } 111 | 112 | # rewards to go 113 | gamma = 0.99 114 | for i in range(len(rewards) - 1)[::-1]: 115 | rewards[i] = rewards[i] + gamma * rewards[i+1] 116 | rewards = torch.tensor(rewards) 117 | rewards = (rewards - rewards.mean()) / rewards.std() 118 | 119 | metrics['reward/rescaled_min'] = min(rewards) 120 | metrics['reward/rescaled_max'] = max(rewards) 121 | metrics['reward/rescaled_total'] = sum(rewards) 122 | 123 | # record metrics 124 | for m in metrics: 125 | writer.add_scalar(m, metrics[m], epi_count) 126 | 127 | log_action_probs = torch.stack(log_action_probs) 128 | 129 | rollout_loss = -(log_action_probs * rewards).sum() # this is key 130 | rollout_loss.backward() 131 | 132 | if (epi_count+1) % batch_size == 0: 133 | 134 | for i, p in enumerate(policy.parameters()): 135 | writer.add_histogram(f'w{i}/weight', p.data, epi_count) 136 | writer.add_histogram(f'w{i}/grad', p.grad, epi_count) 137 | 138 | if grad_clip_v != -1: 139 | torch.nn.utils.clip_grad_norm_(policy.parameters(), grad_clip_v) 140 | for i, p in enumerate(policy.parameters()): 141 | writer.add_histogram(f'w{i}/clipped_grad', p.grad, epi_count) 142 | 143 | optim.step() 144 | optim.zero_grad() 145 | 146 | epi_count += 1 147 | pbar.close() 148 | 149 | if FLAGS.view_result: ## evaluation 150 | imgs = [] 151 | eval_reward = 0 152 | obs = env.reset() 153 | while True: 154 | img = env.render(mode='rgb_array') 155 | imgs.append(img) 156 | 157 | tobs = torch.from_numpy(obs).float() 158 | a_space = policy(tobs) 159 | a_space = Categorical(a_space) 160 | a = a_space.sample() 161 | 162 | obs2, r, done, _ = env.step(a.numpy()) 163 | 164 | eval_reward += r 165 | obs = obs2 166 | 167 | if done: break 168 | 169 | print(f'Total Reward: {eval_reward:.2f} #frames: {len(imgs)}') 170 | print('writing...') 171 | write_apng('cartpole.png', imgs, delay=20) 172 | 173 | if __name__ == '__main__': 174 | app.run(main) 175 | 176 | # %% 177 | #%% 178 | # %% 179 | # %% 180 | -------------------------------------------------------------------------------- /reinforce/reinforce_cont.py: -------------------------------------------------------------------------------- 1 | # doesn't work 2 | 3 | #%% 4 | import jax 5 | import jax.numpy as np 6 | import numpy as onp 7 | import distrax 8 | import optax 9 | import gym 10 | from functools import partial 11 | 12 | from jax.config import config 13 | config.update("jax_enable_x64", True) 14 | config.update("jax_debug_nans", True) # break on nans 15 | 16 | #%% 17 | env_name = 'Pendulum-v0' 18 | env = gym.make(env_name) 19 | 20 | # from env import Navigation2DEnv 21 | # env_name = 'Navigation2D' 22 | # def make_env(): 23 | # env = Navigation2DEnv(max_n_steps=100) 24 | # env.seed(0) 25 | # task = env.sample_tasks(1)[0] 26 | # print(f'LOGGER: task = {task}') 27 | # env.reset_task(task) 28 | 29 | # # log max reward 30 | # goal = env._task['goal'] 31 | # reward = 0 32 | # step_count = 0 33 | # obs = env.reset() 34 | # while True: 35 | # a = goal - obs 36 | # obs2, r, done, _ = env.step(a) 37 | # reward += r 38 | # step_count += 1 39 | # if done: break 40 | # obs = obs2 41 | # print(f'[LOGGER]: MAX_REWARD={reward} IN {step_count} STEPS') 42 | # return env 43 | # env = make_env() 44 | 45 | n_actions = env.action_space.shape[0] 46 | obs_dim = env.observation_space.shape[0] 47 | 48 | a_high = env.action_space.high[0] 49 | a_low = env.action_space.low[0] 50 | 51 | print(f'[LOGGER] a_high: {a_high} a_low: {a_low} n_actions: {n_actions} obs_dim: {obs_dim}') 52 | assert -a_high == a_low 53 | 54 | #%% 55 | import haiku as hk 56 | 57 | def _policy_fcn(s): 58 | log_std = hk.get_parameter("log_std", shape=[n_actions,], init=np.ones) 59 | mu = hk.Sequential([ 60 | hk.Linear(64), jax.nn.relu, 61 | hk.Linear(64), jax.nn.relu, 62 | hk.Linear(n_actions), np.tanh 63 | ])(s) * a_high 64 | sig = np.exp(log_std) 65 | return mu, sig 66 | 67 | policy_fcn = hk.transform(_policy_fcn) 68 | policy_fcn = hk.without_apply_rng(policy_fcn) 69 | p_frwd = jax.jit(policy_fcn.apply) 70 | 71 | @jax.jit 72 | def update_step(params, grads, opt_state): 73 | grads, opt_state = p_optim.update(grads, opt_state) 74 | params = optax.apply_updates(params, grads) 75 | return params, opt_state 76 | 77 | def discount_cumsum(l, discount): 78 | l = onp.array(l) 79 | for i in range(len(l) - 1)[::-1]: 80 | l[i] = l[i] + discount * l[i+1] 81 | return l 82 | 83 | @jax.jit 84 | def policy(params, obs, rng): 85 | mu, sig = p_frwd(params, obs) 86 | dist = distrax.MultivariateNormalDiag(mu, sig) 87 | a = dist.sample(seed=rng) 88 | a = np.clip(a, a_low, a_high) 89 | return a 90 | 91 | @jax.jit 92 | def eval_policy(params, obs, _): 93 | a, _ = p_frwd(params, obs) 94 | a = np.clip(a, a_low, a_high) 95 | return a, None 96 | 97 | def eval(params, env, rng): 98 | rewards = 0 99 | obs = env.reset() 100 | while True: 101 | rng, subkey = jax.random.split(rng, 2) 102 | a = eval_policy(params, obs, subkey)[0] 103 | a = onp.array(a) 104 | obs2, r, done, _ = env.step(a) 105 | obs = obs2 106 | rewards += r 107 | if done: break 108 | return rewards 109 | 110 | def rollout(p_params, rng): 111 | global step_count 112 | 113 | observ, action, rew = [], [], [] 114 | obs = env.reset() 115 | while True: 116 | rng, subkey = jax.random.split(rng, 2) 117 | a = policy(p_params, obs, subkey) 118 | a = onp.array(a) 119 | 120 | obs2, r, done, _ = env.step(a) 121 | 122 | observ.append(obs) 123 | action.append(a) 124 | rew.append(r) 125 | 126 | if done: break 127 | obs = obs2 128 | 129 | obs = np.stack(observ) 130 | a = np.stack(action) 131 | r = onp.stack(rew) # 132 | 133 | return obs, a, r 134 | 135 | def reinforce_loss(p_params, obs, a, r): 136 | mu, sig = p_frwd(p_params, obs) 137 | log_prob = distrax.MultivariateNormalDiag(mu, sig).log_prob(a) 138 | loss = -(log_prob * r).sum() 139 | return loss 140 | 141 | from functools import partial 142 | def batch_reinforce_loss(params, batch): 143 | return jax.vmap(partial(reinforce_loss, params))(*batch).sum() 144 | 145 | # %% 146 | seed = onp.random.randint(1e5) 147 | policy_lr = 1e-3 148 | batch_size = 32 149 | max_n_steps = 1000 150 | gamma = 0.99 151 | 152 | rng = jax.random.PRNGKey(seed) 153 | onp.random.seed(seed) 154 | 155 | obs = env.reset() # dummy input 156 | p_params = policy_fcn.init(rng, obs) 157 | 158 | ## optimizers 159 | p_optim = optax.chain( 160 | optax.clip_by_global_norm(0.5), 161 | optax.scale_by_adam(), 162 | optax.scale(-policy_lr), 163 | ) 164 | 165 | p_opt_state = p_optim.init(p_params) 166 | 167 | # %% 168 | from torch.utils.tensorboard import SummaryWriter 169 | writer = SummaryWriter(comment=f'reinforce_cont_{env_name}_seed={seed}') 170 | 171 | # %% 172 | from tqdm import tqdm 173 | 174 | step_count = 0 175 | epi_i = 0 176 | 177 | pbar = tqdm(total=max_n_steps) 178 | gradients = [] 179 | loss_grad_fcn = jax.jit(jax.value_and_grad(batch_reinforce_loss)) 180 | 181 | env_step_count = 0 182 | while step_count < max_n_steps: 183 | rng, subkey = jax.random.split(rng, 2) 184 | obs, a, r = rollout(p_params, subkey) 185 | writer.add_scalar('rollout/reward', r.sum().item(), epi_i) 186 | 187 | r = discount_cumsum(r, discount=gamma) 188 | r = (r - r.mean()) / (r.std() + 1e-8) 189 | loss, grad = loss_grad_fcn(p_params, (obs, a, r)) 190 | gradients.append(grad) 191 | writer.add_scalar('loss/loss', loss.item(), epi_i) 192 | 193 | epi_i += 1 194 | if epi_i % batch_size == 0: 195 | # update 196 | grad = jax.tree_multimap(lambda *x: np.stack(x).mean(0), *gradients) 197 | p_params, p_opt_state = update_step(p_params, grad, p_opt_state) 198 | step_count += 1 199 | pbar.update(1) 200 | gradients = [] 201 | 202 | # eval 203 | rng, subkey = jax.random.split(rng, 2) 204 | r = eval(p_params, env, subkey) 205 | writer.add_scalar('eval/total_reward', r, epi_i) 206 | 207 | # %% 208 | # %% 209 | # %% 210 | -------------------------------------------------------------------------------- /reinforce/reinforce_jax.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import jax 3 | import jax.numpy as np 4 | import numpy as onp 5 | import distrax 6 | import optax 7 | import gym 8 | from functools import partial 9 | 10 | from jax.config import config 11 | config.update("jax_enable_x64", True) 12 | config.update("jax_debug_nans", True) # break on nans 13 | 14 | #%% 15 | env_name = 'CartPole-v0' 16 | env = gym.make(env_name) 17 | 18 | n_actions = env.action_space.n 19 | obs_dim = env.observation_space.shape[0] 20 | print(f'[LOGGER] n_actions: {n_actions} obs_dim: {obs_dim}') 21 | 22 | #%% 23 | import haiku as hk 24 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 25 | 26 | def _policy_fcn(obs): 27 | a_probs = hk.Sequential([ 28 | hk.Linear(32), jax.nn.relu, 29 | hk.Linear(32), jax.nn.relu, 30 | hk.Linear(n_actions, w_init=init_final), jax.nn.softmax 31 | ])(obs) 32 | return a_probs 33 | 34 | policy_fcn = hk.transform(_policy_fcn) 35 | policy_fcn = hk.without_apply_rng(policy_fcn) 36 | p_frwd = jax.jit(policy_fcn.apply) 37 | 38 | @jax.jit 39 | def update_step(params, grads, opt_state): 40 | grads, opt_state = p_optim.update(grads, opt_state) 41 | params = optax.apply_updates(params, grads) 42 | return params, opt_state 43 | 44 | def reward2go(r, gamma=0.99): 45 | for i in range(len(r) - 1)[::-1]: 46 | r[i] = r[i] + gamma * r[i+1] 47 | r = (r - r.mean()) / (r.std() + 1e-8) 48 | return r 49 | 50 | @jax.jit 51 | def policy(p_params, obs, rng): 52 | a_probs = p_frwd(p_params, obs) 53 | dist = distrax.Categorical(probs=a_probs) 54 | a = dist.sample(seed=rng) 55 | entropy = dist.entropy() 56 | return a, entropy 57 | 58 | def rollout(p_params, rng): 59 | global step_count 60 | 61 | observ, action, rew = [], [], [] 62 | obs = env.reset() 63 | while True: 64 | rng, subkey = jax.random.split(rng, 2) 65 | a, entropy = policy(p_params, obs, subkey) 66 | a = a.item() 67 | 68 | writer.add_scalar('policy/entropy', entropy.item(), step_count) 69 | 70 | obs2, r, done, _ = env.step(a) 71 | step_count += 1 72 | pbar.update(1) 73 | 74 | observ.append(obs) 75 | action.append(a) 76 | rew.append(r) 77 | 78 | if done: break 79 | obs = obs2 80 | 81 | obs = np.stack(observ) 82 | a = np.stack(action) 83 | r = onp.stack(rew) # 84 | 85 | return obs, a, r 86 | 87 | def reinforce_loss(p_params, obs, a, r): 88 | a_probs = p_frwd(p_params, obs) 89 | log_prob = distrax.Categorical(probs=a_probs).log_prob(a.astype(int)) 90 | loss = -(log_prob * r).sum() 91 | return loss 92 | 93 | from functools import partial 94 | def batch_reinforce_loss(params, batch): 95 | return jax.vmap(partial(reinforce_loss, params))(*batch).sum() 96 | 97 | # %% 98 | seed = onp.random.randint(1e5) 99 | policy_lr = 1e-3 100 | batch_size = 32 101 | max_n_steps = 100000 102 | 103 | rng = jax.random.PRNGKey(seed) 104 | onp.random.seed(seed) 105 | env.seed(seed) 106 | 107 | obs = env.reset() # dummy input 108 | p_params = policy_fcn.init(rng, obs) 109 | 110 | ## optimizers 111 | p_optim = optax.sgd(policy_lr) 112 | p_opt_state = p_optim.init(p_params) 113 | 114 | # %% 115 | from torch.utils.tensorboard import SummaryWriter 116 | writer = SummaryWriter(comment=f'reinforce_{env_name}_seed={seed}') 117 | 118 | # %% 119 | from tqdm import tqdm 120 | 121 | step_count = 0 122 | epi_i = 0 123 | 124 | pbar = tqdm(total=max_n_steps) 125 | gradients = [] 126 | loss_grad_fcn = jax.jit(jax.value_and_grad(batch_reinforce_loss)) 127 | 128 | while step_count < max_n_steps: 129 | rng, subkey = jax.random.split(rng, 2) 130 | obs, a, r = rollout(p_params, subkey) 131 | writer.add_scalar('rollout/reward', r.sum().item(), epi_i) 132 | r = reward2go(r) 133 | loss, grad = loss_grad_fcn(p_params, (obs, a, r)) 134 | 135 | writer.add_scalar('loss/loss', loss.item(), step_count) 136 | gradients.append(grad) 137 | 138 | epi_i += 1 139 | if epi_i % batch_size == 0: 140 | grads = jax.tree_multimap(lambda *x: np.stack(x).sum(0), *gradients) 141 | p_params, p_opt_state = update_step(p_params, grads, p_opt_state) 142 | 143 | for i, g in enumerate(jax.tree_leaves(grads)): 144 | name = 'b' if len(g.shape) == 1 else 'w' 145 | writer.add_histogram(f'{name}_{i}_grad', onp.array(g), epi_i) 146 | 147 | gradients = [] 148 | 149 | # %% 150 | # %% 151 | # %% 152 | -------------------------------------------------------------------------------- /reinforce/reinforce_linear_baseline.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import jax 3 | import jax.numpy as np 4 | import numpy as onp 5 | import distrax 6 | import optax 7 | import gym 8 | from functools import partial 9 | 10 | from jax.config import config 11 | config.update("jax_enable_x64", True) 12 | config.update("jax_debug_nans", True) # break on nans 13 | 14 | #%% 15 | env_name = 'CartPole-v0' 16 | env = gym.make(env_name) 17 | 18 | # from tmp import Navigation2DEnv_Disc 19 | # env_name = 'Navigation2D' 20 | # def make_env(init_task=False): 21 | # env = Navigation2DEnv_Disc(max_n_steps=200) 22 | 23 | # if init_task: 24 | # env.seed(0) 25 | # task = env.sample_tasks(1)[0] 26 | # print(f'[LOGGER]: task = {task}') 27 | # env.reset_task(task) 28 | 29 | # # log max reward 30 | # goal = env._task['goal'] 31 | # x, y = goal 32 | # n_right = x / 0.1 33 | # n_up = y / 0.1 34 | # action_seq = [] 35 | # for _ in range(int(abs(n_right))): action_seq.append(3 if n_right > 0 else 2) 36 | # for _ in range(int(abs(n_up))): action_seq.append(0 if n_up > 0 else 1) 37 | 38 | # reward = 0 39 | # step_count = 0 40 | # env.reset() 41 | # for a in action_seq: 42 | # _, r, done, _ = env.step(a) 43 | # reward += r 44 | # step_count += 1 45 | # if done: break 46 | # assert done 47 | # print(f'[LOGGER]: MAX_REWARD={reward} IN {step_count} STEPS') 48 | # return env 49 | 50 | # env = make_env(init_task=True) 51 | 52 | n_actions = env.action_space.n 53 | obs_dim = env.observation_space.shape[0] 54 | print(f'[LOGGER] n_actions: {n_actions} obs_dim: {obs_dim}') 55 | 56 | #%% 57 | import haiku as hk 58 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 59 | 60 | def _policy_fcn(obs): 61 | a_probs = hk.Sequential([ 62 | hk.Linear(32), jax.nn.relu, 63 | hk.Linear(32), jax.nn.relu, 64 | hk.Linear(n_actions, w_init=init_final), jax.nn.softmax 65 | ])(obs) 66 | return a_probs 67 | 68 | policy_fcn = hk.transform(_policy_fcn) 69 | policy_fcn = hk.without_apply_rng(policy_fcn) 70 | p_frwd = jax.jit(policy_fcn.apply) 71 | 72 | @jax.jit 73 | def update_step(params, grads, opt_state): 74 | grads, opt_state = p_optim.update(grads, opt_state) 75 | params = optax.apply_updates(params, grads) 76 | return params, opt_state 77 | 78 | def reward2go(r, gamma=0.99): 79 | for i in range(len(r) - 1)[::-1]: 80 | r[i] = r[i] + gamma * r[i+1] 81 | r = (r - r.mean()) / (r.std() + 1e-8) 82 | return r 83 | 84 | @jax.jit 85 | def policy(p_params, obs, rng): 86 | a_probs = p_frwd(p_params, obs) 87 | dist = distrax.Categorical(probs=a_probs) 88 | a = dist.sample(seed=rng) 89 | entropy = dist.entropy() 90 | return a, entropy 91 | 92 | def rollout(p_params, rng): 93 | global step_count 94 | 95 | observ, action, rew = [], [], [] 96 | obs = env.reset() 97 | while True: 98 | rng, subkey = jax.random.split(rng, 2) 99 | a, entropy = policy(p_params, obs, subkey) 100 | a = a.item() 101 | 102 | writer.add_scalar('policy/entropy', entropy.item(), step_count) 103 | 104 | obs2, r, done, _ = env.step(a) 105 | step_count += 1 106 | pbar.update(1) 107 | 108 | observ.append(obs) 109 | action.append(a) 110 | rew.append(r) 111 | 112 | if done: break 113 | obs = obs2 114 | 115 | obs = np.stack(observ) 116 | a = np.stack(action) 117 | r = onp.stack(rew) # 118 | 119 | return obs, a, r 120 | 121 | def reinforce_loss(p_params, obs, a, r): 122 | a_probs = p_frwd(p_params, obs) 123 | log_prob = distrax.Categorical(probs=a_probs).log_prob(a.astype(int)) 124 | loss = -(log_prob * r).sum() 125 | return loss 126 | 127 | from functools import partial 128 | def batch_reinforce_loss(params, batch): 129 | return jax.vmap(partial(reinforce_loss, params))(*batch).sum() 130 | 131 | # https://github.com/rll/rllab/blob/master/rllab/baselines/linear_feature_baseline.py 132 | def v_features(obs): 133 | o = np.clip(obs, -10, 10) 134 | l = len(o) 135 | al = np.arange(l).reshape(-1, 1) / 100.0 136 | return np.concatenate([o, o ** 2, al, al ** 2, al ** 3, np.ones((l, 1))], axis=1) 137 | 138 | def v_fit(obs_l, rew_l, feature_fcn=v_features, reg_coeff=1e-5): 139 | featmat = np.concatenate([feature_fcn(obs) for obs in obs_l]) 140 | r = np.concatenate(rew_l) 141 | for _ in range(5): 142 | # solve argmin_x (F x = R) 143 | _coeffs = np.linalg.lstsq( 144 | featmat.T.dot(featmat) + reg_coeff * np.identity(featmat.shape[1]), 145 | featmat.T.dot(r) 146 | )[0] 147 | if not np.any(np.isnan(_coeffs)): 148 | return _coeffs, 0 # succ 149 | reg_coeff *= 10 150 | return np.zeros_like(_coeffs), 1 # err 151 | 152 | # %% 153 | seed = onp.random.randint(1e5) 154 | policy_lr = 1e-3 155 | batch_size = 32 156 | max_n_steps = 100000 157 | 158 | rng = jax.random.PRNGKey(seed) 159 | onp.random.seed(seed) 160 | 161 | obs = env.reset() # dummy input 162 | p_params = policy_fcn.init(rng, obs) 163 | 164 | ## optimizers 165 | p_optim = optax.sgd(policy_lr) 166 | p_opt_state = p_optim.init(p_params) 167 | 168 | # %% 169 | from torch.utils.tensorboard import SummaryWriter 170 | writer = SummaryWriter(comment=f'reinforce_LinearBaseline_{env_name}_seed={seed}') 171 | 172 | # %% 173 | from tqdm import tqdm 174 | 175 | step_count = 0 176 | epi_i = 0 177 | 178 | pbar = tqdm(total=max_n_steps) 179 | loss_grad_fcn = jax.jit(jax.value_and_grad(batch_reinforce_loss)) 180 | 181 | while step_count < max_n_steps: 182 | obs_l, act_l, rew_l = [], [], [] 183 | for _ in range(batch_size): 184 | rng, subkey = jax.random.split(rng, 2) 185 | obs, a, r = rollout(p_params, subkey) 186 | writer.add_scalar('rollout/reward', r.sum().item(), epi_i) 187 | 188 | obs_l.append(obs) 189 | act_l.append(a) 190 | rew_l.append(r) 191 | epi_i += 1 192 | 193 | W = v_fit(obs_l, [reward2go(r) for r in rew_l])[0] 194 | 195 | gradients = [] 196 | for i in range(batch_size): 197 | obs, a, r = obs_l[i], act_l[i], rew_l[i] 198 | r = r - (v_features(obs) @ W) # subtract baseline 199 | 200 | loss, grad = loss_grad_fcn(p_params, (obs, a, r)) 201 | writer.add_scalar('loss/loss', loss.item(), step_count) 202 | gradients.append(grad) 203 | 204 | grads = jax.tree_multimap(lambda *x: np.stack(x).sum(0), *gradients) 205 | p_params, p_opt_state = update_step(p_params, grads, p_opt_state) 206 | 207 | for i, g in enumerate(jax.tree_leaves(grads)): 208 | name = 'b' if len(g.shape) == 1 else 'w' 209 | writer.add_histogram(f'{name}_{i}_grad', onp.array(g), epi_i) 210 | -------------------------------------------------------------------------------- /reinforce/reinforce_torchVSjax.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import jax 3 | import jax.numpy as jnp 4 | import haiku as hk 5 | import distrax 6 | 7 | import torch 8 | from torch.distributions.categorical import Categorical 9 | 10 | import gym 11 | import numpy as np 12 | 13 | from functools import partial 14 | import time 15 | 16 | #%% 17 | env = gym.make('CartPole-v0') 18 | 19 | import torch.nn as nn 20 | policy = nn.Sequential( 21 | nn.Linear(env.observation_space.shape[0], 32), 22 | nn.ReLU(), 23 | nn.Linear(32, 32), 24 | nn.ReLU(), 25 | nn.Linear(32, env.action_space.n), 26 | nn.Softmax(dim=-1), 27 | ) 28 | optim = torch.optim.SGD(policy.parameters(), lr=1e-3) 29 | optim.zero_grad() 30 | 31 | def _policy_fcn(obs): 32 | a_probs = hk.Sequential([ 33 | hk.Linear(32), jax.nn.relu, 34 | hk.Linear(32), jax.nn.relu, 35 | hk.Linear(env.action_space.n), jax.nn.softmax 36 | ])(obs) 37 | return a_probs 38 | 39 | policy_fcn = hk.transform(_policy_fcn) 40 | policy_fcn = hk.without_apply_rng(policy_fcn) 41 | p_frwd = jax.jit(policy_fcn.apply) 42 | 43 | seed = 0 44 | rng = jax.random.PRNGKey(seed) 45 | obs = env.reset() # dummy input 46 | p_params = policy_fcn.init(rng, obs) 47 | 48 | #%% 49 | def torch_policy(obs): 50 | a_space = policy(torch.from_numpy(obs).float()) 51 | a_space = Categorical(a_space) 52 | a = a_space.sample() 53 | log_prob = a_space.log_prob(a) 54 | return a, log_prob 55 | 56 | def torch_rollout(): 57 | np.random.seed(seed) 58 | env.seed(seed) 59 | obs = env.reset() 60 | 61 | log_probs = [] 62 | rewards = [] 63 | while True: 64 | ## -- 65 | a, log_prob = torch_policy(obs) 66 | a = a.numpy() 67 | ## -- 68 | 69 | a = np.random.choice(env.action_space.n) 70 | obs2, r, done, _ = env.step(a) 71 | if done: break 72 | obs = obs2 73 | 74 | log_probs.append(log_prob) 75 | rewards.append(r) 76 | 77 | ## -- 78 | log_probs = torch.stack(log_probs) 79 | r = torch.tensor(rewards) 80 | loss = -(log_probs * r).sum() 81 | ## -- 82 | return loss 83 | 84 | @jax.jit 85 | def jax_policy(p_params, obs, key): 86 | a_probs = p_frwd(p_params, obs) 87 | a_probs = distrax.Categorical(probs=a_probs) 88 | a, log_prob = a_probs.sample_and_log_prob(seed=key) 89 | return a, log_prob 90 | 91 | def jax_rollout(p_params, rng): 92 | np.random.seed(seed) 93 | env.seed(seed) 94 | obs = env.reset() 95 | 96 | log_probs = [] 97 | rewards = [] 98 | while True: 99 | ## -- 100 | rng, key = jax.random.split(rng, 2) 101 | a, log_prob = jax_policy(p_params, obs, key) 102 | a = a.astype(int) 103 | ## -- 104 | 105 | a = np.random.choice(env.action_space.n) 106 | obs2, r, done, _ = env.step(a) 107 | if done: break 108 | obs = obs2 109 | 110 | log_probs.append(log_prob) 111 | rewards.append(r) 112 | 113 | ## -- 114 | log_prob = jnp.stack(log_probs) 115 | r = np.stack(rewards) 116 | loss = -(log_prob * r).sum() 117 | ## -- 118 | return loss 119 | 120 | ## only sample policy (log_prob is computed in loss) 121 | @jax.jit 122 | def jax_policy2(p_params, obs, key): 123 | a_probs = p_frwd(p_params, obs) 124 | a_probs = distrax.Categorical(probs=a_probs) 125 | a = a_probs.sample(seed=key) 126 | return a 127 | 128 | def jax_rollout2(p_params, rng): 129 | np.random.seed(seed) 130 | env.seed(seed) 131 | obs = env.reset() 132 | 133 | observ, action, rew = [], [], [] 134 | while True: 135 | ## -- 136 | rng, key = jax.random.split(rng, 2) 137 | a = jax_policy2(p_params, obs, key) 138 | a = a.astype(int) 139 | ## -- 140 | 141 | a = np.random.choice(env.action_space.n) 142 | obs2, r, done, _ = env.step(a) 143 | 144 | observ.append(obs) 145 | action.append(a) 146 | rew.append(r) 147 | 148 | if done: break 149 | obs = obs2 150 | 151 | obs = jnp.stack(observ) 152 | a = jnp.stack(action) 153 | r = jnp.stack(rew) 154 | return obs, a, r 155 | 156 | def jax_loss(p_params, obs, a, r): 157 | a_probs = p_frwd(p_params, obs) 158 | log_prob = distrax.Categorical(probs=a_probs).log_prob(a.astype(int)) 159 | loss = -(log_prob * r).sum() 160 | return loss 161 | 162 | def batch_jax_loss(params, obs, a, r): 163 | return jax.vmap(partial(jax_loss, params))(obs, a, r).sum() 164 | 165 | rng = jax.random.PRNGKey(seed) 166 | 167 | #%% 168 | #### PYTORCH 169 | times = [] 170 | for _ in range(50): 171 | start = time.time() 172 | 173 | loss = torch_rollout() 174 | loss.backward() 175 | 176 | times.append(time.time() - start) 177 | 178 | # 0.03423449039459228 179 | print(f'PYTORCH TIME: {np.mean(times)}') 180 | 181 | #%% 182 | #### JAX 183 | # rollout_loss fcn 184 | jax_rollout_jitgrad = jax.jit(jax.value_and_grad(jax_rollout)) 185 | 186 | times = [] 187 | for _ in range(50): 188 | rng = jax.random.PRNGKey(seed) 189 | 190 | start = time.time() 191 | 192 | rng, key = jax.random.split(rng, 2) 193 | loss, grad = jax_rollout_jitgrad(p_params, key) 194 | loss.block_until_ready() 195 | 196 | times.append(time.time() - start) 197 | 198 | # 0.21324730396270752 199 | print(f'JAX (rollout_loss) TIME: {np.mean(times)}') 200 | 201 | #%% 202 | #### JAX 203 | # rollout fcn & loss fcn 204 | jit_jax_rollout2 = jax.jit(jax_rollout2) 205 | jax_loss_jit = jax.jit(jax.value_and_grad(batch_jax_loss)) 206 | 207 | times = [] 208 | for _ in range(50): 209 | rng = jax.random.PRNGKey(seed) 210 | start = time.time() 211 | 212 | rng, key = jax.random.split(rng, 2) 213 | # batch = jit_jax_rollout2(p_params, key) 214 | batch = jax_rollout2(p_params, key) 215 | 216 | loss, grad = jax_loss_jit(p_params, *batch) 217 | loss.block_until_ready() 218 | 219 | times.append(time.time() - start) 220 | 221 | # 0.10453275203704834 with jit_jax_rollout2 222 | # 0.07715171337127685 with **no-jit**-rollout2 223 | print(f'JAX (rollout -> loss) TIME: {np.mean(times)}') 224 | 225 | #%% -------------------------------------------------------------------------------- /reinforce/tmp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gebob19/rl_with_jax/a30df06de3035c460e5339611974664a2130ca6e/reinforce/tmp.png -------------------------------------------------------------------------------- /reinforce/tmp.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import numpy as np 3 | import gym 4 | 5 | from gym import spaces 6 | from gym.utils import seeding 7 | 8 | class Navigation2DEnv_Disc(gym.Env): 9 | def __init__(self, task={}, low=-0.5, high=0.5, max_n_steps=100): 10 | super().__init__() 11 | self.low = low 12 | self.high = high 13 | self.max_n_steps = max_n_steps 14 | self._step_count = 0 15 | 16 | self.observation_space = spaces.Box(low=-np.inf, high=np.inf, 17 | shape=(2,), dtype=np.float32) 18 | self.action_space = spaces.Discrete(4) # left, right, up, down 19 | 20 | self._task = task 21 | self._goal = task.get('goal', np.zeros(2, dtype=np.float32)) 22 | self._state = np.zeros(2, dtype=np.float32) 23 | self.seed() 24 | 25 | def seed(self, seed=None): 26 | self.np_random, seed = seeding.np_random(seed) 27 | return [seed] 28 | 29 | def sample_tasks(self, num_tasks): 30 | while True: 31 | goals = self.np_random.uniform(self.low, self.high, size=(num_tasks, 2)) 32 | if not (goals.sum(0) == 0).any(): break 33 | 34 | goals = np.round_(goals, 1) # discrete them to 0.1 steps 35 | tasks = [{'goal': goal} for goal in goals] 36 | return tasks 37 | 38 | def reset_task(self, task): 39 | self._task = task 40 | self._goal = task['goal'] 41 | 42 | def reset(self): 43 | self._step_count = 0 44 | self._state = np.zeros(2, dtype=np.float32) 45 | return self._state 46 | 47 | def step(self, action): 48 | assert self.action_space.contains(action) 49 | # up down left right 50 | step = np.array({ 51 | 0: [0, 0.1], 52 | 1: [0, -0.1], 53 | 2: [-0.1, 0], 54 | 3: [0.1, 0] 55 | }[action]) 56 | 57 | self._state = self._state + step 58 | 59 | diff = self._state - self._goal 60 | reward = -np.sqrt((diff**2).sum()) 61 | done = (np.abs(diff) < 0.01).sum() == 2 62 | 63 | done = done or self._step_count >= self.max_n_steps 64 | self._step_count += 1 65 | 66 | return self._state, reward, done, {'task': self._task} 67 | 68 | # %% 69 | env = Navigation2DEnv_Disc() 70 | task = env.sample_tasks(1)[0] 71 | env.reset_task(task) 72 | print(task) 73 | 74 | # %% 75 | x, y = env._goal 76 | n_right = x / 0.1 77 | n_up = y / 0.1 78 | action_seq = [] 79 | for _ in range(int(abs(n_right)+.5)): action_seq.append(3 if n_right > 0 else 2) 80 | for _ in range(int(abs(n_up)+.5)): action_seq.append(0 if n_up > 0 else 1) 81 | action_seq 82 | 83 | # %% 84 | states = [] 85 | obs = env.reset() 86 | states.append(obs) 87 | for a in [2, 3] + action_seq: 88 | obs, r, done, _ = env.step(a) 89 | states.append(obs) 90 | print(r, done) 91 | if done: break 92 | states = np.stack(states) 93 | 94 | # %% 95 | import matplotlib.pyplot as plt 96 | plt.scatter(x, y, marker='*') 97 | plt.scatter(states[:, 0], states[:, 1], color='b') 98 | plt.plot(states[:, 0], states[:, 1], color='b') 99 | plt.savefig('tmp.png') 100 | plt.clf() 101 | import cv2 102 | img = cv2.imread('tmp.png') 103 | plt.imshow(img) 104 | 105 | # %% 106 | # %% 107 | # %% 108 | 109 | 110 | %% 111 | -------------------------------------------------------------------------------- /tmp.md: -------------------------------------------------------------------------------- 1 | $$ 2 | \underbrace{\theta'}_{N\times 1} = \underbrace{\theta}_{N\times 1} + \alpha \underbrace{F^{-1}}_{N\times N} \underbrace{\nabla J}_{N\times M} \tag{where M = 1 for DL} 3 | $$ 4 | -------------------------------------------------------------------------------- /trpo/debug.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import jax 3 | import jax.numpy as np 4 | import numpy as onp 5 | import distrax 6 | import optax 7 | import gym 8 | from functools import partial 9 | import cloudpickle 10 | import haiku as hk 11 | 12 | from jax.config import config 13 | config.update("jax_enable_x64", True) 14 | config.update("jax_debug_nans", True) # break on nans 15 | 16 | env_name = 'Pendulum-v0' 17 | env = gym.make(env_name) 18 | 19 | n_actions = env.action_space.shape[0] 20 | obs_dim = env.observation_space.shape[0] 21 | 22 | #%% 23 | # https://github.com/ikostrikov/pytorch-trpo 24 | import torch 25 | import torch.autograd as autograd 26 | import torch.nn as nn 27 | 28 | class Policy(nn.Module): 29 | def __init__(self, num_inputs, num_outputs): 30 | super(Policy, self).__init__() 31 | self.affine1 = nn.Linear(num_inputs, 64) 32 | self.affine2 = nn.Linear(64, 64) 33 | 34 | self.action_mean = nn.Linear(64, num_outputs) 35 | self.action_mean.weight.data.mul_(0.1) 36 | self.action_mean.bias.data.mul_(0.0) 37 | 38 | self.action_log_std = nn.Parameter(torch.zeros(1, num_outputs)) 39 | 40 | self.saved_actions = [] 41 | self.rewards = [] 42 | self.final_value = 0 43 | 44 | def forward(self, x): 45 | x = torch.tanh(self.affine1(x)) 46 | x = torch.tanh(self.affine2(x)) 47 | 48 | action_mean = self.action_mean(x) 49 | action_log_std = self.action_log_std.expand_as(action_mean) 50 | action_std = torch.exp(action_log_std) 51 | 52 | return action_mean, action_log_std, action_std 53 | 54 | policy_net = Policy(obs_dim, n_actions) 55 | 56 | #%% 57 | params = [] 58 | for linear in [policy_net.affine1, policy_net.affine2, policy_net.action_mean]: 59 | w = onp.array(linear.weight.data) 60 | b = onp.array(linear.bias.data) 61 | params.append((w, b)) 62 | 63 | #%% 64 | class CustomWInit(hk.initializers.Initializer): 65 | def __init__(self, i) -> None: 66 | super().__init__() 67 | self.i = i 68 | def __call__(self, shape, dtype): 69 | return params[self.i][0].T 70 | class CustomBInit(hk.initializers.Initializer): 71 | def __init__(self, i) -> None: 72 | super().__init__() 73 | self.i = i 74 | def __call__(self, shape, dtype): 75 | return params[self.i][1].T 76 | 77 | def _policy_fcn(s): 78 | log_std = hk.get_parameter("log_std", shape=[n_actions,], init=np.zeros) 79 | mu = hk.Sequential([ 80 | hk.Linear(64, w_init=CustomWInit(0), b_init=CustomBInit(0)), np.tanh, 81 | hk.Linear(64, w_init=CustomWInit(1), b_init=CustomBInit(1)), np.tanh, 82 | hk.Linear(n_actions, w_init=CustomWInit(2), b_init=CustomBInit(2)) 83 | ])(s) 84 | std = np.exp(log_std) 85 | return mu, log_std, std 86 | 87 | policy_fcn = hk.transform(_policy_fcn) 88 | policy_fcn = hk.without_apply_rng(policy_fcn) 89 | p_frwd = jax.jit(policy_fcn.apply) 90 | 91 | seed = 0 92 | rng = jax.random.PRNGKey(seed) 93 | onp.random.seed(seed) 94 | env.seed(seed) 95 | 96 | obs = env.reset() # dummy input 97 | p_params = policy_fcn.init(rng, obs) 98 | 99 | # %% 100 | t_obs = torch.from_numpy(obs).float()[None] 101 | n_obs = t_obs.numpy() 102 | 103 | p_frwd(p_params, n_obs), policy_net(t_obs) 104 | 105 | # %% 106 | a = env.action_space.sample() 107 | obs2, r, done, _ = env.step(a) 108 | 109 | # %% 110 | ta = torch.from_numpy(a).float() 111 | na = ta.numpy() 112 | 113 | import math 114 | def normal_log_density_jax(x, mean, log_std, std): 115 | var = np.power(std, 2) 116 | log_density = -np.power(x - mean, 2) / ( 117 | 2 * var) - 0.5 * np.log(2 * np.pi) - log_std 118 | return np.sum(log_density, 1, keepdims=True) 119 | 120 | def get_loss_jax(p_params): 121 | action_means, action_log_stds, action_stds = p_frwd(p_params, n_obs) 122 | 123 | log_prob = normal_log_density_jax(na, action_means, action_log_stds, action_stds) 124 | fixed_log_prob = jax.lax.stop_gradient(log_prob) 125 | action_loss = -r * np.exp(log_prob - fixed_log_prob) 126 | return action_loss.mean() 127 | 128 | def normal_log_density(x, mean, log_std, std): 129 | var = std.pow(2) 130 | log_density = -(x - mean).pow(2) / ( 131 | 2 * var) - 0.5 * math.log(2 * math.pi) - log_std 132 | return log_density.sum(1, keepdim=True) 133 | 134 | def get_loss_torch(): 135 | action_means, action_log_stds, action_stds = policy_net(t_obs) 136 | 137 | log_prob = normal_log_density(ta, action_means, action_log_stds, action_stds) 138 | fixed_log_prob = log_prob.detach() 139 | action_loss = -r * torch.exp(log_prob - fixed_log_prob) 140 | return action_loss.mean() 141 | 142 | get_loss_torch(), get_loss_jax(p_params) 143 | 144 | # %% 145 | from torch.autograd import Variable 146 | def get_kl(): 147 | mean1, log_std1, std1 = policy_net(t_obs) 148 | 149 | mean0 = Variable(mean1.data) 150 | log_std0 = Variable(log_std1.data) 151 | std0 = Variable(std1.data) 152 | kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5 153 | return kl.sum(1, keepdim=True) 154 | 155 | def D_KL_Gauss(θ1, θ2): 156 | mu0, _, std0 = θ1 157 | mu1, _, std1 = θ2 158 | 159 | log_std0 = np.log(std0) 160 | log_std1 = np.log(std1) 161 | 162 | d_kl = log_std1 - log_std0 + (np.power(std0, 2) + np.power(mu0 - mu1, 2)) / (2.0 * np.power(std1, 2)) - 0.5 163 | return d_kl.sum() # sum over actions 164 | 165 | theta = p_frwd(p_params, n_obs) 166 | get_kl(), D_KL_Gauss(theta, theta) 167 | 168 | # %% 169 | loss = get_loss_torch() 170 | grads = torch.autograd.grad(loss, policy_net.parameters()) 171 | for g in grads: print(g.shape) 172 | 173 | # %% 174 | jax_grads = jax.grad(get_loss_jax)(p_params) 175 | for g in jax.tree_leaves(jax_grads): print(g.shape) 176 | 177 | # %% 178 | def Fvp(v): 179 | kl = get_kl() 180 | kl = kl.mean() 181 | 182 | grads = torch.autograd.grad(kl, policy_net.parameters(), create_graph=True) 183 | flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) 184 | 185 | kl_v = (flat_grad_kl * Variable(v)).sum() 186 | grads = torch.autograd.grad(kl_v, policy_net.parameters()) 187 | flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data 188 | 189 | return flat_grad_grad_kl 190 | 191 | def unflat(model, flat_params): 192 | prev_ind = 0 193 | params = [] 194 | for param in model.parameters(): 195 | flat_size = int(onp.prod(list(param.size()))) 196 | params.append( 197 | flat_params[prev_ind:prev_ind + flat_size].view(param.size())) 198 | prev_ind += flat_size 199 | return params 200 | 201 | loss_grad = torch.cat([grad.view(-1) for grad in grads]).data 202 | vp = Fvp(-loss_grad) 203 | vp = unflat(policy_net, vp) 204 | 205 | # %% 206 | def hvp(J, w, v): 207 | return jax.jvp(jax.grad(J), (w,), (v,))[1] 208 | 209 | def pullback_mvp(f, rho, w, v): 210 | # J 211 | z, R_z = jax.jvp(f, (w,), (v,)) 212 | # H J 213 | R_gz = hvp(lambda z1: rho(z, z1), z, R_z) 214 | _, f_vjp = jax.vjp(f, w) 215 | # (HJ)^T J = J^T H J 216 | return f_vjp(R_gz)[0] 217 | 218 | f = lambda w: p_frwd(w, n_obs) 219 | rho = D_KL_Gauss 220 | neg_jax_grads = jax.tree_map(lambda x: -1 * x, jax_grads) 221 | vp_jax = pullback_mvp(f, rho, p_params, neg_jax_grads) 222 | 223 | # %% 224 | vp_jax['~']['log_std'], vp[0] 225 | 226 | # %% 227 | vp_v = [] 228 | for i in range(1, len(vp), 2): 229 | vp_v.append((vp[i], vp[i+1])) 230 | 231 | jax_v = [] 232 | for k in list(vp_jax.keys())[:-1]: 233 | jax_v.append((vp_jax[k]['w'], vp_jax[k]['b'])) 234 | 235 | # %% 236 | for (tw, tb), (jw, jb) in zip(vp_v, jax_v): 237 | print(np.mean(np.abs(tw.numpy() - jw.T))) 238 | print(np.mean(np.abs(tb.numpy() - jb.T))) 239 | 240 | # %% 241 | def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10): 242 | x = torch.zeros(b.size()) 243 | r = b.clone() 244 | p = b.clone() 245 | rdotr = torch.dot(r, r) 246 | for i in range(nsteps): 247 | _Avp = Avp(p) 248 | alpha = rdotr / torch.dot(p, _Avp) 249 | x += alpha * p 250 | r -= alpha * _Avp 251 | new_rdotr = torch.dot(r, r) 252 | betta = new_rdotr / rdotr 253 | p = r + betta * p 254 | rdotr = new_rdotr 255 | if rdotr < residual_tol: 256 | break 257 | return x 258 | 259 | stepdir = conjugate_gradients(Fvp, -loss_grad, 10) 260 | stepdir.shape 261 | 262 | # %% 263 | def jax_conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10): 264 | x = np.zeros_like(b) 265 | r = np.zeros_like(b) + b 266 | p = np.zeros_like(b) + b 267 | rdotr = np.dot(r, r) 268 | for i in range(nsteps): 269 | _Avp = Avp(p) 270 | alpha = rdotr / np.dot(p, _Avp) 271 | x += alpha * p 272 | r -= alpha * _Avp 273 | new_rdotr = np.dot(r, r) 274 | betta = new_rdotr / rdotr 275 | p = r + betta * p 276 | rdotr = new_rdotr 277 | if rdotr < residual_tol: 278 | break 279 | return x 280 | 281 | def lax_jax_conjugate_gradients(Avp, b, nsteps): 282 | x = np.zeros_like(b) 283 | r = np.zeros_like(b) + b 284 | p = np.zeros_like(b) + b 285 | residual_tol = 1e-10 286 | 287 | cond = lambda v: v['step'] < nsteps 288 | def body(v): 289 | p, r, x = v['p'], v['r'], v['x'] 290 | rdotr = np.dot(r, r) 291 | _Avp = Avp(p) 292 | alpha = rdotr / np.dot(p, _Avp) 293 | x += alpha * p 294 | r -= alpha * _Avp 295 | new_rdotr = np.dot(r, r) 296 | betta = new_rdotr / rdotr 297 | p = r + betta * p 298 | rdotr = new_rdotr 299 | new_step = jax.lax.cond(rdotr < residual_tol, lambda s: s + nsteps, lambda s: s +1, v['step']) 300 | return {'step': new_step, 'p': p, 'r': r, 'x': x} 301 | 302 | init = {'step': 0, 'p': p, 'r': r, 'x': x} 303 | x = jax.lax.while_loop(cond, body, init)['x'] 304 | return x 305 | 306 | mvp = lambda v: pullback_mvp(f, rho, p_params, v) 307 | neg_grads = jax.tree_map(lambda x: -1 * x, jax_grads) 308 | flat_grads, unflatten_fcn = jax.flatten_util.ravel_pytree(jax_grads) 309 | flat_mvp = lambda v: jax.flatten_util.ravel_pytree(mvp(unflatten_fcn(v)))[0] 310 | stepdir_jax = lax_jax_conjugate_gradients(flat_mvp, -flat_grads, 10) 311 | stepdir_jax.shape 312 | 313 | # %% 314 | np.abs(stepdir_jax - stepdir.numpy()).mean() # 0.00201443 315 | 316 | # %% 317 | shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True) 318 | lm = torch.sqrt(shs / 1e-2) 319 | fullstep = stepdir / lm[0] 320 | neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True) 321 | expected_improve_rate = neggdotstepdir / lm[0] 322 | 323 | lm, expected_improve_rate 324 | 325 | # %% 326 | shs_j = .5 * (stepdir_jax * flat_mvp(stepdir_jax)).sum() 327 | lm = np.sqrt(shs_j / 1e-2) 328 | fullstep = stepdir_jax / lm 329 | 330 | neggdotstepdir = (-flat_grads * stepdir_jax).sum() 331 | 332 | fullstep = unflatten_fcn(fullstep) 333 | expected_improve_rate = neggdotstepdir / lm 334 | 335 | lm, expected_improve_rate 336 | 337 | # %% 338 | f = lambda w: p_frwd(w, n_obs) 339 | rho = D_KL_Gauss 340 | neg_jax_grads = jax.tree_map(lambda x: -1 * x, jax_grads) 341 | mvp = lambda v: pullback_mvp(f, rho, p_params, v) 342 | p_ngrad, _ = jax.scipy.sparse.linalg.cg(mvp, 343 | neg_jax_grads, maxiter=10, tol=1e-10) 344 | flat_ngrad, _ = jax.flatten_util.ravel_pytree(p_ngrad) 345 | 346 | np.abs(flat_ngrad - stepdir.numpy()).mean() # 475711.62 347 | 348 | # %% 349 | 350 | 351 | 352 | 353 | 354 | -------------------------------------------------------------------------------- /trpo/trpo.py: -------------------------------------------------------------------------------- 1 | # works :) 2 | 3 | #%% 4 | import jax 5 | import jax.numpy as np 6 | import numpy as onp 7 | import distrax 8 | import optax 9 | import gym 10 | from functools import partial 11 | import cloudpickle 12 | import haiku as hk 13 | 14 | from jax.config import config 15 | config.update("jax_enable_x64", True) 16 | config.update("jax_debug_nans", True) # break on nans 17 | 18 | env_name = 'CartPole-v0' 19 | env = gym.make(env_name) 20 | 21 | n_actions = env.action_space.n 22 | obs_dim = env.observation_space.shape[0] 23 | 24 | print(f'[LOGGER] n_actions: {n_actions} obs_dim: {obs_dim}') 25 | 26 | #%% 27 | init_final = hk.initializers.RandomUniform(-3e-3, 3e-3) 28 | 29 | def _policy_fcn(s): 30 | pi = hk.Sequential([ 31 | hk.Linear(64), jax.nn.relu, 32 | hk.Linear(64), jax.nn.relu, 33 | hk.Linear(n_actions, w_init=init_final), jax.nn.softmax 34 | ])(s) 35 | return pi 36 | 37 | def _critic_fcn(s): 38 | v = hk.Sequential([ 39 | hk.Linear(64), jax.nn.relu, 40 | hk.Linear(64), jax.nn.relu, 41 | hk.Linear(1), 42 | ])(s) 43 | return v 44 | 45 | policy_fcn = hk.transform(_policy_fcn) 46 | policy_fcn = hk.without_apply_rng(policy_fcn) 47 | p_frwd = jax.jit(policy_fcn.apply) 48 | 49 | critic_fcn = hk.transform(_critic_fcn) 50 | critic_fcn = hk.without_apply_rng(critic_fcn) 51 | v_frwd = jax.jit(critic_fcn.apply) 52 | 53 | #%% 54 | @jax.jit 55 | def policy(params, obs, rng): 56 | pi = p_frwd(params, obs) 57 | dist = distrax.Categorical(probs=pi) 58 | a = dist.sample(seed=rng) 59 | log_prob = dist.log_prob(a) 60 | return a, log_prob 61 | 62 | class Vector_ReplayBuffer: 63 | def __init__(self, buffer_capacity): 64 | self.buffer_capacity = buffer_capacity = int(buffer_capacity) 65 | self.i = 0 66 | # obs, obs2, a, r, done 67 | self.splits = [obs_dim, obs_dim+1, obs_dim+1+1, obs_dim*2+1+1, obs_dim*2+1+1+1] 68 | self.clear() 69 | 70 | def push(self, sample): 71 | assert self.i < self.buffer_capacity # dont let it get full 72 | (obs, a, r, obs2, done, log_prob) = sample 73 | self.buffer[self.i] = onp.array([*obs, onp.array(a), onp.array(r), *obs2, float(done), onp.array(log_prob)]) 74 | self.i += 1 75 | 76 | def contents(self): 77 | return onp.split(self.buffer[:self.i], self.splits, axis=-1) 78 | 79 | def clear(self): 80 | self.i = 0 81 | self.buffer = onp.zeros((self.buffer_capacity, 2 * obs_dim + 1 + 2 + 1)) 82 | 83 | def eval(params, env, rng): 84 | rewards = 0 85 | obs = env.reset() 86 | while True: 87 | rng, subrng = jax.random.split(rng) 88 | a = policy(params, obs, subrng)[0].item() 89 | obs2, r, done, _ = env.step(a) 90 | obs = obs2 91 | rewards += r 92 | if done: break 93 | return rewards 94 | 95 | def discount_cumsum(l, discount): 96 | l = onp.array(l) 97 | for i in range(len(l) - 1)[::-1]: 98 | l[i] = l[i] + discount * l[i+1] 99 | return l 100 | 101 | def compute_advantage_targets(v_params, rollout): 102 | (obs, _, r, obs2, done, _) = rollout 103 | 104 | batch_v_fcn = jax.vmap(partial(v_frwd, v_params)) 105 | v_obs = batch_v_fcn(obs) 106 | v_obs2 = batch_v_fcn(obs2) 107 | 108 | # gae 109 | deltas = (r + (1 - done) * gamma * v_obs2) - v_obs 110 | deltas = jax.lax.stop_gradient(deltas) 111 | adv = discount_cumsum(deltas, discount=gamma * lmbda) 112 | adv = (adv - adv.mean()) / (adv.std() + 1e-8) 113 | 114 | # reward2go 115 | v_target = discount_cumsum(r, discount=gamma) 116 | 117 | return adv, v_target 118 | 119 | class Worker: 120 | def __init__(self, n_steps): 121 | self.n_steps = n_steps 122 | self.buffer = Vector_ReplayBuffer(1e6) 123 | # import pybullet_envs 124 | # self.env = make_env() 125 | self.env = gym.make(env_name) 126 | self.obs = self.env.reset() 127 | 128 | def rollout(self, p_params, v_params, rng): 129 | self.buffer.clear() 130 | 131 | for _ in range(self.n_steps): # rollout 132 | rng, subrng = jax.random.split(rng) 133 | a, log_prob = policy(p_params, self.obs, subrng) 134 | a = a.item() 135 | 136 | obs2, r, done, _ = self.env.step(a) 137 | 138 | self.buffer.push((self.obs, a, r, obs2, done, log_prob)) 139 | self.obs = obs2 140 | if done: 141 | self.obs = self.env.reset() 142 | 143 | # update rollout contents 144 | rollout = self.buffer.contents() 145 | advantages, v_target = compute_advantage_targets(v_params, rollout) 146 | (obs, a, r, _, _, log_prob) = rollout 147 | rollout = (obs, a, log_prob, v_target, advantages) 148 | 149 | return rollout 150 | 151 | def optim_update_fcn(optim): 152 | @jax.jit 153 | def update_step(params, grads, opt_state): 154 | grads, opt_state = optim.update(grads, opt_state) 155 | params = optax.apply_updates(params, grads) 156 | return params, opt_state 157 | return update_step 158 | 159 | #%% 160 | seed = 24910 #onp.random.randint(1e5) 161 | 162 | epochs = 1000 163 | n_step_rollout = 4000 164 | # v training 165 | v_lr = 1e-3 166 | n_v_iters = 80 167 | # gae 168 | gamma = 0.99 169 | lmbda = 0.95 170 | # trpo 171 | delta = 0.01 172 | damp_lambda = 1e-5 173 | n_search_iters = 10 174 | cg_iters = 10 175 | 176 | rng = jax.random.PRNGKey(seed) 177 | onp.random.seed(seed) 178 | env.seed(seed) 179 | 180 | obs = env.reset() # dummy input 181 | p_params = policy_fcn.init(rng, obs) 182 | v_params = critic_fcn.init(rng, obs) 183 | 184 | worker = Worker(n_step_rollout) 185 | 186 | optimizer = lambda lr: optax.chain( 187 | optax.scale_by_adam(), 188 | optax.scale(-lr), 189 | ) 190 | v_optim = optimizer(v_lr) 191 | v_opt_state = v_optim.init(v_params) 192 | v_update_fcn = optim_update_fcn(v_optim) 193 | 194 | # %% 195 | def policy_loss(p_params, sample): 196 | (obs, a, old_log_prob, _, advantages) = sample 197 | 198 | pi = p_frwd(p_params, obs) 199 | dist = distrax.Categorical(probs=pi) 200 | ratio = np.exp(dist.log_prob(a) - old_log_prob) 201 | loss = -(ratio * advantages).sum() 202 | return loss 203 | 204 | @jax.jit 205 | def batch_policy_loss(p_params, batch): 206 | return jax.vmap(partial(policy_loss, p_params))(batch).mean() 207 | 208 | def critic_loss(v_params, sample): 209 | (obs, _, _, v_target, _) = sample 210 | 211 | v_obs = v_frwd(v_params, obs) 212 | loss = (0.5 * ((v_obs - v_target) ** 2)).sum() 213 | return loss 214 | 215 | def batch_critic_loss(v_params, batch): 216 | return jax.vmap(partial(critic_loss, v_params))(batch).mean() 217 | 218 | @jax.jit 219 | def critic_step(v_params, opt_state, batch): 220 | loss, grads = jax.value_and_grad(batch_critic_loss)(v_params, batch) 221 | v_params, opt_state = v_update_fcn(v_params, grads, opt_state) 222 | return loss, v_params, opt_state 223 | 224 | def hvp(J, w, v): 225 | return jax.jvp(jax.grad(J), (w,), (v,))[1] 226 | 227 | def D_KL_probs(p1, p2): 228 | d_kl = (p1 * (np.log(p1) - np.log(p2))).sum() 229 | return d_kl 230 | 231 | def D_KL_probs_params(param1, param2, obs): 232 | p1, p2 = p_frwd(param1, obs), p_frwd(param2, obs) 233 | return D_KL_probs(p1, p2) 234 | 235 | def pullback_mvp(f, rho, w, v): 236 | z, R_z = jax.jvp(f, (w,), (v,)) 237 | R_gz = hvp(lambda z1: rho(z, z1), z, R_z) 238 | _, f_vjp = jax.vjp(f, w) 239 | return f_vjp(R_gz)[0] 240 | 241 | @jax.jit 242 | def sgd_step(params, grads, alpha): 243 | sgd_update = lambda param, grad: param - alpha * grad 244 | return jax.tree_multimap(sgd_update, params, grads) 245 | 246 | @jax.jit 247 | def sgd_step_tree(params, grads, alphas): 248 | sgd_update = lambda param, grad, alpha: param - alpha * grad 249 | return jax.tree_multimap(sgd_update, params, grads, alphas) 250 | 251 | import operator 252 | tree_scalar_op = lambda op: lambda tree, arg2: jax.tree_map(lambda x: op(x, arg2), tree) 253 | tree_scalar_divide = tree_scalar_op(operator.truediv) 254 | tree_scalar_mult = tree_scalar_op(operator.mul) 255 | 256 | # backtracking line-search 257 | def line_search(alpha_start, init_loss, p_params, p_ngrad, rollout, n_iters, delta): 258 | obs = rollout[0] 259 | for i in np.arange(n_iters): 260 | alpha = tree_scalar_divide(alpha_start, 2 ** i) 261 | 262 | new_p_params = sgd_step_tree(p_params, p_ngrad, alpha) 263 | new_loss = batch_policy_loss(new_p_params, rollout) 264 | 265 | d_kl = jax.vmap(partial(D_KL_probs_params, new_p_params, p_params))(obs).mean() 266 | 267 | if (new_loss < init_loss) and (d_kl <= delta): 268 | writer.add_scalar('info/line_search_n_iters', i, e) 269 | return new_p_params # new weights 270 | 271 | writer.add_scalar('info/line_search_n_iters', -1, e) 272 | return p_params # no new weights 273 | 274 | def tree_mvp_dampen(mvp, lmbda=0.1): 275 | dampen_fcn = lambda mvp_, v_: mvp_ + lmbda * v_ 276 | damp_mvp = lambda v: jax.tree_multimap(dampen_fcn, mvp(v), v) 277 | return damp_mvp 278 | 279 | def natural_grad(p_params, sample): 280 | obs = sample[0] 281 | loss, p_grads = jax.value_and_grad(policy_loss)(p_params, sample) 282 | f = lambda w: p_frwd(w, obs) 283 | rho = D_KL_probs 284 | p_ngrad, _ = jax.scipy.sparse.linalg.cg( 285 | tree_mvp_dampen(lambda v: pullback_mvp(f, rho, p_params, v), damp_lambda), 286 | p_grads, maxiter=cg_iters) 287 | 288 | # compute optimal step 289 | # SGD = theta - alpha * ∇J where ∇J = Jacobian of Loss 290 | # theta = N x 1, ∇J = N x M, F = N x N where N = #params and M = #outputs (for scalar loss M = 1! so ∇J = N x 1) 291 | # Note: s^T H s (from paper) = ∇J^T (H^-1 ∇J <-- we know this already; p_ngrad) 292 | # H^-1 ∇J = N x 1 ... then ∇J^T H^-1 ∇J = 1x1 293 | # thus turn both to vecs and dot prod to compute alpha 294 | vec = lambda x: x.flatten()[:, None] 295 | mat_mul = lambda x, y: np.sqrt(2 * delta / (vec(x).T @ vec(y)).flatten()) 296 | alpha = jax.tree_multimap(mat_mul, p_grads, p_ngrad) 297 | 298 | return loss, p_ngrad, alpha 299 | 300 | @jax.jit 301 | def batch_natural_grad(p_params, batch): 302 | out = jax.vmap(partial(natural_grad, p_params))(batch) 303 | out = jax.tree_map(lambda x: x.mean(0), out) 304 | return out 305 | 306 | def sample_rollout(rollout, p): 307 | if p < 1.: rollout_len = int(rollout[0].shape[0] * p) 308 | else: rollout_len = p 309 | idxs = onp.random.choice(rollout_len, size=rollout_len, replace=False) 310 | sampled_rollout = jax.tree_map(lambda x: x[idxs], rollout) 311 | return sampled_rollout 312 | 313 | #%% 314 | from torch.utils.tensorboard import SummaryWriter 315 | writer = SummaryWriter(comment=f'trpo_{env_name}_seed={seed}_nrollout={n_step_rollout}') 316 | 317 | #%% 318 | from tqdm import tqdm 319 | 320 | for e in tqdm(range(epochs)): 321 | # rollout 322 | rng, subkey = jax.random.split(rng, 2) 323 | rollout = worker.rollout(p_params, v_params, subkey) 324 | 325 | # train 326 | sampled_rollout = sample_rollout(rollout, 0.1) # natural grad on 10% of data 327 | loss, p_ngrad, alpha = batch_natural_grad(p_params, sampled_rollout) 328 | for i, g in enumerate(jax.tree_leaves(alpha)): 329 | writer.add_scalar(f'alpha/{i}', g.item(), e) 330 | 331 | # update 332 | p_params = line_search(alpha, loss, p_params, p_ngrad, rollout, n_search_iters, delta) 333 | writer.add_scalar('info/ploss', loss.item(), e) 334 | 335 | v_loss = 0 336 | for _ in range(n_v_iters): 337 | loss, v_params, v_opt_state = critic_step(v_params, v_opt_state, rollout) 338 | v_loss += loss 339 | 340 | v_loss /= n_v_iters 341 | writer.add_scalar('info/vloss', v_loss.item(), e) 342 | 343 | for i, g in enumerate(jax.tree_leaves(p_ngrad)): 344 | name = 'b' if len(g.shape) == 1 else 'w' 345 | writer.add_histogram(f'{name}_{i}_grad', onp.array(g), e) 346 | 347 | rng, subkey = jax.random.split(rng, 2) 348 | r = eval(p_params, env, subkey) 349 | writer.add_scalar('eval/total_reward', r, e) 350 | 351 | # %% 352 | 353 | # %% 354 | # %% 355 | --------------------------------------------------------------------------------