├── .gitignore ├── .pylintrc ├── LICENSE ├── README.md ├── algorithms ├── bc.py ├── combo.py ├── cql.py ├── dynamics.py ├── edac.py ├── iql.py ├── mopo.py ├── morel.py ├── rebrac.py ├── sac_n.py ├── td3_bc.py ├── termination_fns.py └── unifloral.py ├── configs ├── algorithms │ ├── bc.yaml │ ├── combo.yaml │ ├── cql.yaml │ ├── edac.yaml │ ├── iql.yaml │ ├── mopo.yaml │ ├── morel.yaml │ ├── rebrac.yaml │ ├── sac_n.yaml │ └── td3_bc.yaml ├── dynamics.yaml └── unifloral │ ├── bc.yaml │ ├── edac.yaml │ ├── iql.yaml │ ├── mobrac.yaml │ ├── rebrac.yaml │ ├── sac_n.yaml │ ├── td3_awr.yaml │ └── td3_bc.yaml ├── evaluation.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | dynamics_models/ 3 | final_returns/ 4 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | 3 | # Enable the message, report, category or checker with the given id(s). You can 4 | # either give multiple identifier separated by comma (,) or put this option 5 | # multiple time. 6 | #enable= 7 | 8 | # Disable the message, report, category or checker with the given id(s). You 9 | # can either give multiple identifier separated by comma (,) or put this option 10 | # multiple time (only on the command line, not in the configuration file where 11 | # it should appear only once). 12 | disable=C0114,C0115,C0116,W0105,W0621,C3001 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 Matthew Jackson 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

🌹 Unifloral: Unified Offline Reinforcement Learning

2 | 3 |

4 | 5 | 6 |

7 | 8 | Unified implementations and rigorous evaluation for offline reinforcement learning - built by [Matthew Jackson](https://github.com/EmptyJackson), [Uljad Berdica](https://github.com/uljad), and [Jarek Liesen](https://github.com/keraJLi). 9 | 10 | ## 💡 Code Philosophy 11 | 12 | - ⚛️ **Single-file**: We implement algorithms as standalone Python files. 13 | - 🤏 **Minimal**: We only edit what is necessary between algorithms, making comparisons straightforward. 14 | - ⚡️ **GPU-accelerated**: We use JAX and end-to-end compile all training code, enabling lightning-fast training. 15 | 16 | Inspired by [CORL](https://github.com/tinkoff-ai/CORL) and [CleanRL](https://github.com/vwxyzjn/cleanrl) - check them out! 17 | 18 | ## 🤖 Algorithms 19 | 20 | We provide two types of algorithm implementation: 21 | 22 | 1. **Standalone**: Each algorithm is implemented as a [single file](algorithms) with minimal dependencies, making it easy to understand and modify. 23 | 2. **Unified**: Most algorithms are available as configs for our unified implementation [`unifloral.py`](algorithms/unifloral.py). 24 | 25 | After training, final evaluation results are saved to `.npz` files in [`final_returns/`](final_returns) for analysis using our evaluation protocol. 26 | 27 | All scripts support [D4RL](https://github.com/Farama-Foundation/D4RL) and use [Weights & Biases](https://wandb.ai) for logging, with configs provided as WandB sweep files. 28 | 29 | ### Model-free 30 | 31 | | Algorithm | Standalone | Unified | Extras | 32 | | --- | --- | --- | --- | 33 | | BC | [`bc.py`](algorithms/bc.py) | [`unifloral/bc.yaml`](configs/unifloral/bc.yaml) | - | 34 | | SAC-N | [`sac_n.py`](algorithms/sac_n.py) | [`unifloral/sac_n.yaml`](configs/unifloral/sac_n.yaml) | [[ArXiv]](https://arxiv.org/abs/2110.01548) | 35 | | EDAC | [`edac.py`](algorithms/edac.py) | [`unifloral/edac.yaml`](configs/unifloral/edac.yaml) | [[ArXiv]](https://arxiv.org/abs/2110.01548) | 36 | | CQL | [`cql.py`](algorithms/cql.py) | - | [[ArXiv]](https://arxiv.org/abs/2006.04779) | 37 | | IQL | [`iql.py`](algorithms/iql.py) | [`unifloral/iql.yaml`](configs/unifloral/iql.yaml) | [[ArXiv]](https://arxiv.org/abs/2110.06169) | 38 | | TD3-BC | [`td3_bc.py`](algorithms/td3_bc.py) | [`unifloral/td3_bc.yaml`](configs/unifloral/td3_bc.yaml) | [[ArXiv]](https://arxiv.org/abs/2106.06860) | 39 | | ReBRAC | [`rebrac.py`](algorithms/rebrac.py) | [`unifloral/rebrac.yaml`](configs/unifloral/rebrac.yaml) | [[ArXiv]](https://arxiv.org/abs/2305.09836) | 40 | | TD3-AWR | - | [`unifloral/td3_awr.yaml`](configs/unifloral/td3_awr.yaml) | [[ArXiv]](https://arxiv.org/abs/2504.11453) | 41 | 42 | ### Model-based 43 | 44 | We implement a single script for dynamics model training: [`dynamics.py`](algorithms/dynamics.py), with config [`dynamics.yaml`](configs/dynamics.yaml). 45 | 46 | | Algorithm | Standalone | Unified | Extras | 47 | | --- | --- | --- | --- | 48 | | MOPO | [`mopo.py`](algorithms/mopo.py) | - | [[ArXiv]](https://arxiv.org/abs/2005.13239) | 49 | | MOReL | [`morel.py`](algorithms/morel.py) | - | [[ArXiv]](https://arxiv.org/abs/2005.05951) | 50 | | COMBO | [`combo.py`](algorithms/combo.py) | - | [[ArXiv]](https://arxiv.org/abs/2102.08363) | 51 | | MoBRAC | - | [`unifloral/mobrac.yaml`](configs/unifloral/mobrac.yaml) | [[ArXiv]](https://arxiv.org/abs/2504.11453) | 52 | 53 | New ones coming soon 👀 54 | 55 | ## 📊 Evaluation 56 | 57 | Our evaluation script ([`evaluation.py`](evaluation.py)) implements the protocol described in our paper, analysing the performance of a UCB bandit over a range of policy evaluations. 58 | 59 | ```python 60 | from evaluation import load_results_dataframe, bootstrap_bandit_trials 61 | import jax.numpy as jnp 62 | 63 | # Load all results from the final_returns directory 64 | df = load_results_dataframe("final_returns") 65 | 66 | # Run bandit trials with bootstrapped confidence intervals 67 | results = bootstrap_bandit_trials( 68 | returns_array=jnp.array(policy_returns), # Shape: (num_policies, num_rollouts) 69 | num_subsample=8, # Number of policies to subsample 70 | num_repeats=1000, # Number of bandit trials 71 | max_pulls=200, # Maximum pulls per trial 72 | ucb_alpha=2.0, # UCB exploration coefficient 73 | n_bootstraps=1000, # Bootstrap samples for confidence intervals 74 | confidence=0.95 # Confidence level 75 | ) 76 | 77 | # Access results 78 | pulls = results["pulls"] # Number of pulls at each step 79 | means = results["estimated_bests_mean"] # Mean score of estimated best policy 80 | ci_low = results["estimated_bests_ci_low"] # Lower confidence bound 81 | ci_high = results["estimated_bests_ci_high"] # Upper confidence bound 82 | ``` 83 | 84 | ## 📝 Cite us! 85 | ```bibtex 86 | @misc{jackson2025clean, 87 | title={A Clean Slate for Offline Reinforcement Learning}, 88 | author={Matthew Thomas Jackson and Uljad Berdica and Jarek Liesen and Shimon Whiteson and Jakob Nicolaus Foerster}, 89 | year={2025}, 90 | eprint={2504.11453}, 91 | archivePrefix={arXiv}, 92 | primaryClass={cs.LG}, 93 | url={https://arxiv.org/abs/2504.11453}, 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /algorithms/bc.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from dataclasses import dataclass, asdict 3 | from datetime import datetime 4 | import os 5 | import warnings 6 | 7 | import distrax 8 | import d4rl 9 | import flax.linen as nn 10 | from flax.training.train_state import TrainState 11 | import gym 12 | import jax 13 | import jax.numpy as jnp 14 | import numpy as onp 15 | import optax 16 | import tyro 17 | import wandb 18 | 19 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 20 | 21 | 22 | @dataclass 23 | class Args: 24 | # --- Experiment --- 25 | seed: int = 0 26 | dataset: str = "halfcheetah-medium-v2" 27 | algorithm: str = "bc" 28 | num_updates: int = 1_000_000 29 | eval_interval: int = 2500 30 | eval_workers: int = 8 31 | eval_final_episodes: int = 1000 32 | # --- Logging --- 33 | log: bool = False 34 | wandb_project: str = "unifloral" 35 | wandb_team: str = "flair" 36 | wandb_group: str = "debug" 37 | # --- Generic optimization --- 38 | lr: float = 3e-4 39 | batch_size: int = 256 40 | 41 | 42 | r""" 43 | |\ __ 44 | \| /_/ 45 | \| 46 | ___|_____ 47 | \ / 48 | \ / 49 | \___/ Preliminaries 50 | """ 51 | 52 | AgentTrainState = namedtuple("AgentTrainState", "actor") 53 | Transition = namedtuple("Transition", "obs action reward next_obs done") 54 | 55 | 56 | class DeterministicTanhActor(nn.Module): 57 | num_actions: int 58 | obs_mean: jax.Array 59 | obs_std: jax.Array 60 | 61 | @nn.compact 62 | def __call__(self, x): 63 | x = (x - self.obs_mean) / (self.obs_std + 1e-3) 64 | for _ in range(2): 65 | x = nn.Dense(256)(x) 66 | x = nn.relu(x) 67 | action = nn.Dense(self.num_actions)(x) 68 | pi = distrax.Transformed( 69 | distrax.Deterministic(action), 70 | distrax.Tanh(), 71 | ) 72 | return pi 73 | 74 | 75 | def create_train_state(args, rng, network, dummy_input): 76 | return TrainState.create( 77 | apply_fn=network.apply, 78 | params=network.init(rng, *dummy_input), 79 | tx=optax.adam(args.lr, eps=1e-5), 80 | ) 81 | 82 | 83 | def eval_agent(args, rng, env, agent_state): 84 | # --- Reset environment --- 85 | step = 0 86 | returned = onp.zeros(args.eval_workers).astype(bool) 87 | cum_reward = onp.zeros(args.eval_workers) 88 | rng, rng_reset = jax.random.split(rng) 89 | rng_reset = jax.random.split(rng_reset, args.eval_workers) 90 | obs = env.reset() 91 | 92 | # --- Rollout agent --- 93 | @jax.jit 94 | @jax.vmap 95 | def _policy_step(rng, obs): 96 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs) 97 | action = pi.sample(seed=rng) 98 | return jnp.nan_to_num(action) 99 | 100 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps 101 | while step < max_episode_steps and not returned.all(): 102 | # --- Take step in environment --- 103 | step += 1 104 | rng, rng_step = jax.random.split(rng) 105 | rng_step = jax.random.split(rng_step, args.eval_workers) 106 | action = _policy_step(rng_step, jnp.array(obs)) 107 | obs, reward, done, info = env.step(onp.array(action)) 108 | 109 | # --- Track cumulative reward --- 110 | cum_reward += reward * ~returned 111 | returned |= done 112 | 113 | if step >= max_episode_steps and not returned.all(): 114 | warnings.warn("Maximum steps reached before all episodes terminated") 115 | return cum_reward 116 | 117 | 118 | r""" 119 | __/) 120 | .-(__(=: 121 | |\ | \) 122 | \ || 123 | \|| 124 | \| 125 | ___|_____ 126 | \ / 127 | \ / 128 | \___/ Agent 129 | """ 130 | 131 | 132 | def make_train_step(args, actor_apply_fn, dataset): 133 | """Make JIT-compatible agent train step.""" 134 | 135 | def _train_step(runner_state, _): 136 | rng, agent_state = runner_state 137 | 138 | # --- Sample batch --- 139 | rng, rng_batch = jax.random.split(rng) 140 | batch_indices = jax.random.randint( 141 | rng_batch, (args.batch_size,), 0, len(dataset.obs) 142 | ) 143 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset) 144 | 145 | # --- Update actor --- 146 | def _actor_loss_function(params): 147 | def _transition_loss(transition): 148 | pi = actor_apply_fn(params, transition.obs) 149 | pi_action = pi.sample(seed=None) 150 | bc_loss = jnp.square(pi_action - transition.action).mean() 151 | return bc_loss 152 | 153 | bc_loss = jax.vmap(_transition_loss)(batch) 154 | return bc_loss.mean() 155 | 156 | loss_fn = jax.value_and_grad(_actor_loss_function) 157 | actor_loss, actor_grad = loss_fn(agent_state.actor.params) 158 | agent_state = agent_state._replace( 159 | actor=agent_state.actor.apply_gradients(grads=actor_grad) 160 | ) 161 | 162 | loss = { 163 | "actor_loss": actor_loss, 164 | "bc_loss": actor_loss, 165 | } 166 | return (rng, agent_state), loss 167 | 168 | return _train_step 169 | 170 | 171 | if __name__ == "__main__": 172 | # --- Parse arguments --- 173 | args = tyro.cli(Args) 174 | rng = jax.random.PRNGKey(args.seed) 175 | 176 | # --- Initialize logger --- 177 | if args.log: 178 | wandb.init( 179 | config=args, 180 | project=args.wandb_project, 181 | entity=args.wandb_team, 182 | group=args.wandb_group, 183 | job_type="train_agent", 184 | ) 185 | 186 | # --- Initialize environment and dataset --- 187 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers) 188 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset)) 189 | dataset = Transition( 190 | obs=jnp.array(dataset["observations"]), 191 | action=jnp.array(dataset["actions"]), 192 | reward=jnp.array(dataset["rewards"]), 193 | next_obs=jnp.array(dataset["next_observations"]), 194 | done=jnp.array(dataset["terminals"]), 195 | ) 196 | 197 | # --- Initialize agent and value networks --- 198 | num_actions = env.single_action_space.shape[0] 199 | obs_mean = dataset.obs.mean(axis=0) 200 | obs_std = jnp.nan_to_num(dataset.obs.std(axis=0), nan=1.0) 201 | dummy_obs = jnp.zeros(env.single_observation_space.shape) 202 | dummy_action = jnp.zeros(num_actions) 203 | actor_net = DeterministicTanhActor(num_actions, obs_mean, obs_std) 204 | 205 | rng, rng_actor = jax.random.split(rng) 206 | agent_state = AgentTrainState( 207 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]) 208 | ) 209 | 210 | # --- Make train step --- 211 | _agent_train_step_fn = make_train_step(args, actor_net.apply, dataset) 212 | 213 | num_evals = args.num_updates // args.eval_interval 214 | for eval_idx in range(num_evals): 215 | # --- Execute train loop --- 216 | (rng, agent_state), loss = jax.lax.scan( 217 | _agent_train_step_fn, 218 | (rng, agent_state), 219 | None, 220 | args.eval_interval, 221 | ) 222 | 223 | # --- Evaluate agent --- 224 | rng, rng_eval = jax.random.split(rng) 225 | returns = eval_agent(args, rng_eval, env, agent_state) 226 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0 227 | 228 | # --- Log metrics --- 229 | step = (eval_idx + 1) * args.eval_interval 230 | print("Step:", step, f"\t Score: {scores.mean():.2f}") 231 | if args.log: 232 | log_dict = { 233 | "return": returns.mean(), 234 | "score": scores.mean(), 235 | "score_std": scores.std(), 236 | "num_updates": step, 237 | **{k: loss[k][-1] for k in loss}, 238 | } 239 | wandb.log(log_dict) 240 | 241 | # --- Evaluate final agent --- 242 | if args.eval_final_episodes > 0: 243 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers)) 244 | print(f"Evaluating final agent for {final_iters} iterations...") 245 | _rng = jax.random.split(rng, final_iters) 246 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng]) 247 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0 248 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()} 249 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores") 250 | 251 | # --- Write final returns to file --- 252 | os.makedirs("final_returns", exist_ok=True) 253 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 254 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz" 255 | with open(os.path.join("final_returns", filename), "wb") as f: 256 | onp.savez_compressed(f, **info, args=asdict(args)) 257 | 258 | if args.log: 259 | wandb.save(os.path.join("final_returns", filename)) 260 | 261 | if args.log: 262 | wandb.finish() 263 | -------------------------------------------------------------------------------- /algorithms/combo.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from dataclasses import dataclass, asdict 3 | from datetime import datetime 4 | from functools import partial 5 | import os 6 | import warnings 7 | 8 | import distrax 9 | import d4rl 10 | import flax.linen as nn 11 | from flax.linen.initializers import constant, uniform 12 | from flax.training.train_state import TrainState 13 | import gym 14 | import jax 15 | import jax.numpy as jnp 16 | import numpy as onp 17 | import optax 18 | import tyro 19 | import wandb 20 | 21 | from dynamics import ( 22 | Transition, 23 | load_dynamics_model, 24 | EnsembleDynamics, # required for loading dynamics model 25 | EnsembleDynamicsModel, # required for loading dynamics model 26 | ) 27 | 28 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 29 | 30 | 31 | @dataclass 32 | class Args: 33 | # --- Experiment --- 34 | seed: int = 0 35 | dataset: str = "halfcheetah-medium-v2" 36 | algorithm: str = "combo" 37 | num_updates: int = 3_000_000 38 | eval_interval: int = 2500 39 | eval_workers: int = 8 40 | eval_final_episodes: int = 1000 41 | # --- Logging --- 42 | log: bool = False 43 | wandb_project: str = "unifloral" 44 | wandb_team: str = "flair" 45 | wandb_group: str = "debug" 46 | # --- Generic optimization --- 47 | lr: float = 3e-4 48 | batch_size: int = 256 49 | gamma: float = 0.99 50 | polyak_step_size: float = 0.005 51 | # --- SAC-N --- 52 | num_critics: int = 10 53 | # --- World model --- 54 | model_path: str = "" 55 | rollout_interval: int = 1000 56 | rollout_length: int = 5 57 | rollout_batch_size: int = 50000 58 | model_retain_epochs: int = 5 59 | dataset_sample_ratio: float = 0.5 60 | # --- CQL --- 61 | actor_lr: float = 1e-4 62 | cql_temperature: float = 1.0 63 | cql_min_q_weight: float = 0.5 64 | 65 | 66 | r""" 67 | |\ __ 68 | \| /_/ 69 | \| 70 | ___|_____ 71 | \ / 72 | \ / 73 | \___/ Preliminaries 74 | """ 75 | 76 | AgentTrainState = namedtuple("AgentTrainState", "actor vec_q vec_q_target alpha") 77 | 78 | 79 | def sym(scale): 80 | def _init(*args, **kwargs): 81 | return uniform(2 * scale)(*args, **kwargs) - scale 82 | 83 | return _init 84 | 85 | 86 | class SoftQNetwork(nn.Module): 87 | @nn.compact 88 | def __call__(self, obs, action): 89 | x = jnp.concatenate([obs, action], axis=-1) 90 | for _ in range(3): 91 | x = nn.Dense(256, bias_init=constant(0.1))(x) 92 | x = nn.relu(x) 93 | q = nn.Dense(1, kernel_init=sym(3e-3), bias_init=sym(3e-3))(x) 94 | return q.squeeze(-1) 95 | 96 | 97 | class VectorQ(nn.Module): 98 | num_critics: int 99 | 100 | @nn.compact 101 | def __call__(self, obs, action): 102 | vmap_critic = nn.vmap( 103 | SoftQNetwork, 104 | variable_axes={"params": 0}, # Parameters not shared between critics 105 | split_rngs={"params": True, "dropout": True}, # Different initializations 106 | in_axes=None, 107 | out_axes=-1, 108 | axis_size=self.num_critics, 109 | ) 110 | q_values = vmap_critic()(obs, action) 111 | return q_values 112 | 113 | 114 | class TanhGaussianActor(nn.Module): 115 | num_actions: int 116 | log_std_max: float = 2.0 117 | log_std_min: float = -5.0 118 | 119 | @nn.compact 120 | def __call__(self, x): 121 | for _ in range(3): 122 | x = nn.Dense(256, bias_init=constant(0.1))(x) 123 | x = nn.relu(x) 124 | log_std = nn.Dense( 125 | self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3) 126 | )(x) 127 | std = jnp.exp(jnp.clip(log_std, self.log_std_min, self.log_std_max)) 128 | mean = nn.Dense(self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3))(x) 129 | pi = distrax.Transformed( 130 | distrax.Normal(mean, std), 131 | distrax.Tanh(), 132 | ) 133 | return pi 134 | 135 | 136 | class EntropyCoef(nn.Module): 137 | ent_coef_init: float = 1.0 138 | 139 | @nn.compact 140 | def __call__(self) -> jnp.ndarray: 141 | log_ent_coef = self.param( 142 | "log_ent_coef", 143 | init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)), 144 | ) 145 | return log_ent_coef 146 | 147 | 148 | def create_train_state(args, rng, network, dummy_input, lr=None): 149 | return TrainState.create( 150 | apply_fn=network.apply, 151 | params=network.init(rng, *dummy_input), 152 | tx=optax.adam(lr if lr is not None else args.lr, eps=1e-5), 153 | ) 154 | 155 | 156 | def eval_agent(args, rng, env, agent_state): 157 | # --- Reset environment --- 158 | step = 0 159 | returned = onp.zeros(args.eval_workers).astype(bool) 160 | cum_reward = onp.zeros(args.eval_workers) 161 | rng, rng_reset = jax.random.split(rng) 162 | rng_reset = jax.random.split(rng_reset, args.eval_workers) 163 | obs = env.reset() 164 | 165 | # --- Rollout agent --- 166 | @jax.jit 167 | @jax.vmap 168 | def _policy_step(rng, obs): 169 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs) 170 | action = pi.sample(seed=rng) 171 | return jnp.nan_to_num(action) 172 | 173 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps 174 | while step < max_episode_steps and not returned.all(): 175 | # --- Take step in environment --- 176 | step += 1 177 | rng, rng_step = jax.random.split(rng) 178 | rng_step = jax.random.split(rng_step, args.eval_workers) 179 | action = _policy_step(rng_step, jnp.array(obs)) 180 | obs, reward, done, info = env.step(onp.array(action)) 181 | 182 | # --- Track cumulative reward --- 183 | cum_reward += reward * ~returned 184 | returned |= done 185 | 186 | if step >= max_episode_steps and not returned.all(): 187 | warnings.warn("Maximum steps reached before all episodes terminated") 188 | return cum_reward 189 | 190 | 191 | def sample_from_buffer(buffer, batch_size, rng): 192 | """Sample a batch from the buffer.""" 193 | idxs = jax.random.randint(rng, (batch_size,), 0, len(buffer.obs)) 194 | return jax.tree_map(lambda x: x[idxs], buffer) 195 | 196 | 197 | r""" 198 | __/) 199 | .-(__(=: 200 | |\ | \) 201 | \ || 202 | \|| 203 | \| 204 | ___|_____ 205 | \ / 206 | \ / 207 | \___/ Agent 208 | """ 209 | 210 | 211 | def make_train_step( 212 | args, actor_apply_fn, q_apply_fn, alpha_apply_fn, dataset, rollout_fn 213 | ): 214 | """Make JIT-compatible agent train step with model-based rollouts.""" 215 | 216 | def _train_step(runner_state, _): 217 | rng, agent_state, rollout_buffer = runner_state 218 | 219 | # --- Update model buffer --- 220 | params = agent_state.actor.params 221 | policy_fn = lambda obs, rng: actor_apply_fn(params, obs).sample(seed=rng) 222 | rng, rng_buffer = jax.random.split(rng) 223 | rollout_buffer = jax.lax.cond( 224 | agent_state.actor.step % args.rollout_interval == 0, 225 | lambda: rollout_fn(rng_buffer, policy_fn, rollout_buffer), 226 | lambda: rollout_buffer, 227 | ) 228 | 229 | # --- Sample batch --- 230 | rng, rng_dataset, rng_rollout = jax.random.split(rng, 3) 231 | dataset_size = int(args.batch_size * args.dataset_sample_ratio) 232 | rollout_size = args.batch_size - dataset_size 233 | dataset_batch = sample_from_buffer(dataset, dataset_size, rng_dataset) 234 | rollout_batch = sample_from_buffer(rollout_buffer, rollout_size, rng_rollout) 235 | batch = jax.tree_map( 236 | lambda x, y: jnp.concatenate([x, y]), dataset_batch, rollout_batch 237 | ) 238 | 239 | # --- Update alpha --- 240 | @jax.value_and_grad 241 | def _alpha_loss_fn(params, rng): 242 | def _compute_entropy(rng, transition): 243 | pi = actor_apply_fn(agent_state.actor.params, transition.obs) 244 | _, log_pi = pi.sample_and_log_prob(seed=rng) 245 | return -log_pi.sum() 246 | 247 | log_alpha = alpha_apply_fn(params) 248 | rng = jax.random.split(rng, args.batch_size) 249 | entropy = jax.vmap(_compute_entropy)(rng, batch).mean() 250 | target_entropy = -batch.action.shape[-1] 251 | return log_alpha * (entropy - target_entropy) 252 | 253 | rng, rng_alpha = jax.random.split(rng) 254 | alpha_loss, alpha_grad = _alpha_loss_fn(agent_state.alpha.params, rng_alpha) 255 | updated_alpha = agent_state.alpha.apply_gradients(grads=alpha_grad) 256 | agent_state = agent_state._replace(alpha=updated_alpha) 257 | alpha = jnp.exp(alpha_apply_fn(agent_state.alpha.params)) 258 | 259 | # --- Update actor --- 260 | @partial(jax.value_and_grad, has_aux=True) 261 | def _actor_loss_function(params, rng): 262 | def _compute_loss(rng, transition): 263 | pi = actor_apply_fn(params, transition.obs) 264 | sampled_action, log_pi = pi.sample_and_log_prob(seed=rng) 265 | log_pi = log_pi.sum() 266 | q_values = q_apply_fn( 267 | agent_state.vec_q.params, transition.obs, sampled_action 268 | ) 269 | q_min = jnp.min(q_values) 270 | return -q_min + alpha * log_pi, -log_pi, q_min, q_values.std() 271 | 272 | rng = jax.random.split(rng, args.batch_size) 273 | loss, entropy, q_min, q_std = jax.vmap(_compute_loss)(rng, batch) 274 | return loss.mean(), (entropy.mean(), q_min.mean(), q_std.mean()) 275 | 276 | rng, rng_actor = jax.random.split(rng) 277 | (actor_loss, (entropy, q_min, q_std)), actor_grad = _actor_loss_function( 278 | agent_state.actor.params, rng_actor 279 | ) 280 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad) 281 | agent_state = agent_state._replace(actor=updated_actor) 282 | 283 | # --- Update Q target network --- 284 | updated_q_target_params = optax.incremental_update( 285 | agent_state.vec_q.params, 286 | agent_state.vec_q_target.params, 287 | args.polyak_step_size, 288 | ) 289 | updated_q_target = agent_state.vec_q_target.replace( 290 | step=agent_state.vec_q_target.step + 1, params=updated_q_target_params 291 | ) 292 | agent_state = agent_state._replace(vec_q_target=updated_q_target) 293 | 294 | # --- Compute targets --- 295 | def _sample_next_v(rng, transition): 296 | next_pi = actor_apply_fn(agent_state.actor.params, transition.next_obs) 297 | # Note: Important to use sample_and_log_prob here for numerical stability 298 | # See https://github.com/deepmind/distrax/issues/7 299 | next_action, log_next_pi = next_pi.sample_and_log_prob(seed=rng) 300 | # Minimum of the target Q-values 301 | next_q = q_apply_fn( 302 | agent_state.vec_q_target.params, transition.next_obs, next_action 303 | ) 304 | return next_q.min(-1) - alpha * log_next_pi.sum(-1) 305 | 306 | rng, rng_next_v = jax.random.split(rng) 307 | rng_next_v = jax.random.split(rng_next_v, args.batch_size) 308 | next_v_target = jax.vmap(_sample_next_v)(rng_next_v, batch) 309 | target = batch.reward + args.gamma * (1 - batch.done) * next_v_target 310 | 311 | # --- Sample actions for CQL --- 312 | def _sample_actions(rng, obs): 313 | pi = actor_apply_fn(agent_state.actor.params, obs) 314 | return pi.sample(seed=rng) 315 | 316 | rng, rng_pi, rng_next = jax.random.split(rng, 3) 317 | pi_actions = _sample_actions(rng_pi, batch.obs) 318 | pi_next_actions = _sample_actions(rng_next, batch.next_obs) 319 | rng, rng_random = jax.random.split(rng) 320 | cql_random_actions = jax.random.uniform( 321 | rng_random, shape=batch.action.shape, minval=-1.0, maxval=1.0 322 | ) 323 | 324 | # --- Update critics --- 325 | @jax.value_and_grad 326 | def _q_loss_fn(params): 327 | q_pred = q_apply_fn(params, batch.obs, batch.action) 328 | critic_loss = jnp.square((q_pred - jnp.expand_dims(target, -1))) 329 | critic_loss = critic_loss.sum(-1).mean() 330 | q_pred_combo = q_apply_fn(params, dataset_batch.obs, dataset_batch.action) 331 | q_pred_combo = q_pred_combo.mean() 332 | 333 | rand_q = q_apply_fn(params, batch.obs, cql_random_actions) 334 | pi_q = q_apply_fn(params, batch.obs, pi_actions) 335 | # Note: Source implementation erroneously uses current obs in next_pi_q 336 | next_pi_q = q_apply_fn(params, batch.next_obs, pi_next_actions) 337 | all_qs = jnp.concatenate([rand_q, pi_q, next_pi_q, q_pred], axis=1) 338 | q_ood = jax.scipy.special.logsumexp(all_qs / args.cql_temperature, axis=1) 339 | q_ood = jax.lax.stop_gradient(q_ood * args.cql_temperature) 340 | q_diff = (jnp.expand_dims(q_ood, 1) - q_pred_combo).mean() 341 | min_q_loss = q_diff * args.cql_min_q_weight 342 | 343 | critic_loss += min_q_loss.mean() 344 | return critic_loss 345 | 346 | critic_loss, critic_grad = _q_loss_fn(agent_state.vec_q.params) 347 | updated_q = agent_state.vec_q.apply_gradients(grads=critic_grad) 348 | agent_state = agent_state._replace(vec_q=updated_q) 349 | 350 | num_done = jnp.sum(batch.done) 351 | loss = { 352 | "critic_loss": critic_loss, 353 | "actor_loss": actor_loss, 354 | "alpha_loss": alpha_loss, 355 | "entropy": entropy, 356 | "alpha": alpha, 357 | "q_min": q_min, 358 | "q_std": q_std, 359 | "terminations/num_done": num_done, 360 | "terminations/done_ratio": num_done / batch.done.shape[0], 361 | } 362 | return (rng, agent_state, rollout_buffer), loss 363 | 364 | return _train_step 365 | 366 | 367 | if __name__ == "__main__": 368 | # --- Parse arguments --- 369 | args = tyro.cli(Args) 370 | rng = jax.random.PRNGKey(args.seed) 371 | 372 | # --- Initialize logger --- 373 | if args.log: 374 | wandb.init( 375 | config=args, 376 | project=args.wandb_project, 377 | entity=args.wandb_team, 378 | group=args.wandb_group, 379 | job_type="train_agent", 380 | ) 381 | 382 | # --- Initialize environment and dataset --- 383 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers) 384 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset)) 385 | dataset = Transition( 386 | obs=jnp.array(dataset["observations"]), 387 | action=jnp.array(dataset["actions"]), 388 | reward=jnp.array(dataset["rewards"]), 389 | next_obs=jnp.array(dataset["next_observations"]), 390 | done=jnp.array(dataset["terminals"]), 391 | next_action=jnp.roll(dataset["actions"], -1, axis=0), 392 | ) 393 | 394 | # --- Initialize agent and value networks --- 395 | num_actions = env.single_action_space.shape[0] 396 | dummy_obs = jnp.zeros(env.single_observation_space.shape) 397 | dummy_action = jnp.zeros(num_actions) 398 | actor_net = TanhGaussianActor(num_actions) 399 | q_net = VectorQ(args.num_critics) 400 | alpha_net = EntropyCoef() 401 | 402 | # Target networks share seeds to match initialization 403 | rng, rng_actor, rng_q, rng_alpha = jax.random.split(rng, 4) 404 | actor_lr = args.actor_lr if args.actor_lr is not None else args.lr 405 | agent_state = AgentTrainState( 406 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs], actor_lr), 407 | vec_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 408 | vec_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 409 | alpha=create_train_state(args, rng_alpha, alpha_net, []), 410 | ) 411 | 412 | # --- Initialize buffer and rollout function --- 413 | assert args.model_path, "Model path must be provided for model-based methods" 414 | dynamics_model = load_dynamics_model(args.model_path) 415 | dynamics_model.dataset = dataset 416 | max_buffer_size = args.rollout_batch_size * args.rollout_length 417 | max_buffer_size *= args.model_retain_epochs 418 | rollout_buffer = jax.tree_map( 419 | lambda x: jnp.zeros((max_buffer_size, *x.shape[1:])), 420 | dataset, 421 | ) 422 | rollout_fn = dynamics_model.make_rollout_fn( 423 | batch_size=args.rollout_batch_size, rollout_length=args.rollout_length 424 | ) 425 | 426 | # --- Make train step --- 427 | _agent_train_step_fn = make_train_step( 428 | args, actor_net.apply, q_net.apply, alpha_net.apply, dataset, rollout_fn 429 | ) 430 | 431 | num_evals = args.num_updates // args.eval_interval 432 | for eval_idx in range(num_evals): 433 | # --- Execute train loop --- 434 | (rng, agent_state, rollout_buffer), loss = jax.lax.scan( 435 | _agent_train_step_fn, 436 | (rng, agent_state, rollout_buffer), 437 | None, 438 | args.eval_interval, 439 | ) 440 | 441 | # --- Evaluate agent --- 442 | rng, rng_eval = jax.random.split(rng) 443 | returns = eval_agent(args, rng_eval, env, agent_state) 444 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0 445 | 446 | # --- Log metrics --- 447 | step = (eval_idx + 1) * args.eval_interval 448 | print("Step:", step, f"\t Score: {scores.mean():.2f}") 449 | if args.log: 450 | log_dict = { 451 | "return": returns.mean(), 452 | "score": scores.mean(), 453 | "score_std": scores.std(), 454 | "num_updates": step, 455 | **{k: loss[k][-1] for k in loss}, 456 | } 457 | wandb.log(log_dict) 458 | 459 | # --- Evaluate final agent --- 460 | if args.eval_final_episodes > 0: 461 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers)) 462 | print(f"Evaluating final agent for {final_iters} iterations...") 463 | _rng = jax.random.split(rng, final_iters) 464 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng]) 465 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0 466 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()} 467 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores") 468 | 469 | # --- Write final returns to file --- 470 | os.makedirs("final_returns", exist_ok=True) 471 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 472 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz" 473 | with open(os.path.join("final_returns", filename), "wb") as f: 474 | onp.savez_compressed(f, **info, args=asdict(args)) 475 | 476 | if args.log: 477 | wandb.save(os.path.join("final_returns", filename)) 478 | 479 | if args.log: 480 | wandb.finish() 481 | -------------------------------------------------------------------------------- /algorithms/cql.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from dataclasses import dataclass, asdict 3 | from datetime import datetime 4 | from functools import partial 5 | import os 6 | import warnings 7 | 8 | import distrax 9 | import d4rl 10 | import flax.linen as nn 11 | from flax.linen.initializers import constant, uniform 12 | from flax.training.train_state import TrainState 13 | import gym 14 | import jax 15 | import jax.numpy as jnp 16 | import numpy as onp 17 | import optax 18 | import tyro 19 | import wandb 20 | 21 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 22 | 23 | 24 | @dataclass 25 | class Args: 26 | # --- Experiment --- 27 | seed: int = 0 28 | dataset: str = "halfcheetah-medium-v2" 29 | algorithm: str = "cql" 30 | num_updates: int = 1_000_000 31 | eval_interval: int = 2500 32 | eval_workers: int = 8 33 | eval_final_episodes: int = 1000 34 | # --- Logging --- 35 | log: bool = False 36 | wandb_project: str = "unifloral" 37 | wandb_team: str = "flair" 38 | wandb_group: str = "debug" 39 | # --- Generic optimization --- 40 | lr: float = 3e-4 41 | batch_size: int = 256 42 | gamma: float = 0.99 43 | polyak_step_size: float = 0.005 44 | # --- SAC-N --- 45 | num_critics: int = 10 46 | # --- CQL--- 47 | actor_lr: float = 3e-5 48 | cql_temperature: float = 1.0 49 | cql_min_q_weight: float = 10.0 50 | 51 | 52 | r""" 53 | |\ __ 54 | \| /_/ 55 | \| 56 | ___|_____ 57 | \ / 58 | \ / 59 | \___/ Preliminaries 60 | """ 61 | 62 | AgentTrainState = namedtuple("AgentTrainState", "actor vec_q vec_q_target alpha") 63 | Transition = namedtuple("Transition", "obs action reward next_obs done") 64 | 65 | 66 | def sym(scale): 67 | def _init(*args, **kwargs): 68 | return uniform(2 * scale)(*args, **kwargs) - scale 69 | 70 | return _init 71 | 72 | 73 | class SoftQNetwork(nn.Module): 74 | @nn.compact 75 | def __call__(self, obs, action): 76 | x = jnp.concatenate([obs, action], axis=-1) 77 | for _ in range(3): 78 | x = nn.Dense(256, bias_init=constant(0.1))(x) 79 | x = nn.relu(x) 80 | q = nn.Dense(1, kernel_init=sym(3e-3), bias_init=sym(3e-3))(x) 81 | return q.squeeze(-1) 82 | 83 | 84 | class VectorQ(nn.Module): 85 | num_critics: int 86 | 87 | @nn.compact 88 | def __call__(self, obs, action): 89 | vmap_critic = nn.vmap( 90 | SoftQNetwork, 91 | variable_axes={"params": 0}, # Parameters not shared between critics 92 | split_rngs={"params": True, "dropout": True}, # Different initializations 93 | in_axes=None, 94 | out_axes=-1, 95 | axis_size=self.num_critics, 96 | ) 97 | q_values = vmap_critic()(obs, action) 98 | return q_values 99 | 100 | 101 | class TanhGaussianActor(nn.Module): 102 | num_actions: int 103 | log_std_max: float = 2.0 104 | log_std_min: float = -5.0 105 | 106 | @nn.compact 107 | def __call__(self, x): 108 | for _ in range(3): 109 | x = nn.Dense(256, bias_init=constant(0.1))(x) 110 | x = nn.relu(x) 111 | log_std = nn.Dense( 112 | self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3) 113 | )(x) 114 | std = jnp.exp(jnp.clip(log_std, self.log_std_min, self.log_std_max)) 115 | mean = nn.Dense(self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3))(x) 116 | pi = distrax.Transformed( 117 | distrax.Normal(mean, std), 118 | distrax.Tanh(), 119 | ) 120 | return pi 121 | 122 | 123 | class EntropyCoef(nn.Module): 124 | ent_coef_init: float = 1.0 125 | 126 | @nn.compact 127 | def __call__(self) -> jnp.ndarray: 128 | log_ent_coef = self.param( 129 | "log_ent_coef", 130 | init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)), 131 | ) 132 | return log_ent_coef 133 | 134 | 135 | def create_train_state(args, rng, network, dummy_input, lr=None): 136 | return TrainState.create( 137 | apply_fn=network.apply, 138 | params=network.init(rng, *dummy_input), 139 | tx=optax.adam(lr if lr is not None else args.lr, eps=1e-5), 140 | ) 141 | 142 | 143 | def eval_agent(args, rng, env, agent_state): 144 | # --- Reset environment --- 145 | step = 0 146 | returned = onp.zeros(args.eval_workers).astype(bool) 147 | cum_reward = onp.zeros(args.eval_workers) 148 | rng, rng_reset = jax.random.split(rng) 149 | rng_reset = jax.random.split(rng_reset, args.eval_workers) 150 | obs = env.reset() 151 | 152 | # --- Rollout agent --- 153 | @jax.jit 154 | @jax.vmap 155 | def _policy_step(rng, obs): 156 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs) 157 | action = pi.sample(seed=rng) 158 | return jnp.nan_to_num(action) 159 | 160 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps 161 | while step < max_episode_steps and not returned.all(): 162 | # --- Take step in environment --- 163 | step += 1 164 | rng, rng_step = jax.random.split(rng) 165 | rng_step = jax.random.split(rng_step, args.eval_workers) 166 | action = _policy_step(rng_step, jnp.array(obs)) 167 | obs, reward, done, info = env.step(onp.array(action)) 168 | 169 | # --- Track cumulative reward --- 170 | cum_reward += reward * ~returned 171 | returned |= done 172 | 173 | if step >= max_episode_steps and not returned.all(): 174 | warnings.warn("Maximum steps reached before all episodes terminated") 175 | return cum_reward 176 | 177 | 178 | r""" 179 | __/) 180 | .-(__(=: 181 | |\ | \) 182 | \ || 183 | \|| 184 | \| 185 | ___|_____ 186 | \ / 187 | \ / 188 | \___/ Agent 189 | """ 190 | 191 | 192 | def make_train_step(args, actor_apply_fn, q_apply_fn, alpha_apply_fn, dataset): 193 | """Make JIT-compatible agent train step.""" 194 | 195 | def _train_step(runner_state, _): 196 | rng, agent_state = runner_state 197 | 198 | # --- Sample batch --- 199 | rng, rng_batch = jax.random.split(rng) 200 | batch_indices = jax.random.randint( 201 | rng_batch, (args.batch_size,), 0, len(dataset.obs) 202 | ) 203 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset) 204 | 205 | # --- Update alpha --- 206 | @jax.value_and_grad 207 | def _alpha_loss_fn(params, rng): 208 | def _compute_entropy(rng, transition): 209 | pi = actor_apply_fn(agent_state.actor.params, transition.obs) 210 | _, log_pi = pi.sample_and_log_prob(seed=rng) 211 | return -log_pi.sum() 212 | 213 | log_alpha = alpha_apply_fn(params) 214 | rng = jax.random.split(rng, args.batch_size) 215 | entropy = jax.vmap(_compute_entropy)(rng, batch).mean() 216 | target_entropy = -batch.action.shape[-1] 217 | return log_alpha * (entropy - target_entropy) 218 | 219 | rng, rng_alpha = jax.random.split(rng) 220 | alpha_loss, alpha_grad = _alpha_loss_fn(agent_state.alpha.params, rng_alpha) 221 | updated_alpha = agent_state.alpha.apply_gradients(grads=alpha_grad) 222 | agent_state = agent_state._replace(alpha=updated_alpha) 223 | alpha = jnp.exp(alpha_apply_fn(agent_state.alpha.params)) 224 | 225 | # --- Update actor --- 226 | @partial(jax.value_and_grad, has_aux=True) 227 | def _actor_loss_function(params, rng): 228 | def _compute_loss(rng, transition): 229 | pi = actor_apply_fn(params, transition.obs) 230 | sampled_action, log_pi = pi.sample_and_log_prob(seed=rng) 231 | log_pi = log_pi.sum() 232 | q_values = q_apply_fn( 233 | agent_state.vec_q.params, transition.obs, sampled_action 234 | ) 235 | q_min = jnp.min(q_values) 236 | return -q_min + alpha * log_pi, -log_pi, q_min, q_values.std() 237 | 238 | rng = jax.random.split(rng, args.batch_size) 239 | loss, entropy, q_min, q_std = jax.vmap(_compute_loss)(rng, batch) 240 | return loss.mean(), (entropy.mean(), q_min.mean(), q_std.mean()) 241 | 242 | rng, rng_actor = jax.random.split(rng) 243 | (actor_loss, (entropy, q_min, q_std)), actor_grad = _actor_loss_function( 244 | agent_state.actor.params, rng_actor 245 | ) 246 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad) 247 | agent_state = agent_state._replace(actor=updated_actor) 248 | 249 | # --- Update Q target network --- 250 | updated_q_target_params = optax.incremental_update( 251 | agent_state.vec_q.params, 252 | agent_state.vec_q_target.params, 253 | args.polyak_step_size, 254 | ) 255 | updated_q_target = agent_state.vec_q_target.replace( 256 | step=agent_state.vec_q_target.step + 1, params=updated_q_target_params 257 | ) 258 | agent_state = agent_state._replace(vec_q_target=updated_q_target) 259 | 260 | # --- Compute targets --- 261 | def _sample_next_v(rng, transition): 262 | next_pi = actor_apply_fn(agent_state.actor.params, transition.next_obs) 263 | # Note: Important to use sample_and_log_prob here for numerical stability 264 | # See https://github.com/deepmind/distrax/issues/7 265 | next_action, log_next_pi = next_pi.sample_and_log_prob(seed=rng) 266 | # Minimum of the target Q-values 267 | next_q = q_apply_fn( 268 | agent_state.vec_q_target.params, transition.next_obs, next_action 269 | ) 270 | return next_q.min(-1) - alpha * log_next_pi.sum(-1) 271 | 272 | rng, rng_next_v = jax.random.split(rng) 273 | rng_next_v = jax.random.split(rng_next_v, args.batch_size) 274 | next_v_target = jax.vmap(_sample_next_v)(rng_next_v, batch) 275 | target = batch.reward + args.gamma * (1 - batch.done) * next_v_target 276 | 277 | # --- Sample actions for CQL --- 278 | def _sample_actions(rng, obs): 279 | pi = actor_apply_fn(agent_state.actor.params, obs) 280 | return pi.sample(seed=rng) 281 | 282 | rng, rng_pi, rng_next = jax.random.split(rng, 3) 283 | pi_actions = _sample_actions(rng_pi, batch.obs) 284 | pi_next_actions = _sample_actions(rng_next, batch.next_obs) 285 | rng, rng_random = jax.random.split(rng) 286 | cql_random_actions = jax.random.uniform( 287 | rng_random, shape=batch.action.shape, minval=-1.0, maxval=1.0 288 | ) 289 | 290 | # --- Update critics --- 291 | @jax.value_and_grad 292 | def _q_loss_fn(params): 293 | q_pred = q_apply_fn(params, batch.obs, batch.action) 294 | critic_loss = jnp.square((q_pred - jnp.expand_dims(target, -1))) 295 | critic_loss = critic_loss.sum(-1).mean() 296 | 297 | rand_q = q_apply_fn(params, batch.obs, cql_random_actions) 298 | pi_q = q_apply_fn(params, batch.obs, pi_actions) 299 | # Note: Source implementation erroneously uses current obs in next_pi_q 300 | next_pi_q = q_apply_fn(params, batch.next_obs, pi_next_actions) 301 | all_qs = jnp.concatenate([rand_q, pi_q, next_pi_q, q_pred], axis=1) 302 | q_ood = jax.scipy.special.logsumexp(all_qs / args.cql_temperature, axis=1) 303 | q_ood = jax.lax.stop_gradient(q_ood * args.cql_temperature) 304 | q_diff = (jnp.expand_dims(q_ood, 1) - q_pred).mean() 305 | min_q_loss = q_diff * args.cql_min_q_weight 306 | 307 | critic_loss += min_q_loss.mean() 308 | return critic_loss 309 | 310 | critic_loss, critic_grad = _q_loss_fn(agent_state.vec_q.params) 311 | updated_q = agent_state.vec_q.apply_gradients(grads=critic_grad) 312 | agent_state = agent_state._replace(vec_q=updated_q) 313 | 314 | loss = { 315 | "critic_loss": critic_loss, 316 | "actor_loss": actor_loss, 317 | "alpha_loss": alpha_loss, 318 | "entropy": entropy, 319 | "alpha": alpha, 320 | "q_min": q_min, 321 | "q_std": q_std, 322 | } 323 | return (rng, agent_state), loss 324 | 325 | return _train_step 326 | 327 | 328 | if __name__ == "__main__": 329 | # --- Parse arguments --- 330 | args = tyro.cli(Args) 331 | rng = jax.random.PRNGKey(args.seed) 332 | 333 | # --- Initialize logger --- 334 | if args.log: 335 | wandb.init( 336 | config=args, 337 | project=args.wandb_project, 338 | entity=args.wandb_team, 339 | group=args.wandb_group, 340 | job_type="train_agent", 341 | ) 342 | 343 | # --- Initialize environment and dataset --- 344 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers) 345 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset)) 346 | dataset = Transition( 347 | obs=jnp.array(dataset["observations"]), 348 | action=jnp.array(dataset["actions"]), 349 | reward=jnp.array(dataset["rewards"]), 350 | next_obs=jnp.array(dataset["next_observations"]), 351 | done=jnp.array(dataset["terminals"]), 352 | ) 353 | 354 | # --- Initialize agent and value networks --- 355 | num_actions = env.single_action_space.shape[0] 356 | dummy_obs = jnp.zeros(env.single_observation_space.shape) 357 | dummy_action = jnp.zeros(num_actions) 358 | actor_net = TanhGaussianActor(num_actions) 359 | q_net = VectorQ(args.num_critics) 360 | alpha_net = EntropyCoef() 361 | 362 | # Target networks share seeds to match initialization 363 | rng, rng_actor, rng_q, rng_alpha = jax.random.split(rng, 4) 364 | actor_lr = args.actor_lr if args.actor_lr is not None else args.lr 365 | agent_state = AgentTrainState( 366 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs], actor_lr), 367 | vec_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 368 | vec_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 369 | alpha=create_train_state(args, rng_alpha, alpha_net, []), 370 | ) 371 | 372 | # --- Make train step --- 373 | _agent_train_step_fn = make_train_step( 374 | args, actor_net.apply, q_net.apply, alpha_net.apply, dataset 375 | ) 376 | 377 | num_evals = args.num_updates // args.eval_interval 378 | for eval_idx in range(num_evals): 379 | # --- Execute train loop --- 380 | (rng, agent_state), loss = jax.lax.scan( 381 | _agent_train_step_fn, 382 | (rng, agent_state), 383 | None, 384 | args.eval_interval, 385 | ) 386 | 387 | # --- Evaluate agent --- 388 | rng, rng_eval = jax.random.split(rng) 389 | returns = eval_agent(args, rng_eval, env, agent_state) 390 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0 391 | 392 | # --- Log metrics --- 393 | step = (eval_idx + 1) * args.eval_interval 394 | print("Step:", step, f"\t Score: {scores.mean():.2f}") 395 | if args.log: 396 | log_dict = { 397 | "return": returns.mean(), 398 | "score": scores.mean(), 399 | "score_std": scores.std(), 400 | "num_updates": step, 401 | **{k: loss[k][-1] for k in loss}, 402 | } 403 | wandb.log(log_dict) 404 | 405 | # --- Evaluate final agent --- 406 | if args.eval_final_episodes > 0: 407 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers)) 408 | print(f"Evaluating final agent for {final_iters} iterations...") 409 | _rng = jax.random.split(rng, final_iters) 410 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng]) 411 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0 412 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()} 413 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores") 414 | 415 | # --- Write final returns to file --- 416 | os.makedirs("final_returns", exist_ok=True) 417 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 418 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz" 419 | with open(os.path.join("final_returns", filename), "wb") as f: 420 | onp.savez_compressed(f, **info, args=asdict(args)) 421 | 422 | if args.log: 423 | wandb.save(os.path.join("final_returns", filename)) 424 | 425 | if args.log: 426 | wandb.finish() 427 | -------------------------------------------------------------------------------- /algorithms/edac.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from dataclasses import dataclass, asdict 3 | from datetime import datetime 4 | from functools import partial 5 | import os 6 | import warnings 7 | 8 | import distrax 9 | import d4rl 10 | import flax.linen as nn 11 | from flax.linen.initializers import constant, uniform 12 | from flax.training.train_state import TrainState 13 | import gym 14 | import jax 15 | import jax.numpy as jnp 16 | import numpy as onp 17 | import optax 18 | import tyro 19 | import wandb 20 | 21 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 22 | 23 | 24 | @dataclass 25 | class Args: 26 | # --- Experiment --- 27 | seed: int = 0 28 | dataset: str = "halfcheetah-medium-v2" 29 | algorithm: str = "edac" 30 | num_updates: int = 3_000_000 31 | eval_interval: int = 2500 32 | eval_workers: int = 8 33 | eval_final_episodes: int = 1000 34 | # --- Logging --- 35 | log: bool = False 36 | wandb_project: str = "unifloral" 37 | wandb_team: str = "flair" 38 | wandb_group: str = "debug" 39 | # --- Generic optimization --- 40 | lr: float = 3e-4 41 | batch_size: int = 256 42 | gamma: float = 0.99 43 | polyak_step_size: float = 0.005 44 | # --- SAC-N --- 45 | num_critics: int = 10 46 | # --- EDAC --- 47 | eta: float = 1.0 48 | 49 | 50 | r""" 51 | |\ __ 52 | \| /_/ 53 | \| 54 | ___|_____ 55 | \ / 56 | \ / 57 | \___/ Preliminaries 58 | """ 59 | 60 | AgentTrainState = namedtuple("AgentTrainState", "actor vec_q vec_q_target alpha") 61 | Transition = namedtuple("Transition", "obs action reward next_obs done") 62 | 63 | 64 | def sym(scale): 65 | def _init(*args, **kwargs): 66 | return uniform(2 * scale)(*args, **kwargs) - scale 67 | 68 | return _init 69 | 70 | 71 | class SoftQNetwork(nn.Module): 72 | @nn.compact 73 | def __call__(self, obs, action): 74 | x = jnp.concatenate([obs, action], axis=-1) 75 | for _ in range(3): 76 | x = nn.Dense(256, bias_init=constant(0.1))(x) 77 | x = nn.relu(x) 78 | q = nn.Dense(1, kernel_init=sym(3e-3), bias_init=sym(3e-3))(x) 79 | return q.squeeze(-1) 80 | 81 | 82 | class VectorQ(nn.Module): 83 | num_critics: int 84 | 85 | @nn.compact 86 | def __call__(self, obs, action): 87 | vmap_critic = nn.vmap( 88 | SoftQNetwork, 89 | variable_axes={"params": 0}, # Parameters not shared between critics 90 | split_rngs={"params": True, "dropout": True}, # Different initializations 91 | in_axes=None, 92 | out_axes=-1, 93 | axis_size=self.num_critics, 94 | ) 95 | q_values = vmap_critic()(obs, action) 96 | return q_values 97 | 98 | 99 | class TanhGaussianActor(nn.Module): 100 | num_actions: int 101 | log_std_max: float = 2.0 102 | log_std_min: float = -5.0 103 | 104 | @nn.compact 105 | def __call__(self, x): 106 | for _ in range(3): 107 | x = nn.Dense(256, bias_init=constant(0.1))(x) 108 | x = nn.relu(x) 109 | log_std = nn.Dense( 110 | self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3) 111 | )(x) 112 | std = jnp.exp(jnp.clip(log_std, self.log_std_min, self.log_std_max)) 113 | mean = nn.Dense(self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3))(x) 114 | pi = distrax.Transformed( 115 | distrax.Normal(mean, std), 116 | distrax.Tanh(), 117 | ) 118 | return pi 119 | 120 | 121 | class EntropyCoef(nn.Module): 122 | ent_coef_init: float = 1.0 123 | 124 | @nn.compact 125 | def __call__(self) -> jnp.ndarray: 126 | log_ent_coef = self.param( 127 | "log_ent_coef", 128 | init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)), 129 | ) 130 | return log_ent_coef 131 | 132 | 133 | def create_train_state(args, rng, network, dummy_input): 134 | return TrainState.create( 135 | apply_fn=network.apply, 136 | params=network.init(rng, *dummy_input), 137 | tx=optax.adam(args.lr, eps=1e-5), 138 | ) 139 | 140 | 141 | def eval_agent(args, rng, env, agent_state): 142 | # --- Reset environment --- 143 | step = 0 144 | returned = onp.zeros(args.eval_workers).astype(bool) 145 | cum_reward = onp.zeros(args.eval_workers) 146 | rng, rng_reset = jax.random.split(rng) 147 | rng_reset = jax.random.split(rng_reset, args.eval_workers) 148 | obs = env.reset() 149 | 150 | # --- Rollout agent --- 151 | @jax.jit 152 | @jax.vmap 153 | def _policy_step(rng, obs): 154 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs) 155 | action = pi.sample(seed=rng) 156 | return jnp.nan_to_num(action) 157 | 158 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps 159 | while step < max_episode_steps and not returned.all(): 160 | # --- Take step in environment --- 161 | step += 1 162 | rng, rng_step = jax.random.split(rng) 163 | rng_step = jax.random.split(rng_step, args.eval_workers) 164 | action = _policy_step(rng_step, jnp.array(obs)) 165 | obs, reward, done, info = env.step(onp.array(action)) 166 | 167 | # --- Track cumulative reward --- 168 | cum_reward += reward * ~returned 169 | returned |= done 170 | 171 | if step >= max_episode_steps and not returned.all(): 172 | warnings.warn("Maximum steps reached before all episodes terminated") 173 | return cum_reward 174 | 175 | 176 | r""" 177 | __/) 178 | .-(__(=: 179 | |\ | \) 180 | \ || 181 | \|| 182 | \| 183 | ___|_____ 184 | \ / 185 | \ / 186 | \___/ Agent 187 | """ 188 | 189 | 190 | def make_train_step(args, actor_apply_fn, q_apply_fn, alpha_apply_fn, dataset): 191 | """Make JIT-compatible agent train step.""" 192 | 193 | def _train_step(runner_state, _): 194 | rng, agent_state = runner_state 195 | 196 | # --- Sample batch --- 197 | rng, rng_batch = jax.random.split(rng) 198 | batch_indices = jax.random.randint( 199 | rng_batch, (args.batch_size,), 0, len(dataset.obs) 200 | ) 201 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset) 202 | 203 | # --- Update alpha --- 204 | @jax.value_and_grad 205 | def _alpha_loss_fn(params, rng): 206 | def _compute_entropy(rng, transition): 207 | pi = actor_apply_fn(agent_state.actor.params, transition.obs) 208 | _, log_pi = pi.sample_and_log_prob(seed=rng) 209 | return -log_pi.sum() 210 | 211 | log_alpha = alpha_apply_fn(params) 212 | rng = jax.random.split(rng, args.batch_size) 213 | entropy = jax.vmap(_compute_entropy)(rng, batch).mean() 214 | target_entropy = -batch.action.shape[-1] 215 | return log_alpha * (entropy - target_entropy) 216 | 217 | rng, rng_alpha = jax.random.split(rng) 218 | alpha_loss, alpha_grad = _alpha_loss_fn(agent_state.alpha.params, rng_alpha) 219 | updated_alpha = agent_state.alpha.apply_gradients(grads=alpha_grad) 220 | agent_state = agent_state._replace(alpha=updated_alpha) 221 | alpha = jnp.exp(alpha_apply_fn(agent_state.alpha.params)) 222 | 223 | # --- Update actor --- 224 | @partial(jax.value_and_grad, has_aux=True) 225 | def _actor_loss_function(params, rng): 226 | def _compute_loss(rng, transition): 227 | pi = actor_apply_fn(params, transition.obs) 228 | sampled_action, log_pi = pi.sample_and_log_prob(seed=rng) 229 | log_pi = log_pi.sum() 230 | q_values = q_apply_fn( 231 | agent_state.vec_q.params, transition.obs, sampled_action 232 | ) 233 | q_min = jnp.min(q_values) 234 | return -q_min + alpha * log_pi, -log_pi, q_min, q_values.std() 235 | 236 | rng = jax.random.split(rng, args.batch_size) 237 | loss, entropy, q_min, q_std = jax.vmap(_compute_loss)(rng, batch) 238 | return loss.mean(), (entropy.mean(), q_min.mean(), q_std.mean()) 239 | 240 | rng, rng_actor = jax.random.split(rng) 241 | (actor_loss, (entropy, q_min, q_std)), actor_grad = _actor_loss_function( 242 | agent_state.actor.params, rng_actor 243 | ) 244 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad) 245 | agent_state = agent_state._replace(actor=updated_actor) 246 | 247 | # --- Update Q target network --- 248 | updated_q_target_params = optax.incremental_update( 249 | agent_state.vec_q.params, 250 | agent_state.vec_q_target.params, 251 | args.polyak_step_size, 252 | ) 253 | updated_q_target = agent_state.vec_q_target.replace( 254 | step=agent_state.vec_q_target.step + 1, params=updated_q_target_params 255 | ) 256 | agent_state = agent_state._replace(vec_q_target=updated_q_target) 257 | 258 | # --- Compute targets --- 259 | def _sample_next_v(rng, transition): 260 | next_pi = actor_apply_fn(agent_state.actor.params, transition.next_obs) 261 | # Note: Important to use sample_and_log_prob here for numerical stability 262 | # See https://github.com/deepmind/distrax/issues/7 263 | next_action, log_next_pi = next_pi.sample_and_log_prob(seed=rng) 264 | # Minimum of the target Q-values 265 | next_q = q_apply_fn( 266 | agent_state.vec_q_target.params, transition.next_obs, next_action 267 | ) 268 | return next_q.min(-1) - alpha * log_next_pi.sum(-1) 269 | 270 | rng, rng_next_v = jax.random.split(rng) 271 | rng_next_v = jax.random.split(rng_next_v, args.batch_size) 272 | next_v_target = jax.vmap(_sample_next_v)(rng_next_v, batch) 273 | target = batch.reward + args.gamma * (1 - batch.done) * next_v_target 274 | 275 | # --- Update critics --- 276 | @partial(jax.value_and_grad, has_aux=True) 277 | def _q_loss_fn(params): 278 | q_pred = q_apply_fn(params, batch.obs, batch.action) 279 | value_loss = jnp.square((q_pred - jnp.expand_dims(target, -1))) 280 | value_loss = value_loss.sum(-1).mean() 281 | 282 | def _diversity_loss_fn(obs, action): 283 | action_jac = jax.jacrev(q_apply_fn, argnums=2)(params, obs, action) 284 | action_jac /= jnp.linalg.norm(action_jac, axis=-1, keepdims=True) + 1e-6 285 | div_loss = action_jac @ action_jac.T 286 | div_loss *= 1.0 - jnp.eye(args.num_critics) 287 | return div_loss.sum() 288 | 289 | diversity_loss = jax.vmap(_diversity_loss_fn)(batch.obs, batch.action) 290 | diversity_loss = diversity_loss.mean() / (args.num_critics - 1) 291 | critic_loss = value_loss + args.eta * diversity_loss 292 | return critic_loss, (value_loss, diversity_loss) 293 | 294 | (critic_loss, (value_loss, diversity_loss)), critic_grad = _q_loss_fn( 295 | agent_state.vec_q.params 296 | ) 297 | updated_q = agent_state.vec_q.apply_gradients(grads=critic_grad) 298 | agent_state = agent_state._replace(vec_q=updated_q) 299 | 300 | loss = { 301 | "critic_loss": critic_loss, 302 | "value_loss": value_loss, 303 | "diversity_loss": diversity_loss, 304 | "actor_loss": actor_loss, 305 | "alpha_loss": alpha_loss, 306 | "entropy": entropy, 307 | "alpha": alpha, 308 | "q_min": q_min, 309 | "q_std": q_std, 310 | } 311 | return (rng, agent_state), loss 312 | 313 | return _train_step 314 | 315 | 316 | if __name__ == "__main__": 317 | # --- Parse arguments --- 318 | args = tyro.cli(Args) 319 | rng = jax.random.PRNGKey(args.seed) 320 | 321 | # --- Initialize logger --- 322 | if args.log: 323 | wandb.init( 324 | config=args, 325 | project=args.wandb_project, 326 | entity=args.wandb_team, 327 | group=args.wandb_group, 328 | job_type="train_agent", 329 | ) 330 | 331 | # --- Initialize environment and dataset --- 332 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers) 333 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset)) 334 | dataset = Transition( 335 | obs=jnp.array(dataset["observations"]), 336 | action=jnp.array(dataset["actions"]), 337 | reward=jnp.array(dataset["rewards"]), 338 | next_obs=jnp.array(dataset["next_observations"]), 339 | done=jnp.array(dataset["terminals"]), 340 | ) 341 | 342 | # --- Initialize agent and value networks --- 343 | num_actions = env.single_action_space.shape[0] 344 | dummy_obs = jnp.zeros(env.single_observation_space.shape) 345 | dummy_action = jnp.zeros(num_actions) 346 | actor_net = TanhGaussianActor(num_actions) 347 | q_net = VectorQ(args.num_critics) 348 | alpha_net = EntropyCoef() 349 | 350 | # Target networks share seeds to match initialization 351 | rng, rng_actor, rng_q, rng_alpha = jax.random.split(rng, 4) 352 | agent_state = AgentTrainState( 353 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]), 354 | vec_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 355 | vec_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 356 | alpha=create_train_state(args, rng_alpha, alpha_net, []), 357 | ) 358 | 359 | # --- Make train step --- 360 | _agent_train_step_fn = make_train_step( 361 | args, actor_net.apply, q_net.apply, alpha_net.apply, dataset 362 | ) 363 | 364 | num_evals = args.num_updates // args.eval_interval 365 | for eval_idx in range(num_evals): 366 | # --- Execute train loop --- 367 | (rng, agent_state), loss = jax.lax.scan( 368 | _agent_train_step_fn, 369 | (rng, agent_state), 370 | None, 371 | args.eval_interval, 372 | ) 373 | 374 | # --- Evaluate agent --- 375 | rng, rng_eval = jax.random.split(rng) 376 | returns = eval_agent(args, rng_eval, env, agent_state) 377 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0 378 | 379 | # --- Log metrics --- 380 | step = (eval_idx + 1) * args.eval_interval 381 | print("Step:", step, f"\t Score: {scores.mean():.2f}") 382 | if args.log: 383 | log_dict = { 384 | "return": returns.mean(), 385 | "score": scores.mean(), 386 | "score_std": scores.std(), 387 | "num_updates": step, 388 | **{k: loss[k][-1] for k in loss}, 389 | } 390 | wandb.log(log_dict) 391 | 392 | # --- Evaluate final agent --- 393 | if args.eval_final_episodes > 0: 394 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers)) 395 | print(f"Evaluating final agent for {final_iters} iterations...") 396 | _rng = jax.random.split(rng, final_iters) 397 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng]) 398 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0 399 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()} 400 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores") 401 | 402 | # --- Write final returns to file --- 403 | os.makedirs("final_returns", exist_ok=True) 404 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 405 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz" 406 | with open(os.path.join("final_returns", filename), "wb") as f: 407 | onp.savez_compressed(f, **info, args=asdict(args)) 408 | 409 | if args.log: 410 | wandb.save(os.path.join("final_returns", filename)) 411 | 412 | if args.log: 413 | wandb.finish() 414 | -------------------------------------------------------------------------------- /algorithms/iql.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from dataclasses import dataclass, asdict 3 | from datetime import datetime 4 | from functools import partial 5 | import os 6 | import warnings 7 | 8 | import distrax 9 | import d4rl 10 | import flax.linen as nn 11 | from flax.training.train_state import TrainState 12 | import gym 13 | import jax 14 | import jax.numpy as jnp 15 | import numpy as onp 16 | import optax 17 | import tyro 18 | import wandb 19 | 20 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 21 | 22 | 23 | @dataclass 24 | class Args: 25 | # --- Experiment --- 26 | seed: int = 0 27 | dataset: str = "halfcheetah-medium-v2" 28 | algorithm: str = "iql" 29 | num_updates: int = 1_000_000 30 | eval_interval: int = 2500 31 | eval_workers: int = 8 32 | eval_final_episodes: int = 1000 33 | # --- Logging --- 34 | log: bool = False 35 | wandb_project: str = "unifloral" 36 | wandb_team: str = "flair" 37 | wandb_group: str = "debug" 38 | # --- Generic optimization --- 39 | lr: float = 3e-4 40 | batch_size: int = 256 41 | gamma: float = 0.99 42 | polyak_step_size: float = 0.005 43 | # --- IQL --- 44 | beta: float = 3.0 45 | iql_tau: float = 0.7 46 | exp_adv_clip: float = 100.0 47 | 48 | 49 | r""" 50 | |\ __ 51 | \| /_/ 52 | \| 53 | ___|_____ 54 | \ / 55 | \ / 56 | \___/ Preliminaries 57 | """ 58 | 59 | AgentTrainState = namedtuple("AgentTrainState", "actor dual_q dual_q_target value") 60 | Transition = namedtuple("Transition", "obs action reward next_obs done") 61 | 62 | 63 | class SoftQNetwork(nn.Module): 64 | obs_mean: jax.Array 65 | obs_std: jax.Array 66 | 67 | @nn.compact 68 | def __call__(self, obs, action): 69 | obs = (obs - self.obs_mean) / (self.obs_std + 1e-3) 70 | x = jnp.concatenate([obs, action], axis=-1) 71 | for _ in range(2): 72 | x = nn.Dense(256)(x) 73 | x = nn.relu(x) 74 | q = nn.Dense(1)(x) 75 | return q.squeeze(-1) 76 | 77 | 78 | class DualQNetwork(nn.Module): 79 | obs_mean: jax.Array 80 | obs_std: jax.Array 81 | 82 | @nn.compact 83 | def __call__(self, obs, action): 84 | vmap_critic = nn.vmap( 85 | SoftQNetwork, 86 | variable_axes={"params": 0}, # Parameters not shared between critics 87 | split_rngs={"params": True, "dropout": True}, # Different initializations 88 | in_axes=None, 89 | out_axes=-1, 90 | axis_size=2, # Two Q networks 91 | ) 92 | q_values = vmap_critic(self.obs_mean, self.obs_std)(obs, action) 93 | return q_values 94 | 95 | 96 | class StateValueFunction(nn.Module): 97 | obs_mean: jax.Array 98 | obs_std: jax.Array 99 | 100 | @nn.compact 101 | def __call__(self, x): 102 | x = (x - self.obs_mean) / (self.obs_std + 1e-3) 103 | for _ in range(2): 104 | x = nn.Dense(256)(x) 105 | x = nn.relu(x) 106 | v = nn.Dense(1)(x) 107 | return v.squeeze(-1) 108 | 109 | 110 | class TanhGaussianActor(nn.Module): 111 | num_actions: int 112 | obs_mean: jax.Array 113 | obs_std: jax.Array 114 | log_std_max: float = 2.0 115 | log_std_min: float = -20.0 116 | 117 | @nn.compact 118 | def __call__(self, x, eval=False): 119 | x = (x - self.obs_mean) / (self.obs_std + 1e-3) 120 | for _ in range(2): 121 | x = nn.Dense(256)(x) 122 | x = nn.relu(x) 123 | x = nn.Dense(self.num_actions)(x) 124 | x = nn.tanh(x) 125 | if eval: 126 | return distrax.Deterministic(x) 127 | logstd = self.param( 128 | "logstd", 129 | init_fn=lambda key: jnp.zeros(self.num_actions, dtype=jnp.float32), 130 | ) 131 | std = jnp.exp(jnp.clip(logstd, self.log_std_min, self.log_std_max)) 132 | return distrax.Normal(x, std) 133 | 134 | 135 | def create_train_state(args, rng, network, dummy_input): 136 | lr_schedule = optax.cosine_decay_schedule(args.lr, args.num_updates) 137 | return TrainState.create( 138 | apply_fn=network.apply, 139 | params=network.init(rng, *dummy_input), 140 | tx=optax.adam(lr_schedule, eps=1e-5), 141 | ) 142 | 143 | 144 | def eval_agent(args, rng, env, agent_state): 145 | # --- Reset environment --- 146 | step = 0 147 | returned = onp.zeros(args.eval_workers).astype(bool) 148 | cum_reward = onp.zeros(args.eval_workers) 149 | rng, rng_reset = jax.random.split(rng) 150 | rng_reset = jax.random.split(rng_reset, args.eval_workers) 151 | obs = env.reset() 152 | 153 | # --- Rollout agent --- 154 | @jax.jit 155 | @jax.vmap 156 | def _policy_step(rng, obs): 157 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs, eval=True) 158 | action = pi.sample(seed=rng) 159 | return jnp.nan_to_num(action) 160 | 161 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps 162 | while step < max_episode_steps and not returned.all(): 163 | # --- Take step in environment --- 164 | step += 1 165 | rng, rng_step = jax.random.split(rng) 166 | rng_step = jax.random.split(rng_step, args.eval_workers) 167 | action = _policy_step(rng_step, jnp.array(obs)) 168 | obs, reward, done, info = env.step(onp.array(action)) 169 | 170 | # --- Track cumulative reward --- 171 | cum_reward += reward * ~returned 172 | returned |= done 173 | 174 | if step >= max_episode_steps and not returned.all(): 175 | warnings.warn("Maximum steps reached before all episodes terminated") 176 | return cum_reward 177 | 178 | 179 | r""" 180 | __/) 181 | .-(__(=: 182 | |\ | \) 183 | \ || 184 | \|| 185 | \| 186 | ___|_____ 187 | \ / 188 | \ / 189 | \___/ Agent 190 | """ 191 | 192 | 193 | def make_train_step(args, actor_apply_fn, q_apply_fn, value_apply_fn, dataset): 194 | """Make JIT-compatible agent train step.""" 195 | 196 | def _train_step(runner_state, _): 197 | rng, agent_state = runner_state 198 | 199 | # --- Sample batch --- 200 | rng, rng_batch = jax.random.split(rng) 201 | batch_indices = jax.random.randint( 202 | rng_batch, (args.batch_size,), 0, len(dataset.obs) 203 | ) 204 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset) 205 | 206 | # --- Update Q target network --- 207 | updated_q_target_params = optax.incremental_update( 208 | agent_state.dual_q.params, 209 | agent_state.dual_q_target.params, 210 | args.polyak_step_size, 211 | ) 212 | updated_q_target = agent_state.dual_q_target.replace( 213 | step=agent_state.dual_q_target.step + 1, params=updated_q_target_params 214 | ) 215 | agent_state = agent_state._replace(dual_q_target=updated_q_target) 216 | 217 | # --- Compute targets --- 218 | v_target = q_apply_fn(agent_state.dual_q_target.params, batch.obs, batch.action) 219 | v_target = v_target.min(-1) 220 | next_v_target = value_apply_fn(agent_state.value.params, batch.next_obs) 221 | q_targets = batch.reward + args.gamma * (1 - batch.done) * next_v_target 222 | 223 | # --- Update Q and value functions --- 224 | def _q_loss_fn(params): 225 | # Compute loss for both critics 226 | q_pred = q_apply_fn(params, batch.obs, batch.action) 227 | q_loss = jnp.square(q_pred - jnp.expand_dims(q_targets, axis=-1)).mean() 228 | return q_loss 229 | 230 | @partial(jax.value_and_grad, has_aux=True) 231 | def _value_loss_fn(params): 232 | adv = v_target - value_apply_fn(params, batch.obs) 233 | # Asymmetric L2 loss 234 | value_loss = jnp.abs(args.iql_tau - (adv < 0.0).astype(float)) * (adv**2) 235 | return jnp.mean(value_loss), adv 236 | 237 | q_loss, q_grad = jax.value_and_grad(_q_loss_fn)(agent_state.dual_q.params) 238 | (value_loss, adv), value_grad = _value_loss_fn(agent_state.value.params) 239 | agent_state = agent_state._replace( 240 | dual_q=agent_state.dual_q.apply_gradients(grads=q_grad), 241 | value=agent_state.value.apply_gradients(grads=value_grad), 242 | ) 243 | 244 | # --- Update actor --- 245 | exp_adv = jnp.exp(adv * args.beta).clip(max=args.exp_adv_clip) 246 | 247 | @jax.value_and_grad 248 | def _actor_loss_function(params): 249 | def _compute_loss(transition, exp_adv): 250 | pi = actor_apply_fn(params, transition.obs) 251 | bc_loss = -pi.log_prob(transition.action) 252 | return exp_adv * bc_loss.sum() 253 | 254 | actor_loss = jax.vmap(_compute_loss)(batch, exp_adv) 255 | return actor_loss.mean() 256 | 257 | actor_loss, actor_grad = _actor_loss_function(agent_state.actor.params) 258 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad) 259 | agent_state = agent_state._replace(actor=updated_actor) 260 | 261 | loss = { 262 | "value_loss": value_loss, 263 | "q_loss": q_loss, 264 | "actor_loss": actor_loss, 265 | } 266 | return (rng, agent_state), loss 267 | 268 | return _train_step 269 | 270 | 271 | if __name__ == "__main__": 272 | # --- Parse arguments --- 273 | args = tyro.cli(Args) 274 | rng = jax.random.PRNGKey(args.seed) 275 | 276 | # --- Initialize logger --- 277 | if args.log: 278 | wandb.init( 279 | config=args, 280 | project=args.wandb_project, 281 | entity=args.wandb_team, 282 | group=args.wandb_group, 283 | job_type="train_agent", 284 | ) 285 | 286 | # --- Initialize environment and dataset --- 287 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers) 288 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset)) 289 | dataset = Transition( 290 | obs=jnp.array(dataset["observations"]), 291 | action=jnp.array(dataset["actions"]), 292 | reward=jnp.array(dataset["rewards"]), 293 | next_obs=jnp.array(dataset["next_observations"]), 294 | done=jnp.array(dataset["terminals"]), 295 | ) 296 | 297 | # --- Initialize agent and value networks --- 298 | num_actions = env.single_action_space.shape[0] 299 | obs_mean = dataset.obs.mean(axis=0) 300 | obs_std = jnp.nan_to_num(dataset.obs.std(axis=0), nan=1.0) 301 | dummy_obs = jnp.zeros(env.single_observation_space.shape) 302 | dummy_action = jnp.zeros(num_actions) 303 | actor_net = TanhGaussianActor(num_actions, obs_mean, obs_std) 304 | q_net = DualQNetwork(obs_mean, obs_std) 305 | value_net = StateValueFunction(obs_mean, obs_std) 306 | 307 | # Target networks share seeds to match initialization 308 | rng, rng_actor, rng_q, rng_value = jax.random.split(rng, 4) 309 | agent_state = AgentTrainState( 310 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]), 311 | dual_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 312 | dual_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 313 | value=create_train_state(args, rng_value, value_net, [dummy_obs]), 314 | ) 315 | 316 | # --- Make train step --- 317 | _agent_train_step_fn = make_train_step( 318 | args, actor_net.apply, q_net.apply, value_net.apply, dataset 319 | ) 320 | 321 | num_evals = args.num_updates // args.eval_interval 322 | for eval_idx in range(num_evals): 323 | # --- Execute train loop --- 324 | (rng, agent_state), loss = jax.lax.scan( 325 | _agent_train_step_fn, 326 | (rng, agent_state), 327 | None, 328 | args.eval_interval, 329 | ) 330 | 331 | # --- Evaluate agent --- 332 | rng, rng_eval = jax.random.split(rng) 333 | returns = eval_agent(args, rng_eval, env, agent_state) 334 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0 335 | 336 | # --- Log metrics --- 337 | step = (eval_idx + 1) * args.eval_interval 338 | print("Step:", step, f"\t Score: {scores.mean():.2f}") 339 | if args.log: 340 | log_dict = { 341 | "return": returns.mean(), 342 | "score": scores.mean(), 343 | "score_std": scores.std(), 344 | "num_updates": step, 345 | **{k: loss[k][-1] for k in loss}, 346 | } 347 | wandb.log(log_dict) 348 | 349 | # --- Evaluate final agent --- 350 | if args.eval_final_episodes > 0: 351 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers)) 352 | print(f"Evaluating final agent for {final_iters} iterations...") 353 | _rng = jax.random.split(rng, final_iters) 354 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng]) 355 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0 356 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()} 357 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores") 358 | 359 | # --- Write final returns to file --- 360 | os.makedirs("final_returns", exist_ok=True) 361 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 362 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz" 363 | with open(os.path.join("final_returns", filename), "wb") as f: 364 | onp.savez_compressed(f, **info, args=asdict(args)) 365 | 366 | if args.log: 367 | wandb.save(os.path.join("final_returns", filename)) 368 | 369 | if args.log: 370 | wandb.finish() 371 | -------------------------------------------------------------------------------- /algorithms/mopo.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from dataclasses import dataclass, asdict 3 | from datetime import datetime 4 | from functools import partial 5 | import os 6 | import warnings 7 | 8 | import distrax 9 | import d4rl 10 | import flax.linen as nn 11 | from flax.linen.initializers import constant, uniform 12 | from flax.training.train_state import TrainState 13 | import gym 14 | import jax 15 | import jax.numpy as jnp 16 | import numpy as onp 17 | import optax 18 | import tyro 19 | import wandb 20 | 21 | from dynamics import ( 22 | Transition, 23 | load_dynamics_model, 24 | EnsembleDynamics, # required for loading dynamics model 25 | EnsembleDynamicsModel, # required for loading dynamics model 26 | ) 27 | 28 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 29 | 30 | 31 | @dataclass 32 | class Args: 33 | # --- Experiment --- 34 | seed: int = 0 35 | dataset: str = "halfcheetah-medium-v2" 36 | algorithm: str = "mopo" 37 | num_updates: int = 3_000_000 38 | eval_interval: int = 2500 39 | eval_workers: int = 8 40 | eval_final_episodes: int = 1000 41 | # --- Logging --- 42 | log: bool = False 43 | wandb_project: str = "unifloral" 44 | wandb_team: str = "flair" 45 | wandb_group: str = "debug" 46 | # --- Generic optimization --- 47 | lr: float = 1e-4 48 | batch_size: int = 256 49 | gamma: float = 0.99 50 | polyak_step_size: float = 0.005 51 | # --- SAC-N --- 52 | num_critics: int = 10 53 | # --- World model --- 54 | model_path: str = "" 55 | rollout_interval: int = 1000 56 | rollout_length: int = 5 57 | rollout_batch_size: int = 50000 58 | model_retain_epochs: int = 5 59 | dataset_sample_ratio: float = 0.05 60 | # --- MOPO --- 61 | step_penalty_coef: float = 0.5 62 | 63 | 64 | r""" 65 | |\ __ 66 | \| /_/ 67 | \| 68 | ___|_____ 69 | \ / 70 | \ / 71 | \___/ Preliminaries 72 | """ 73 | 74 | AgentTrainState = namedtuple("AgentTrainState", "actor vec_q vec_q_target alpha") 75 | 76 | 77 | def sym(scale): 78 | def _init(*args, **kwargs): 79 | return uniform(2 * scale)(*args, **kwargs) - scale 80 | 81 | return _init 82 | 83 | 84 | class SoftQNetwork(nn.Module): 85 | @nn.compact 86 | def __call__(self, obs, action): 87 | x = jnp.concatenate([obs, action], axis=-1) 88 | for _ in range(3): 89 | x = nn.Dense(256, bias_init=constant(0.1))(x) 90 | x = nn.relu(x) 91 | q = nn.Dense(1, kernel_init=sym(3e-3), bias_init=sym(3e-3))(x) 92 | return q.squeeze(-1) 93 | 94 | 95 | class VectorQ(nn.Module): 96 | num_critics: int 97 | 98 | @nn.compact 99 | def __call__(self, obs, action): 100 | vmap_critic = nn.vmap( 101 | SoftQNetwork, 102 | variable_axes={"params": 0}, # Parameters not shared between critics 103 | split_rngs={"params": True, "dropout": True}, # Different initializations 104 | in_axes=None, 105 | out_axes=-1, 106 | axis_size=self.num_critics, 107 | ) 108 | q_values = vmap_critic()(obs, action) 109 | return q_values 110 | 111 | 112 | class TanhGaussianActor(nn.Module): 113 | num_actions: int 114 | log_std_max: float = 2.0 115 | log_std_min: float = -5.0 116 | 117 | @nn.compact 118 | def __call__(self, x): 119 | for _ in range(3): 120 | x = nn.Dense(256, bias_init=constant(0.1))(x) 121 | x = nn.relu(x) 122 | log_std = nn.Dense( 123 | self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3) 124 | )(x) 125 | std = jnp.exp(jnp.clip(log_std, self.log_std_min, self.log_std_max)) 126 | mean = nn.Dense(self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3))(x) 127 | pi = distrax.Transformed( 128 | distrax.Normal(mean, std), 129 | distrax.Tanh(), 130 | ) 131 | return pi 132 | 133 | 134 | class EntropyCoef(nn.Module): 135 | ent_coef_init: float = 1.0 136 | 137 | @nn.compact 138 | def __call__(self) -> jnp.ndarray: 139 | log_ent_coef = self.param( 140 | "log_ent_coef", 141 | init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)), 142 | ) 143 | return log_ent_coef 144 | 145 | 146 | def create_train_state(args, rng, network, dummy_input): 147 | return TrainState.create( 148 | apply_fn=network.apply, 149 | params=network.init(rng, *dummy_input), 150 | tx=optax.adam(args.lr, eps=1e-5), 151 | ) 152 | 153 | 154 | def eval_agent(args, rng, env, agent_state): 155 | # --- Reset environment --- 156 | step = 0 157 | returned = onp.zeros(args.eval_workers).astype(bool) 158 | cum_reward = onp.zeros(args.eval_workers) 159 | rng, rng_reset = jax.random.split(rng) 160 | rng_reset = jax.random.split(rng_reset, args.eval_workers) 161 | obs = env.reset() 162 | 163 | # --- Rollout agent --- 164 | @jax.jit 165 | @jax.vmap 166 | def _policy_step(rng, obs): 167 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs) 168 | action = pi.sample(seed=rng) 169 | return jnp.nan_to_num(action) 170 | 171 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps 172 | while step < max_episode_steps and not returned.all(): 173 | # --- Take step in environment --- 174 | step += 1 175 | rng, rng_step = jax.random.split(rng) 176 | rng_step = jax.random.split(rng_step, args.eval_workers) 177 | action = _policy_step(rng_step, jnp.array(obs)) 178 | obs, reward, done, info = env.step(onp.array(action)) 179 | 180 | # --- Track cumulative reward --- 181 | cum_reward += reward * ~returned 182 | returned |= done 183 | 184 | if step >= max_episode_steps and not returned.all(): 185 | warnings.warn("Maximum steps reached before all episodes terminated") 186 | return cum_reward 187 | 188 | 189 | def sample_from_buffer(buffer, batch_size, rng): 190 | """Sample a batch from the buffer.""" 191 | idxs = jax.random.randint(rng, (batch_size,), 0, len(buffer.obs)) 192 | return jax.tree_map(lambda x: x[idxs], buffer) 193 | 194 | 195 | r""" 196 | __/) 197 | .-(__(=: 198 | |\ | \) 199 | \ || 200 | \|| 201 | \| 202 | ___|_____ 203 | \ / 204 | \ / 205 | \___/ Agent 206 | """ 207 | 208 | 209 | def make_train_step( 210 | args, actor_apply_fn, q_apply_fn, alpha_apply_fn, dataset, rollout_fn 211 | ): 212 | """Make JIT-compatible agent train step with model-based rollouts.""" 213 | 214 | def _train_step(runner_state, _): 215 | rng, agent_state, rollout_buffer = runner_state 216 | 217 | # --- Update model buffer --- 218 | params = agent_state.actor.params 219 | policy_fn = lambda obs, rng: actor_apply_fn(params, obs).sample(seed=rng) 220 | rng, rng_buffer = jax.random.split(rng) 221 | rollout_buffer = jax.lax.cond( 222 | agent_state.actor.step % args.rollout_interval == 0, 223 | lambda: rollout_fn(rng_buffer, policy_fn, rollout_buffer), 224 | lambda: rollout_buffer, 225 | ) 226 | 227 | # --- Sample batch --- 228 | rng, rng_dataset, rng_rollout = jax.random.split(rng, 3) 229 | dataset_size = int(args.batch_size * args.dataset_sample_ratio) 230 | rollout_size = args.batch_size - dataset_size 231 | dataset_batch = sample_from_buffer(dataset, dataset_size, rng_dataset) 232 | rollout_batch = sample_from_buffer(rollout_buffer, rollout_size, rng_rollout) 233 | batch = jax.tree_map( 234 | lambda x, y: jnp.concatenate([x, y]), dataset_batch, rollout_batch 235 | ) 236 | 237 | # --- Update alpha --- 238 | @jax.value_and_grad 239 | def _alpha_loss_fn(params, rng): 240 | def _compute_entropy(rng, transition): 241 | pi = actor_apply_fn(agent_state.actor.params, transition.obs) 242 | _, log_pi = pi.sample_and_log_prob(seed=rng) 243 | return -log_pi.sum() 244 | 245 | log_alpha = alpha_apply_fn(params) 246 | rng = jax.random.split(rng, args.batch_size) 247 | entropy = jax.vmap(_compute_entropy)(rng, batch).mean() 248 | target_entropy = -batch.action.shape[-1] 249 | return log_alpha * (entropy - target_entropy) 250 | 251 | rng, rng_alpha = jax.random.split(rng) 252 | alpha_loss, alpha_grad = _alpha_loss_fn(agent_state.alpha.params, rng_alpha) 253 | updated_alpha = agent_state.alpha.apply_gradients(grads=alpha_grad) 254 | agent_state = agent_state._replace(alpha=updated_alpha) 255 | alpha = jnp.exp(alpha_apply_fn(agent_state.alpha.params)) 256 | 257 | # --- Update actor --- 258 | @partial(jax.value_and_grad, has_aux=True) 259 | def _actor_loss_function(params, rng): 260 | def _compute_loss(rng, transition): 261 | pi = actor_apply_fn(params, transition.obs) 262 | sampled_action, log_pi = pi.sample_and_log_prob(seed=rng) 263 | log_pi = log_pi.sum() 264 | q_values = q_apply_fn( 265 | agent_state.vec_q.params, transition.obs, sampled_action 266 | ) 267 | q_min = jnp.min(q_values) 268 | return -q_min + alpha * log_pi, -log_pi, q_min, q_values.std() 269 | 270 | rng = jax.random.split(rng, args.batch_size) 271 | loss, entropy, q_min, q_std = jax.vmap(_compute_loss)(rng, batch) 272 | return loss.mean(), (entropy.mean(), q_min.mean(), q_std.mean()) 273 | 274 | rng, rng_actor = jax.random.split(rng) 275 | (actor_loss, (entropy, q_min, q_std)), actor_grad = _actor_loss_function( 276 | agent_state.actor.params, rng_actor 277 | ) 278 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad) 279 | agent_state = agent_state._replace(actor=updated_actor) 280 | 281 | # --- Update Q target network --- 282 | updated_q_target_params = optax.incremental_update( 283 | agent_state.vec_q.params, 284 | agent_state.vec_q_target.params, 285 | args.polyak_step_size, 286 | ) 287 | updated_q_target = agent_state.vec_q_target.replace( 288 | step=agent_state.vec_q_target.step + 1, params=updated_q_target_params 289 | ) 290 | agent_state = agent_state._replace(vec_q_target=updated_q_target) 291 | 292 | # --- Compute targets --- 293 | def _sample_next_v(rng, transition): 294 | next_pi = actor_apply_fn(agent_state.actor.params, transition.next_obs) 295 | # Note: Important to use sample_and_log_prob here for numerical stability 296 | # See https://github.com/deepmind/distrax/issues/7 297 | next_action, log_next_pi = next_pi.sample_and_log_prob(seed=rng) 298 | # Minimum of the target Q-values 299 | next_q = q_apply_fn( 300 | agent_state.vec_q_target.params, transition.next_obs, next_action 301 | ) 302 | return next_q.min(-1) - alpha * log_next_pi.sum(-1) 303 | 304 | rng, rng_next_v = jax.random.split(rng) 305 | rng_next_v = jax.random.split(rng_next_v, args.batch_size) 306 | next_v_target = jax.vmap(_sample_next_v)(rng_next_v, batch) 307 | target = batch.reward + args.gamma * (1 - batch.done) * next_v_target 308 | 309 | # --- Update critics --- 310 | @jax.value_and_grad 311 | def _q_loss_fn(params): 312 | q_pred = q_apply_fn(params, batch.obs, batch.action) 313 | return jnp.square((q_pred - jnp.expand_dims(target, -1))).sum(-1).mean() 314 | 315 | critic_loss, critic_grad = _q_loss_fn(agent_state.vec_q.params) 316 | updated_q = agent_state.vec_q.apply_gradients(grads=critic_grad) 317 | agent_state = agent_state._replace(vec_q=updated_q) 318 | 319 | num_done = jnp.sum(batch.done) 320 | loss = { 321 | "critic_loss": critic_loss, 322 | "actor_loss": actor_loss, 323 | "alpha_loss": alpha_loss, 324 | "entropy": entropy, 325 | "alpha": alpha, 326 | "q_min": q_min, 327 | "q_std": q_std, 328 | "terminations/num_done": num_done, 329 | "terminations/done_ratio": num_done / batch.done.shape[0], 330 | } 331 | return (rng, agent_state, rollout_buffer), loss 332 | 333 | return _train_step 334 | 335 | 336 | if __name__ == "__main__": 337 | # --- Parse arguments --- 338 | args = tyro.cli(Args) 339 | rng = jax.random.PRNGKey(args.seed) 340 | 341 | # --- Initialize logger --- 342 | if args.log: 343 | wandb.init( 344 | config=args, 345 | project=args.wandb_project, 346 | entity=args.wandb_team, 347 | group=args.wandb_group, 348 | job_type="train_agent", 349 | ) 350 | 351 | # --- Initialize environment and dataset --- 352 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers) 353 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset)) 354 | dataset = Transition( 355 | obs=jnp.array(dataset["observations"]), 356 | action=jnp.array(dataset["actions"]), 357 | reward=jnp.array(dataset["rewards"]), 358 | next_obs=jnp.array(dataset["next_observations"]), 359 | done=jnp.array(dataset["terminals"]), 360 | next_action=jnp.roll(dataset["actions"], -1, axis=0), 361 | ) 362 | 363 | # --- Initialize agent and value networks --- 364 | num_actions = env.single_action_space.shape[0] 365 | dummy_obs = jnp.zeros(env.single_observation_space.shape) 366 | dummy_action = jnp.zeros(num_actions) 367 | actor_net = TanhGaussianActor(num_actions) 368 | q_net = VectorQ(args.num_critics) 369 | alpha_net = EntropyCoef() 370 | 371 | # Target networks share seeds to match initialization 372 | rng, rng_actor, rng_q, rng_alpha = jax.random.split(rng, 4) 373 | agent_state = AgentTrainState( 374 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]), 375 | vec_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 376 | vec_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 377 | alpha=create_train_state(args, rng_alpha, alpha_net, []), 378 | ) 379 | 380 | # --- Initialize buffer and rollout function --- 381 | assert args.model_path, "Model path must be provided for model-based methods" 382 | dynamics_model = load_dynamics_model(args.model_path) 383 | dynamics_model.dataset = dataset 384 | max_buffer_size = args.rollout_batch_size * args.rollout_length 385 | max_buffer_size *= args.model_retain_epochs 386 | rollout_buffer = jax.tree_map( 387 | lambda x: jnp.zeros((max_buffer_size, *x.shape[1:])), 388 | dataset, 389 | ) 390 | rollout_fn = dynamics_model.make_rollout_fn( 391 | batch_size=args.rollout_batch_size, 392 | rollout_length=args.rollout_length, 393 | step_penalty_coef=args.step_penalty_coef, 394 | ) 395 | 396 | # --- Make train step --- 397 | _agent_train_step_fn = make_train_step( 398 | args, actor_net.apply, q_net.apply, alpha_net.apply, dataset, rollout_fn 399 | ) 400 | 401 | num_evals = args.num_updates // args.eval_interval 402 | for eval_idx in range(num_evals): 403 | # --- Execute train loop --- 404 | (rng, agent_state, rollout_buffer), loss = jax.lax.scan( 405 | _agent_train_step_fn, 406 | (rng, agent_state, rollout_buffer), 407 | None, 408 | args.eval_interval, 409 | ) 410 | 411 | # --- Evaluate agent --- 412 | rng, rng_eval = jax.random.split(rng) 413 | returns = eval_agent(args, rng_eval, env, agent_state) 414 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0 415 | 416 | # --- Log metrics --- 417 | step = (eval_idx + 1) * args.eval_interval 418 | print("Step:", step, f"\t Score: {scores.mean():.2f}") 419 | if args.log: 420 | log_dict = { 421 | "return": returns.mean(), 422 | "score": scores.mean(), 423 | "score_std": scores.std(), 424 | "num_updates": step, 425 | **{k: loss[k][-1] for k in loss}, 426 | } 427 | wandb.log(log_dict) 428 | 429 | # --- Evaluate final agent --- 430 | if args.eval_final_episodes > 0: 431 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers)) 432 | print(f"Evaluating final agent for {final_iters} iterations...") 433 | _rng = jax.random.split(rng, final_iters) 434 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng]) 435 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0 436 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()} 437 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores") 438 | 439 | # --- Write final returns to file --- 440 | os.makedirs("final_returns", exist_ok=True) 441 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 442 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz" 443 | with open(os.path.join("final_returns", filename), "wb") as f: 444 | onp.savez_compressed(f, **info, args=asdict(args)) 445 | 446 | if args.log: 447 | wandb.save(os.path.join("final_returns", filename)) 448 | 449 | if args.log: 450 | wandb.finish() 451 | -------------------------------------------------------------------------------- /algorithms/morel.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from dataclasses import dataclass, asdict 3 | from datetime import datetime 4 | from functools import partial 5 | import os 6 | import warnings 7 | 8 | import distrax 9 | import d4rl 10 | import flax.linen as nn 11 | from flax.linen.initializers import constant, uniform 12 | from flax.training.train_state import TrainState 13 | import gym 14 | import jax 15 | import jax.numpy as jnp 16 | import numpy as onp 17 | import optax 18 | import tyro 19 | import wandb 20 | 21 | from dynamics import ( 22 | Transition, 23 | load_dynamics_model, 24 | EnsembleDynamics, # required for loading dynamics model 25 | EnsembleDynamicsModel, # required for loading dynamics model 26 | ) 27 | 28 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 29 | 30 | 31 | @dataclass 32 | class Args: 33 | # --- Experiment --- 34 | seed: int = 0 35 | dataset: str = "halfcheetah-medium-v2" 36 | algorithm: str = "morel" 37 | num_updates: int = 3_000_000 38 | eval_interval: int = 2500 39 | eval_workers: int = 8 40 | eval_final_episodes: int = 1000 41 | # --- Logging --- 42 | log: bool = False 43 | wandb_project: str = "unifloral" 44 | wandb_team: str = "flair" 45 | wandb_group: str = "debug" 46 | # --- Generic optimization --- 47 | lr: float = 1e-4 48 | batch_size: int = 256 49 | gamma: float = 0.99 50 | polyak_step_size: float = 0.005 51 | # --- SAC-N --- 52 | num_critics: int = 10 53 | # --- World model --- 54 | model_path: str = "" 55 | rollout_interval: int = 1000 56 | rollout_length: int = 5 57 | rollout_batch_size: int = 50000 58 | model_retain_epochs: int = 5 59 | dataset_sample_ratio: float = 0.01 60 | # --- MOReL --- 61 | threshold_coef: float = 1.0 62 | term_penalty_offset: float = -200 63 | 64 | 65 | r""" 66 | |\ __ 67 | \| /_/ 68 | \| 69 | ___|_____ 70 | \ / 71 | \ / 72 | \___/ Preliminaries 73 | """ 74 | 75 | AgentTrainState = namedtuple("AgentTrainState", "actor vec_q vec_q_target alpha") 76 | 77 | 78 | def sym(scale): 79 | def _init(*args, **kwargs): 80 | return uniform(2 * scale)(*args, **kwargs) - scale 81 | 82 | return _init 83 | 84 | 85 | class SoftQNetwork(nn.Module): 86 | @nn.compact 87 | def __call__(self, obs, action): 88 | x = jnp.concatenate([obs, action], axis=-1) 89 | for _ in range(3): 90 | x = nn.Dense(256, bias_init=constant(0.1))(x) 91 | x = nn.relu(x) 92 | q = nn.Dense(1, kernel_init=sym(3e-3), bias_init=sym(3e-3))(x) 93 | return q.squeeze(-1) 94 | 95 | 96 | class VectorQ(nn.Module): 97 | num_critics: int 98 | 99 | @nn.compact 100 | def __call__(self, obs, action): 101 | vmap_critic = nn.vmap( 102 | SoftQNetwork, 103 | variable_axes={"params": 0}, # Parameters not shared between critics 104 | split_rngs={"params": True, "dropout": True}, # Different initializations 105 | in_axes=None, 106 | out_axes=-1, 107 | axis_size=self.num_critics, 108 | ) 109 | q_values = vmap_critic()(obs, action) 110 | return q_values 111 | 112 | 113 | class TanhGaussianActor(nn.Module): 114 | num_actions: int 115 | log_std_max: float = 2.0 116 | log_std_min: float = -5.0 117 | 118 | @nn.compact 119 | def __call__(self, x): 120 | for _ in range(3): 121 | x = nn.Dense(256, bias_init=constant(0.1))(x) 122 | x = nn.relu(x) 123 | log_std = nn.Dense( 124 | self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3) 125 | )(x) 126 | std = jnp.exp(jnp.clip(log_std, self.log_std_min, self.log_std_max)) 127 | mean = nn.Dense(self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3))(x) 128 | pi = distrax.Transformed( 129 | distrax.Normal(mean, std), 130 | distrax.Tanh(), 131 | ) 132 | return pi 133 | 134 | 135 | class EntropyCoef(nn.Module): 136 | ent_coef_init: float = 1.0 137 | 138 | @nn.compact 139 | def __call__(self) -> jnp.ndarray: 140 | log_ent_coef = self.param( 141 | "log_ent_coef", 142 | init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)), 143 | ) 144 | return log_ent_coef 145 | 146 | 147 | def create_train_state(args, rng, network, dummy_input): 148 | return TrainState.create( 149 | apply_fn=network.apply, 150 | params=network.init(rng, *dummy_input), 151 | tx=optax.adam(args.lr, eps=1e-5), 152 | ) 153 | 154 | 155 | def eval_agent(args, rng, env, agent_state): 156 | # --- Reset environment --- 157 | step = 0 158 | returned = onp.zeros(args.eval_workers).astype(bool) 159 | cum_reward = onp.zeros(args.eval_workers) 160 | rng, rng_reset = jax.random.split(rng) 161 | rng_reset = jax.random.split(rng_reset, args.eval_workers) 162 | obs = env.reset() 163 | 164 | # --- Rollout agent --- 165 | @jax.jit 166 | @jax.vmap 167 | def _policy_step(rng, obs): 168 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs) 169 | action = pi.sample(seed=rng) 170 | return jnp.nan_to_num(action) 171 | 172 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps 173 | while step < max_episode_steps and not returned.all(): 174 | # --- Take step in environment --- 175 | step += 1 176 | rng, rng_step = jax.random.split(rng) 177 | rng_step = jax.random.split(rng_step, args.eval_workers) 178 | action = _policy_step(rng_step, jnp.array(obs)) 179 | obs, reward, done, info = env.step(onp.array(action)) 180 | 181 | # --- Track cumulative reward --- 182 | cum_reward += reward * ~returned 183 | returned |= done 184 | 185 | if step >= max_episode_steps and not returned.all(): 186 | warnings.warn("Maximum steps reached before all episodes terminated") 187 | return cum_reward 188 | 189 | 190 | def sample_from_buffer(buffer, batch_size, rng): 191 | """Sample a batch from the buffer.""" 192 | idxs = jax.random.randint(rng, (batch_size,), 0, len(buffer.obs)) 193 | return jax.tree_map(lambda x: x[idxs], buffer) 194 | 195 | 196 | r""" 197 | __/) 198 | .-(__(=: 199 | |\ | \) 200 | \ || 201 | \|| 202 | \| 203 | ___|_____ 204 | \ / 205 | \ / 206 | \___/ Agent 207 | """ 208 | 209 | 210 | def make_train_step( 211 | args, actor_apply_fn, q_apply_fn, alpha_apply_fn, dataset, rollout_fn 212 | ): 213 | """Make JIT-compatible agent train step with model-based rollouts.""" 214 | 215 | def _train_step(runner_state, _): 216 | rng, agent_state, rollout_buffer = runner_state 217 | 218 | # --- Update model buffer --- 219 | params = agent_state.actor.params 220 | policy_fn = lambda obs, rng: actor_apply_fn(params, obs).sample(seed=rng) 221 | rng, rng_buffer = jax.random.split(rng) 222 | rollout_buffer = jax.lax.cond( 223 | agent_state.actor.step % args.rollout_interval == 0, 224 | lambda: rollout_fn(rng_buffer, policy_fn, rollout_buffer), 225 | lambda: rollout_buffer, 226 | ) 227 | 228 | # --- Sample batch --- 229 | rng, rng_dataset, rng_rollout = jax.random.split(rng, 3) 230 | dataset_size = int(args.batch_size * args.dataset_sample_ratio) 231 | rollout_size = args.batch_size - dataset_size 232 | dataset_batch = sample_from_buffer(dataset, dataset_size, rng_dataset) 233 | rollout_batch = sample_from_buffer(rollout_buffer, rollout_size, rng_rollout) 234 | batch = jax.tree_map( 235 | lambda x, y: jnp.concatenate([x, y]), dataset_batch, rollout_batch 236 | ) 237 | 238 | # --- Update alpha --- 239 | @jax.value_and_grad 240 | def _alpha_loss_fn(params, rng): 241 | def _compute_entropy(rng, transition): 242 | pi = actor_apply_fn(agent_state.actor.params, transition.obs) 243 | _, log_pi = pi.sample_and_log_prob(seed=rng) 244 | return -log_pi.sum() 245 | 246 | log_alpha = alpha_apply_fn(params) 247 | rng = jax.random.split(rng, args.batch_size) 248 | entropy = jax.vmap(_compute_entropy)(rng, batch).mean() 249 | target_entropy = -batch.action.shape[-1] 250 | return log_alpha * (entropy - target_entropy) 251 | 252 | rng, rng_alpha = jax.random.split(rng) 253 | alpha_loss, alpha_grad = _alpha_loss_fn(agent_state.alpha.params, rng_alpha) 254 | updated_alpha = agent_state.alpha.apply_gradients(grads=alpha_grad) 255 | agent_state = agent_state._replace(alpha=updated_alpha) 256 | alpha = jnp.exp(alpha_apply_fn(agent_state.alpha.params)) 257 | 258 | # --- Update actor --- 259 | @partial(jax.value_and_grad, has_aux=True) 260 | def _actor_loss_function(params, rng): 261 | def _compute_loss(rng, transition): 262 | pi = actor_apply_fn(params, transition.obs) 263 | sampled_action, log_pi = pi.sample_and_log_prob(seed=rng) 264 | log_pi = log_pi.sum() 265 | q_values = q_apply_fn( 266 | agent_state.vec_q.params, transition.obs, sampled_action 267 | ) 268 | q_min = jnp.min(q_values) 269 | return -q_min + alpha * log_pi, -log_pi, q_min, q_values.std() 270 | 271 | rng = jax.random.split(rng, args.batch_size) 272 | loss, entropy, q_min, q_std = jax.vmap(_compute_loss)(rng, batch) 273 | return loss.mean(), (entropy.mean(), q_min.mean(), q_std.mean()) 274 | 275 | rng, rng_actor = jax.random.split(rng) 276 | (actor_loss, (entropy, q_min, q_std)), actor_grad = _actor_loss_function( 277 | agent_state.actor.params, rng_actor 278 | ) 279 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad) 280 | agent_state = agent_state._replace(actor=updated_actor) 281 | 282 | # --- Update Q target network --- 283 | updated_q_target_params = optax.incremental_update( 284 | agent_state.vec_q.params, 285 | agent_state.vec_q_target.params, 286 | args.polyak_step_size, 287 | ) 288 | updated_q_target = agent_state.vec_q_target.replace( 289 | step=agent_state.vec_q_target.step + 1, params=updated_q_target_params 290 | ) 291 | agent_state = agent_state._replace(vec_q_target=updated_q_target) 292 | 293 | # --- Compute targets --- 294 | def _sample_next_v(rng, transition): 295 | next_pi = actor_apply_fn(agent_state.actor.params, transition.next_obs) 296 | # Note: Important to use sample_and_log_prob here for numerical stability 297 | # See https://github.com/deepmind/distrax/issues/7 298 | next_action, log_next_pi = next_pi.sample_and_log_prob(seed=rng) 299 | # Minimum of the target Q-values 300 | next_q = q_apply_fn( 301 | agent_state.vec_q_target.params, transition.next_obs, next_action 302 | ) 303 | return next_q.min(-1) - alpha * log_next_pi.sum(-1) 304 | 305 | rng, rng_next_v = jax.random.split(rng) 306 | rng_next_v = jax.random.split(rng_next_v, args.batch_size) 307 | next_v_target = jax.vmap(_sample_next_v)(rng_next_v, batch) 308 | target = batch.reward + args.gamma * (1 - batch.done) * next_v_target 309 | 310 | # --- Update critics --- 311 | @jax.value_and_grad 312 | def _q_loss_fn(params): 313 | q_pred = q_apply_fn(params, batch.obs, batch.action) 314 | return jnp.square((q_pred - jnp.expand_dims(target, -1))).sum(-1).mean() 315 | 316 | critic_loss, critic_grad = _q_loss_fn(agent_state.vec_q.params) 317 | updated_q = agent_state.vec_q.apply_gradients(grads=critic_grad) 318 | agent_state = agent_state._replace(vec_q=updated_q) 319 | 320 | num_done = jnp.sum(batch.done) 321 | loss = { 322 | "critic_loss": critic_loss, 323 | "actor_loss": actor_loss, 324 | "alpha_loss": alpha_loss, 325 | "entropy": entropy, 326 | "alpha": alpha, 327 | "q_min": q_min, 328 | "q_std": q_std, 329 | "terminations/num_done": num_done, 330 | "terminations/done_ratio": num_done / batch.done.shape[0], 331 | } 332 | return (rng, agent_state, rollout_buffer), loss 333 | 334 | return _train_step 335 | 336 | 337 | if __name__ == "__main__": 338 | # --- Parse arguments --- 339 | args = tyro.cli(Args) 340 | rng = jax.random.PRNGKey(args.seed) 341 | 342 | # --- Initialize logger --- 343 | if args.log: 344 | wandb.init( 345 | config=args, 346 | project=args.wandb_project, 347 | entity=args.wandb_team, 348 | group=args.wandb_group, 349 | job_type="train_agent", 350 | ) 351 | 352 | # --- Initialize environment and dataset --- 353 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers) 354 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset)) 355 | dataset = Transition( 356 | obs=jnp.array(dataset["observations"]), 357 | action=jnp.array(dataset["actions"]), 358 | reward=jnp.array(dataset["rewards"]), 359 | next_obs=jnp.array(dataset["next_observations"]), 360 | done=jnp.array(dataset["terminals"]), 361 | next_action=jnp.roll(dataset["actions"], -1, axis=0), 362 | ) 363 | 364 | # --- Initialize agent and value networks --- 365 | num_actions = env.single_action_space.shape[0] 366 | dummy_obs = jnp.zeros(env.single_observation_space.shape) 367 | dummy_action = jnp.zeros(num_actions) 368 | actor_net = TanhGaussianActor(num_actions) 369 | q_net = VectorQ(args.num_critics) 370 | alpha_net = EntropyCoef() 371 | 372 | # Target networks share seeds to match initialization 373 | rng, rng_actor, rng_q, rng_alpha = jax.random.split(rng, 4) 374 | agent_state = AgentTrainState( 375 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]), 376 | vec_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 377 | vec_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 378 | alpha=create_train_state(args, rng_alpha, alpha_net, []), 379 | ) 380 | 381 | # --- Initialize buffer and rollout function --- 382 | assert args.model_path, "Model path must be provided for model-based methods" 383 | dynamics_model = load_dynamics_model(args.model_path) 384 | assert ( 385 | dynamics_model.discrepancy is not None and dynamics_model.min_r is not None 386 | ), "MOReL requires a dynamics model with precomputed statistics (train with --precompute_term_stats)" 387 | dynamics_model.dataset = dataset 388 | max_buffer_size = args.rollout_batch_size * args.rollout_length 389 | max_buffer_size *= args.model_retain_epochs 390 | rollout_buffer = jax.tree_map( 391 | lambda x: jnp.zeros((max_buffer_size, *x.shape[1:])), 392 | dataset, 393 | ) 394 | rollout_fn = dynamics_model.make_rollout_fn( 395 | batch_size=args.rollout_batch_size, 396 | rollout_length=args.rollout_length, 397 | term_penalty_offset=args.term_penalty_offset, 398 | threshold_coef=args.threshold_coef, 399 | ) 400 | 401 | # --- Make train step --- 402 | _agent_train_step_fn = make_train_step( 403 | args, actor_net.apply, q_net.apply, alpha_net.apply, dataset, rollout_fn 404 | ) 405 | 406 | num_evals = args.num_updates // args.eval_interval 407 | for eval_idx in range(num_evals): 408 | # --- Execute train loop --- 409 | (rng, agent_state, rollout_buffer), loss = jax.lax.scan( 410 | _agent_train_step_fn, 411 | (rng, agent_state, rollout_buffer), 412 | None, 413 | args.eval_interval, 414 | ) 415 | 416 | # --- Evaluate agent --- 417 | rng, rng_eval = jax.random.split(rng) 418 | returns = eval_agent(args, rng_eval, env, agent_state) 419 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0 420 | 421 | # --- Log metrics --- 422 | step = (eval_idx + 1) * args.eval_interval 423 | print("Step:", step, f"\t Score: {scores.mean():.2f}") 424 | if args.log: 425 | log_dict = { 426 | "return": returns.mean(), 427 | "score": scores.mean(), 428 | "score_std": scores.std(), 429 | "num_updates": step, 430 | **{k: loss[k][-1] for k in loss}, 431 | } 432 | wandb.log(log_dict) 433 | 434 | # --- Evaluate final agent --- 435 | if args.eval_final_episodes > 0: 436 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers)) 437 | print(f"Evaluating final agent for {final_iters} iterations...") 438 | _rng = jax.random.split(rng, final_iters) 439 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng]) 440 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0 441 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()} 442 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores") 443 | 444 | # --- Write final returns to file --- 445 | os.makedirs("final_returns", exist_ok=True) 446 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 447 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz" 448 | with open(os.path.join("final_returns", filename), "wb") as f: 449 | onp.savez_compressed(f, **info, args=asdict(args)) 450 | 451 | if args.log: 452 | wandb.save(os.path.join("final_returns", filename)) 453 | 454 | if args.log: 455 | wandb.finish() 456 | -------------------------------------------------------------------------------- /algorithms/rebrac.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from dataclasses import dataclass, asdict 3 | from datetime import datetime 4 | import os 5 | import warnings 6 | 7 | import distrax 8 | import d4rl 9 | import flax.linen as nn 10 | from flax.linen.initializers import constant, uniform 11 | from flax.training.train_state import TrainState 12 | import gym 13 | import jax 14 | import jax.numpy as jnp 15 | import numpy as onp 16 | import optax 17 | import tyro 18 | import wandb 19 | 20 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 21 | 22 | 23 | @dataclass 24 | class Args: 25 | # --- Experiment --- 26 | seed: int = 0 27 | dataset: str = "halfcheetah-medium-v2" 28 | algorithm: str = "rebrac" 29 | num_updates: int = 1_000_000 30 | eval_interval: int = 2500 31 | eval_workers: int = 8 32 | eval_final_episodes: int = 1000 33 | # --- Logging --- 34 | log: bool = False 35 | wandb_project: str = "unifloral" 36 | wandb_team: str = "flair" 37 | wandb_group: str = "debug" 38 | # --- Generic optimization --- 39 | lr: float = 1e-3 40 | batch_size: int = 1024 41 | gamma: float = 0.99 42 | polyak_step_size: float = 0.005 43 | # --- TD3+BC --- 44 | noise_clip: float = 0.5 45 | policy_noise: float = 0.2 46 | num_critic_updates_per_step: int = 2 47 | # --- REBRAC --- 48 | critic_bc_coef: float = 0.01 49 | actor_bc_coef: float = 0.001 50 | actor_ln: bool = False 51 | critic_ln: bool = True 52 | norm_obs: bool = False 53 | 54 | 55 | r""" 56 | |\ __ 57 | \| /_/ 58 | \| 59 | ___|_____ 60 | \ / 61 | \ / 62 | \___/ Preliminaries 63 | """ 64 | 65 | AgentTrainState = namedtuple( 66 | "AgentTrainState", "actor actor_target dual_q dual_q_target" 67 | ) 68 | Transition = namedtuple("Transition", "obs action reward next_obs next_action done") 69 | 70 | 71 | def sym(scale): 72 | def _init(*args, **kwargs): 73 | return uniform(2 * scale)(*args, **kwargs) - scale 74 | 75 | return _init 76 | 77 | 78 | class SoftQNetwork(nn.Module): 79 | obs_mean: jax.Array 80 | obs_std: jax.Array 81 | use_ln: bool 82 | norm_obs: bool 83 | 84 | @nn.compact 85 | def __call__(self, obs, action): 86 | if self.norm_obs: 87 | obs = (obs - self.obs_mean) / (self.obs_std + 1e-3) 88 | x = jnp.concatenate([obs, action], axis=-1) 89 | for _ in range(3): 90 | x = nn.Dense(256, bias_init=constant(0.1))(x) 91 | x = nn.relu(x) 92 | x = nn.LayerNorm()(x) if self.use_ln else x 93 | q = nn.Dense(1, bias_init=sym(3e-3), kernel_init=sym(3e-3))(x) 94 | return q.squeeze(-1) 95 | 96 | 97 | class DualQNetwork(nn.Module): 98 | obs_mean: jax.Array 99 | obs_std: jax.Array 100 | use_ln: bool 101 | norm_obs: bool 102 | 103 | @nn.compact 104 | def __call__(self, obs, action): 105 | vmap_critic = nn.vmap( 106 | SoftQNetwork, 107 | variable_axes={"params": 0}, # Parameters not shared between critics 108 | split_rngs={"params": True, "dropout": True}, # Different initializations 109 | in_axes=None, 110 | out_axes=-1, 111 | axis_size=2, # Two Q networks 112 | ) 113 | q_fn = vmap_critic(self.obs_mean, self.obs_std, self.use_ln, self.norm_obs) 114 | q_values = q_fn(obs, action) 115 | return q_values 116 | 117 | 118 | class DeterministicTanhActor(nn.Module): 119 | num_actions: int 120 | obs_mean: jax.Array 121 | obs_std: jax.Array 122 | use_ln: bool 123 | norm_obs: bool 124 | 125 | @nn.compact 126 | def __call__(self, x): 127 | if self.norm_obs: 128 | x = (x - self.obs_mean) / (self.obs_std + 1e-3) 129 | for _ in range(3): 130 | x = nn.Dense(256, bias_init=constant(0.1))(x) 131 | x = nn.relu(x) 132 | x = nn.LayerNorm()(x) if self.use_ln else x 133 | init_fn = sym(1e-3) 134 | action = nn.Dense(self.num_actions, bias_init=init_fn, kernel_init=init_fn)(x) 135 | pi = distrax.Transformed( 136 | distrax.Deterministic(action), 137 | distrax.Tanh(), 138 | ) 139 | return pi 140 | 141 | 142 | def create_train_state(args, rng, network, dummy_input): 143 | return TrainState.create( 144 | apply_fn=network.apply, 145 | params=network.init(rng, *dummy_input), 146 | tx=optax.adam(args.lr, eps=1e-5), 147 | ) 148 | 149 | 150 | def eval_agent(args, rng, env, agent_state): 151 | # --- Reset environment --- 152 | step = 0 153 | returned = onp.zeros(args.eval_workers).astype(bool) 154 | cum_reward = onp.zeros(args.eval_workers) 155 | rng, rng_reset = jax.random.split(rng) 156 | rng_reset = jax.random.split(rng_reset, args.eval_workers) 157 | obs = env.reset() 158 | 159 | # --- Rollout agent --- 160 | @jax.jit 161 | @jax.vmap 162 | def _policy_step(rng, obs): 163 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs) 164 | action = pi.sample(seed=rng) 165 | return jnp.nan_to_num(action) 166 | 167 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps 168 | while step < max_episode_steps and not returned.all(): 169 | # --- Take step in environment --- 170 | step += 1 171 | rng, rng_step = jax.random.split(rng) 172 | rng_step = jax.random.split(rng_step, args.eval_workers) 173 | action = _policy_step(rng_step, jnp.array(obs)) 174 | obs, reward, done, info = env.step(onp.array(action)) 175 | 176 | # --- Track cumulative reward --- 177 | cum_reward += reward * ~returned 178 | returned |= done 179 | 180 | if step >= max_episode_steps and not returned.all(): 181 | warnings.warn("Maximum steps reached before all episodes terminated") 182 | return cum_reward 183 | 184 | 185 | r""" 186 | __/) 187 | .-(__(=: 188 | |\ | \) 189 | \ || 190 | \|| 191 | \| 192 | ___|_____ 193 | \ / 194 | \ / 195 | \___/ Agent 196 | """ 197 | 198 | 199 | def make_train_step(args, actor_apply_fn, q_apply_fn, dataset): 200 | """Make JIT-compatible agent train step.""" 201 | 202 | def _train_step(runner_state, _): 203 | rng, agent_state = runner_state 204 | 205 | # --- Sample batch --- 206 | rng, rng_batch = jax.random.split(rng) 207 | batch_indices = jax.random.randint( 208 | rng_batch, (args.batch_size,), 0, len(dataset.obs) 209 | ) 210 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset) 211 | 212 | # --- Update critics --- 213 | def _update_critics(runner_state, _): 214 | rng, agent_state = runner_state 215 | 216 | def _compute_target(rng, transition): 217 | next_obs = transition.next_obs 218 | 219 | # --- Sample noised action --- 220 | next_pi = actor_apply_fn(agent_state.actor_target.params, next_obs) 221 | rng, rng_action, rng_noise = jax.random.split(rng, 3) 222 | action = next_pi.sample(seed=rng_action) 223 | noise = jax.random.normal(rng_noise, shape=action.shape) 224 | noise *= args.policy_noise 225 | noise = jnp.clip(noise, -args.noise_clip, args.noise_clip) 226 | action = jnp.clip(action + noise, -1, 1) 227 | bc_loss = jnp.square(action - transition.next_action).sum() 228 | 229 | # --- Compute targets --- 230 | target_q = q_apply_fn( 231 | agent_state.dual_q_target.params, next_obs, action 232 | ) 233 | next_q_value = jnp.min(target_q) - args.critic_bc_coef * bc_loss 234 | next_q_value = (1.0 - transition.done) * next_q_value 235 | return transition.reward + args.gamma * next_q_value, bc_loss 236 | 237 | rng, rng_targets = jax.random.split(rng) 238 | rng_targets = jax.random.split(rng_targets, args.batch_size) 239 | target_fn = jax.vmap(_compute_target) 240 | targets, bc_loss = target_fn(rng_targets, batch) 241 | 242 | # --- Compute critic loss --- 243 | @jax.value_and_grad 244 | def _q_loss_fn(params): 245 | q_pred = q_apply_fn(params, batch.obs, batch.action) 246 | q_loss = jnp.square(q_pred - jnp.expand_dims(targets, axis=-1)).sum(-1) 247 | return q_loss.mean() 248 | 249 | q_loss, q_grad = _q_loss_fn(agent_state.dual_q.params) 250 | updated_q_state = agent_state.dual_q.apply_gradients(grads=q_grad) 251 | agent_state = agent_state._replace(dual_q=updated_q_state) 252 | return (rng, agent_state), (q_loss, bc_loss) 253 | 254 | # --- Iterate critic update --- 255 | (rng, agent_state), (q_loss, critic_bc_loss) = jax.lax.scan( 256 | _update_critics, 257 | (rng, agent_state), 258 | None, 259 | length=args.num_critic_updates_per_step, 260 | ) 261 | 262 | # --- Update actor --- 263 | def _actor_loss_function(params): 264 | def _transition_loss(transition): 265 | pi = actor_apply_fn(params, transition.obs) 266 | pi_action = pi.sample(seed=None) 267 | q = q_apply_fn(agent_state.dual_q.params, transition.obs, pi_action) 268 | bc_loss = jnp.square(pi_action - transition.action).sum() 269 | return q.min(), bc_loss 270 | 271 | q, bc_loss = jax.vmap(_transition_loss)(batch) 272 | lambda_ = 1.0 / (jnp.abs(q).mean() + 1e-7) 273 | lambda_ = jax.lax.stop_gradient(lambda_) 274 | actor_loss = -lambda_ * q.mean() + args.actor_bc_coef * bc_loss.mean() 275 | return actor_loss.mean(), (q.mean(), lambda_.mean(), bc_loss.mean()) 276 | 277 | loss_fn = jax.value_and_grad(_actor_loss_function, has_aux=True) 278 | (actor_loss, (q_mean, lambda_, bc_loss)), actor_grad = loss_fn( 279 | agent_state.actor.params 280 | ) 281 | agent_state = agent_state._replace( 282 | actor=agent_state.actor.apply_gradients(grads=actor_grad) 283 | ) 284 | 285 | # --- Update target networks --- 286 | def _update_target(state, target_state): 287 | new_target_params = optax.incremental_update( 288 | state.params, target_state.params, args.polyak_step_size 289 | ) 290 | return target_state.replace( 291 | step=target_state.step + 1, params=new_target_params 292 | ) 293 | 294 | agent_state = agent_state._replace( 295 | actor_target=_update_target(agent_state.actor, agent_state.actor_target), 296 | dual_q_target=_update_target(agent_state.dual_q, agent_state.dual_q_target), 297 | ) 298 | 299 | loss = { 300 | "actor_loss": actor_loss, 301 | "q_loss": q_loss.mean(), 302 | "q_mean": q_mean, 303 | "lambda": lambda_, 304 | "bc_loss": bc_loss, 305 | "critic_bc_loss": critic_bc_loss.mean(), 306 | } 307 | return (rng, agent_state), loss 308 | 309 | return _train_step 310 | 311 | 312 | if __name__ == "__main__": 313 | # --- Parse arguments --- 314 | args = tyro.cli(Args) 315 | rng = jax.random.PRNGKey(args.seed) 316 | 317 | # --- Initialize logger --- 318 | if args.log: 319 | wandb.init( 320 | config=args, 321 | project=args.wandb_project, 322 | entity=args.wandb_team, 323 | group=args.wandb_group, 324 | job_type="train_agent", 325 | ) 326 | 327 | # --- Initialize environment and dataset --- 328 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers) 329 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset)) 330 | dataset = Transition( 331 | obs=jnp.array(dataset["observations"]), 332 | action=jnp.array(dataset["actions"]), 333 | reward=jnp.array(dataset["rewards"]), 334 | next_obs=jnp.array(dataset["next_observations"]), 335 | next_action=jnp.roll(dataset["actions"], -1, axis=0), 336 | done=jnp.array(dataset["terminals"]), 337 | ) 338 | 339 | # --- Initialize agent and value networks --- 340 | num_actions = env.single_action_space.shape[0] 341 | obs_mean = dataset.obs.mean(axis=0) 342 | obs_std = jnp.nan_to_num(dataset.obs.std(axis=0), nan=1.0) 343 | dummy_obs = jnp.zeros(env.single_observation_space.shape) 344 | dummy_action = jnp.zeros(num_actions) 345 | actor_cls = DeterministicTanhActor 346 | actor_net = actor_cls(num_actions, obs_mean, obs_std, args.actor_ln, args.norm_obs) 347 | q_net = DualQNetwork(obs_mean, obs_std, args.critic_ln, args.norm_obs) 348 | 349 | # Target networks share seeds to match initialization 350 | rng, rng_actor, rng_q = jax.random.split(rng, 3) 351 | agent_state = AgentTrainState( 352 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]), 353 | actor_target=create_train_state(args, rng_actor, actor_net, [dummy_obs]), 354 | dual_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 355 | dual_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 356 | ) 357 | 358 | # --- Make train step --- 359 | _agent_train_step_fn = make_train_step(args, actor_net.apply, q_net.apply, dataset) 360 | 361 | num_evals = args.num_updates // args.eval_interval 362 | for eval_idx in range(num_evals): 363 | # --- Execute train loop --- 364 | (rng, agent_state), loss = jax.lax.scan( 365 | _agent_train_step_fn, 366 | (rng, agent_state), 367 | None, 368 | args.eval_interval, 369 | ) 370 | 371 | # --- Evaluate agent --- 372 | rng, rng_eval = jax.random.split(rng) 373 | returns = eval_agent(args, rng_eval, env, agent_state) 374 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0 375 | 376 | # --- Log metrics --- 377 | step = (eval_idx + 1) * args.eval_interval 378 | print("Step:", step, f"\t Score: {scores.mean():.2f}") 379 | if args.log: 380 | log_dict = { 381 | "return": returns.mean(), 382 | "score": scores.mean(), 383 | "score_std": scores.std(), 384 | "num_updates": step, 385 | **{k: loss[k][-1] for k in loss}, 386 | } 387 | wandb.log(log_dict) 388 | 389 | # --- Evaluate final agent --- 390 | if args.eval_final_episodes > 0: 391 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers)) 392 | print(f"Evaluating final agent for {final_iters} iterations...") 393 | _rng = jax.random.split(rng, final_iters) 394 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng]) 395 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0 396 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()} 397 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores") 398 | 399 | # --- Write final returns to file --- 400 | os.makedirs("final_returns", exist_ok=True) 401 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 402 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz" 403 | with open(os.path.join("final_returns", filename), "wb") as f: 404 | onp.savez_compressed(f, **info, args=asdict(args)) 405 | 406 | if args.log: 407 | wandb.save(os.path.join("final_returns", filename)) 408 | 409 | if args.log: 410 | wandb.finish() 411 | -------------------------------------------------------------------------------- /algorithms/sac_n.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from dataclasses import dataclass, asdict 3 | from datetime import datetime 4 | from functools import partial 5 | import os 6 | import warnings 7 | 8 | import distrax 9 | import d4rl 10 | import flax.linen as nn 11 | from flax.linen.initializers import constant, uniform 12 | from flax.training.train_state import TrainState 13 | import gym 14 | import jax 15 | import jax.numpy as jnp 16 | import numpy as onp 17 | import optax 18 | import tyro 19 | import wandb 20 | 21 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 22 | 23 | 24 | @dataclass 25 | class Args: 26 | # --- Experiment --- 27 | seed: int = 0 28 | dataset: str = "halfcheetah-medium-v2" 29 | algorithm: str = "sac_n" 30 | num_updates: int = 3_000_000 31 | eval_interval: int = 2500 32 | eval_workers: int = 8 33 | eval_final_episodes: int = 1000 34 | # --- Logging --- 35 | log: bool = False 36 | wandb_project: str = "unifloral" 37 | wandb_team: str = "flair" 38 | wandb_group: str = "debug" 39 | # --- Generic optimization --- 40 | lr: float = 3e-4 41 | batch_size: int = 256 42 | gamma: float = 0.99 43 | polyak_step_size: float = 0.005 44 | # --- SAC-N --- 45 | num_critics: int = 10 46 | 47 | 48 | r""" 49 | |\ __ 50 | \| /_/ 51 | \| 52 | ___|_____ 53 | \ / 54 | \ / 55 | \___/ Preliminaries 56 | """ 57 | 58 | AgentTrainState = namedtuple("AgentTrainState", "actor vec_q vec_q_target alpha") 59 | Transition = namedtuple("Transition", "obs action reward next_obs done") 60 | 61 | 62 | def sym(scale): 63 | def _init(*args, **kwargs): 64 | return uniform(2 * scale)(*args, **kwargs) - scale 65 | 66 | return _init 67 | 68 | 69 | class SoftQNetwork(nn.Module): 70 | @nn.compact 71 | def __call__(self, obs, action): 72 | x = jnp.concatenate([obs, action], axis=-1) 73 | for _ in range(3): 74 | x = nn.Dense(256, bias_init=constant(0.1))(x) 75 | x = nn.relu(x) 76 | q = nn.Dense(1, kernel_init=sym(3e-3), bias_init=sym(3e-3))(x) 77 | return q.squeeze(-1) 78 | 79 | 80 | class VectorQ(nn.Module): 81 | num_critics: int 82 | 83 | @nn.compact 84 | def __call__(self, obs, action): 85 | vmap_critic = nn.vmap( 86 | SoftQNetwork, 87 | variable_axes={"params": 0}, # Parameters not shared between critics 88 | split_rngs={"params": True, "dropout": True}, # Different initializations 89 | in_axes=None, 90 | out_axes=-1, 91 | axis_size=self.num_critics, 92 | ) 93 | q_values = vmap_critic()(obs, action) 94 | return q_values 95 | 96 | 97 | class TanhGaussianActor(nn.Module): 98 | num_actions: int 99 | log_std_max: float = 2.0 100 | log_std_min: float = -5.0 101 | 102 | @nn.compact 103 | def __call__(self, x): 104 | for _ in range(3): 105 | x = nn.Dense(256, bias_init=constant(0.1))(x) 106 | x = nn.relu(x) 107 | log_std = nn.Dense( 108 | self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3) 109 | )(x) 110 | std = jnp.exp(jnp.clip(log_std, self.log_std_min, self.log_std_max)) 111 | mean = nn.Dense(self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3))(x) 112 | pi = distrax.Transformed( 113 | distrax.Normal(mean, std), 114 | distrax.Tanh(), 115 | ) 116 | return pi 117 | 118 | 119 | class EntropyCoef(nn.Module): 120 | ent_coef_init: float = 1.0 121 | 122 | @nn.compact 123 | def __call__(self) -> jnp.ndarray: 124 | log_ent_coef = self.param( 125 | "log_ent_coef", 126 | init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)), 127 | ) 128 | return log_ent_coef 129 | 130 | 131 | def create_train_state(args, rng, network, dummy_input): 132 | return TrainState.create( 133 | apply_fn=network.apply, 134 | params=network.init(rng, *dummy_input), 135 | tx=optax.adam(args.lr, eps=1e-5), 136 | ) 137 | 138 | 139 | def eval_agent(args, rng, env, agent_state): 140 | # --- Reset environment --- 141 | step = 0 142 | returned = onp.zeros(args.eval_workers).astype(bool) 143 | cum_reward = onp.zeros(args.eval_workers) 144 | rng, rng_reset = jax.random.split(rng) 145 | rng_reset = jax.random.split(rng_reset, args.eval_workers) 146 | obs = env.reset() 147 | 148 | # --- Rollout agent --- 149 | @jax.jit 150 | @jax.vmap 151 | def _policy_step(rng, obs): 152 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs) 153 | action = pi.sample(seed=rng) 154 | return jnp.nan_to_num(action) 155 | 156 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps 157 | while step < max_episode_steps and not returned.all(): 158 | # --- Take step in environment --- 159 | step += 1 160 | rng, rng_step = jax.random.split(rng) 161 | rng_step = jax.random.split(rng_step, args.eval_workers) 162 | action = _policy_step(rng_step, jnp.array(obs)) 163 | obs, reward, done, info = env.step(onp.array(action)) 164 | 165 | # --- Track cumulative reward --- 166 | cum_reward += reward * ~returned 167 | returned |= done 168 | 169 | if step >= max_episode_steps and not returned.all(): 170 | warnings.warn("Maximum steps reached before all episodes terminated") 171 | return cum_reward 172 | 173 | 174 | r""" 175 | __/) 176 | .-(__(=: 177 | |\ | \) 178 | \ || 179 | \|| 180 | \| 181 | ___|_____ 182 | \ / 183 | \ / 184 | \___/ Agent 185 | """ 186 | 187 | 188 | def make_train_step(args, actor_apply_fn, q_apply_fn, alpha_apply_fn, dataset): 189 | """Make JIT-compatible agent train step.""" 190 | 191 | def _train_step(runner_state, _): 192 | rng, agent_state = runner_state 193 | 194 | # --- Sample batch --- 195 | rng, rng_batch = jax.random.split(rng) 196 | batch_indices = jax.random.randint( 197 | rng_batch, (args.batch_size,), 0, len(dataset.obs) 198 | ) 199 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset) 200 | 201 | # --- Update alpha --- 202 | @jax.value_and_grad 203 | def _alpha_loss_fn(params, rng): 204 | def _compute_entropy(rng, transition): 205 | pi = actor_apply_fn(agent_state.actor.params, transition.obs) 206 | _, log_pi = pi.sample_and_log_prob(seed=rng) 207 | return -log_pi.sum() 208 | 209 | log_alpha = alpha_apply_fn(params) 210 | rng = jax.random.split(rng, args.batch_size) 211 | entropy = jax.vmap(_compute_entropy)(rng, batch).mean() 212 | target_entropy = -batch.action.shape[-1] 213 | return log_alpha * (entropy - target_entropy) 214 | 215 | rng, rng_alpha = jax.random.split(rng) 216 | alpha_loss, alpha_grad = _alpha_loss_fn(agent_state.alpha.params, rng_alpha) 217 | updated_alpha = agent_state.alpha.apply_gradients(grads=alpha_grad) 218 | agent_state = agent_state._replace(alpha=updated_alpha) 219 | alpha = jnp.exp(alpha_apply_fn(agent_state.alpha.params)) 220 | 221 | # --- Update actor --- 222 | @partial(jax.value_and_grad, has_aux=True) 223 | def _actor_loss_function(params, rng): 224 | def _compute_loss(rng, transition): 225 | pi = actor_apply_fn(params, transition.obs) 226 | sampled_action, log_pi = pi.sample_and_log_prob(seed=rng) 227 | log_pi = log_pi.sum() 228 | q_values = q_apply_fn( 229 | agent_state.vec_q.params, transition.obs, sampled_action 230 | ) 231 | q_min = jnp.min(q_values) 232 | return -q_min + alpha * log_pi, -log_pi, q_min, q_values.std() 233 | 234 | rng = jax.random.split(rng, args.batch_size) 235 | loss, entropy, q_min, q_std = jax.vmap(_compute_loss)(rng, batch) 236 | return loss.mean(), (entropy.mean(), q_min.mean(), q_std.mean()) 237 | 238 | rng, rng_actor = jax.random.split(rng) 239 | (actor_loss, (entropy, q_min, q_std)), actor_grad = _actor_loss_function( 240 | agent_state.actor.params, rng_actor 241 | ) 242 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad) 243 | agent_state = agent_state._replace(actor=updated_actor) 244 | 245 | # --- Update Q target network --- 246 | updated_q_target_params = optax.incremental_update( 247 | agent_state.vec_q.params, 248 | agent_state.vec_q_target.params, 249 | args.polyak_step_size, 250 | ) 251 | updated_q_target = agent_state.vec_q_target.replace( 252 | step=agent_state.vec_q_target.step + 1, params=updated_q_target_params 253 | ) 254 | agent_state = agent_state._replace(vec_q_target=updated_q_target) 255 | 256 | # --- Compute targets --- 257 | def _sample_next_v(rng, transition): 258 | next_pi = actor_apply_fn(agent_state.actor.params, transition.next_obs) 259 | # Note: Important to use sample_and_log_prob here for numerical stability 260 | # See https://github.com/deepmind/distrax/issues/7 261 | next_action, log_next_pi = next_pi.sample_and_log_prob(seed=rng) 262 | # Minimum of the target Q-values 263 | next_q = q_apply_fn( 264 | agent_state.vec_q_target.params, transition.next_obs, next_action 265 | ) 266 | return next_q.min(-1) - alpha * log_next_pi.sum(-1) 267 | 268 | rng, rng_next_v = jax.random.split(rng) 269 | rng_next_v = jax.random.split(rng_next_v, args.batch_size) 270 | next_v_target = jax.vmap(_sample_next_v)(rng_next_v, batch) 271 | target = batch.reward + args.gamma * (1 - batch.done) * next_v_target 272 | 273 | # --- Update critics --- 274 | @jax.value_and_grad 275 | def _q_loss_fn(params): 276 | q_pred = q_apply_fn(params, batch.obs, batch.action) 277 | return jnp.square((q_pred - jnp.expand_dims(target, -1))).sum(-1).mean() 278 | 279 | critic_loss, critic_grad = _q_loss_fn(agent_state.vec_q.params) 280 | updated_q = agent_state.vec_q.apply_gradients(grads=critic_grad) 281 | agent_state = agent_state._replace(vec_q=updated_q) 282 | 283 | loss = { 284 | "critic_loss": critic_loss, 285 | "actor_loss": actor_loss, 286 | "alpha_loss": alpha_loss, 287 | "entropy": entropy, 288 | "alpha": alpha, 289 | "q_min": q_min, 290 | "q_std": q_std, 291 | } 292 | return (rng, agent_state), loss 293 | 294 | return _train_step 295 | 296 | 297 | if __name__ == "__main__": 298 | # --- Parse arguments --- 299 | args = tyro.cli(Args) 300 | rng = jax.random.PRNGKey(args.seed) 301 | 302 | # --- Initialize logger --- 303 | if args.log: 304 | wandb.init( 305 | config=args, 306 | project=args.wandb_project, 307 | entity=args.wandb_team, 308 | group=args.wandb_group, 309 | job_type="train_agent", 310 | ) 311 | 312 | # --- Initialize environment and dataset --- 313 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers) 314 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset)) 315 | dataset = Transition( 316 | obs=jnp.array(dataset["observations"]), 317 | action=jnp.array(dataset["actions"]), 318 | reward=jnp.array(dataset["rewards"]), 319 | next_obs=jnp.array(dataset["next_observations"]), 320 | done=jnp.array(dataset["terminals"]), 321 | ) 322 | 323 | # --- Initialize agent and value networks --- 324 | num_actions = env.single_action_space.shape[0] 325 | dummy_obs = jnp.zeros(env.single_observation_space.shape) 326 | dummy_action = jnp.zeros(num_actions) 327 | actor_net = TanhGaussianActor(num_actions) 328 | q_net = VectorQ(args.num_critics) 329 | alpha_net = EntropyCoef() 330 | 331 | # Target networks share seeds to match initialization 332 | rng, rng_actor, rng_q, rng_alpha = jax.random.split(rng, 4) 333 | agent_state = AgentTrainState( 334 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]), 335 | vec_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 336 | vec_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 337 | alpha=create_train_state(args, rng_alpha, alpha_net, []), 338 | ) 339 | 340 | # --- Make train step --- 341 | _agent_train_step_fn = make_train_step( 342 | args, actor_net.apply, q_net.apply, alpha_net.apply, dataset 343 | ) 344 | 345 | num_evals = args.num_updates // args.eval_interval 346 | for eval_idx in range(num_evals): 347 | # --- Execute train loop --- 348 | (rng, agent_state), loss = jax.lax.scan( 349 | _agent_train_step_fn, 350 | (rng, agent_state), 351 | None, 352 | args.eval_interval, 353 | ) 354 | 355 | # --- Evaluate agent --- 356 | rng, rng_eval = jax.random.split(rng) 357 | returns = eval_agent(args, rng_eval, env, agent_state) 358 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0 359 | 360 | # --- Log metrics --- 361 | step = (eval_idx + 1) * args.eval_interval 362 | print("Step:", step, f"\t Score: {scores.mean():.2f}") 363 | if args.log: 364 | log_dict = { 365 | "return": returns.mean(), 366 | "score": scores.mean(), 367 | "score_std": scores.std(), 368 | "num_updates": step, 369 | **{k: loss[k][-1] for k in loss}, 370 | } 371 | wandb.log(log_dict) 372 | 373 | # --- Evaluate final agent --- 374 | if args.eval_final_episodes > 0: 375 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers)) 376 | print(f"Evaluating final agent for {final_iters} iterations...") 377 | _rng = jax.random.split(rng, final_iters) 378 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng]) 379 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0 380 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()} 381 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores") 382 | 383 | # --- Write final returns to file --- 384 | os.makedirs("final_returns", exist_ok=True) 385 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 386 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz" 387 | with open(os.path.join("final_returns", filename), "wb") as f: 388 | onp.savez_compressed(f, **info, args=asdict(args)) 389 | 390 | if args.log: 391 | wandb.save(os.path.join("final_returns", filename)) 392 | 393 | if args.log: 394 | wandb.finish() 395 | -------------------------------------------------------------------------------- /algorithms/td3_bc.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from dataclasses import dataclass, asdict 3 | from datetime import datetime 4 | import os 5 | import warnings 6 | 7 | import distrax 8 | import d4rl 9 | import flax.linen as nn 10 | from flax.training.train_state import TrainState 11 | import gym 12 | import jax 13 | import jax.numpy as jnp 14 | import numpy as onp 15 | import optax 16 | import tyro 17 | import wandb 18 | 19 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 20 | 21 | 22 | @dataclass 23 | class Args: 24 | # --- Experiment --- 25 | seed: int = 0 26 | dataset: str = "halfcheetah-medium-v2" 27 | algorithm: str = "td3_bc" 28 | num_updates: int = 1_000_000 29 | eval_interval: int = 2500 30 | eval_workers: int = 8 31 | eval_final_episodes: int = 1000 32 | # --- Logging --- 33 | log: bool = False 34 | wandb_project: str = "unifloral" 35 | wandb_team: str = "flair" 36 | wandb_group: str = "debug" 37 | # --- Generic optimization --- 38 | lr: float = 3e-4 39 | batch_size: int = 256 40 | gamma: float = 0.99 41 | polyak_step_size: float = 0.005 42 | # --- TD3+BC --- 43 | td3_alpha: float = 2.5 44 | noise_clip: float = 0.5 45 | policy_noise: float = 0.2 46 | num_critic_updates_per_step: int = 2 47 | 48 | 49 | r""" 50 | |\ __ 51 | \| /_/ 52 | \| 53 | ___|_____ 54 | \ / 55 | \ / 56 | \___/ Preliminaries 57 | """ 58 | 59 | AgentTrainState = namedtuple( 60 | "AgentTrainState", "actor actor_target dual_q dual_q_target" 61 | ) 62 | Transition = namedtuple("Transition", "obs action reward next_obs done") 63 | 64 | 65 | class SoftQNetwork(nn.Module): 66 | obs_mean: jax.Array 67 | obs_std: jax.Array 68 | 69 | @nn.compact 70 | def __call__(self, obs, action): 71 | obs = (obs - self.obs_mean) / (self.obs_std + 1e-3) 72 | x = jnp.concatenate([obs, action], axis=-1) 73 | for _ in range(2): 74 | x = nn.Dense(256)(x) 75 | x = nn.relu(x) 76 | q = nn.Dense(1)(x) 77 | return q.squeeze(-1) 78 | 79 | 80 | class DualQNetwork(nn.Module): 81 | obs_mean: jax.Array 82 | obs_std: jax.Array 83 | 84 | @nn.compact 85 | def __call__(self, obs, action): 86 | vmap_critic = nn.vmap( 87 | SoftQNetwork, 88 | variable_axes={"params": 0}, # Parameters not shared between critics 89 | split_rngs={"params": True, "dropout": True}, # Different initializations 90 | in_axes=None, 91 | out_axes=-1, 92 | axis_size=2, # Two Q networks 93 | ) 94 | q_values = vmap_critic(self.obs_mean, self.obs_std)(obs, action) 95 | return q_values 96 | 97 | 98 | class DeterministicTanhActor(nn.Module): 99 | num_actions: int 100 | obs_mean: jax.Array 101 | obs_std: jax.Array 102 | 103 | @nn.compact 104 | def __call__(self, x): 105 | x = (x - self.obs_mean) / (self.obs_std + 1e-3) 106 | for _ in range(2): 107 | x = nn.Dense(256)(x) 108 | x = nn.relu(x) 109 | action = nn.Dense(self.num_actions)(x) 110 | pi = distrax.Transformed( 111 | distrax.Deterministic(action), 112 | distrax.Tanh(), 113 | ) 114 | return pi 115 | 116 | 117 | def create_train_state(args, rng, network, dummy_input): 118 | return TrainState.create( 119 | apply_fn=network.apply, 120 | params=network.init(rng, *dummy_input), 121 | tx=optax.adam(args.lr, eps=1e-5), 122 | ) 123 | 124 | 125 | def eval_agent(args, rng, env, agent_state): 126 | # --- Reset environment --- 127 | step = 0 128 | returned = onp.zeros(args.eval_workers).astype(bool) 129 | cum_reward = onp.zeros(args.eval_workers) 130 | rng, rng_reset = jax.random.split(rng) 131 | rng_reset = jax.random.split(rng_reset, args.eval_workers) 132 | obs = env.reset() 133 | 134 | # --- Rollout agent --- 135 | @jax.jit 136 | @jax.vmap 137 | def _policy_step(rng, obs): 138 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs) 139 | action = pi.sample(seed=rng) 140 | return jnp.nan_to_num(action) 141 | 142 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps 143 | while step < max_episode_steps and not returned.all(): 144 | # --- Take step in environment --- 145 | step += 1 146 | rng, rng_step = jax.random.split(rng) 147 | rng_step = jax.random.split(rng_step, args.eval_workers) 148 | action = _policy_step(rng_step, jnp.array(obs)) 149 | obs, reward, done, info = env.step(onp.array(action)) 150 | 151 | # --- Track cumulative reward --- 152 | cum_reward += reward * ~returned 153 | returned |= done 154 | 155 | if step >= max_episode_steps and not returned.all(): 156 | warnings.warn("Maximum steps reached before all episodes terminated") 157 | return cum_reward 158 | 159 | 160 | r""" 161 | __/) 162 | .-(__(=: 163 | |\ | \) 164 | \ || 165 | \|| 166 | \| 167 | ___|_____ 168 | \ / 169 | \ / 170 | \___/ Agent 171 | """ 172 | 173 | 174 | def make_train_step(args, actor_apply_fn, q_apply_fn, dataset): 175 | """Make JIT-compatible agent train step.""" 176 | 177 | def _train_step(runner_state, _): 178 | rng, agent_state = runner_state 179 | 180 | # --- Sample batch --- 181 | rng, rng_batch = jax.random.split(rng) 182 | batch_indices = jax.random.randint( 183 | rng_batch, (args.batch_size,), 0, len(dataset.obs) 184 | ) 185 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset) 186 | 187 | # --- Update critics --- 188 | def _update_critics(runner_state, _): 189 | rng, agent_state = runner_state 190 | 191 | def _compute_target(rng, transition): 192 | next_obs = transition.next_obs 193 | 194 | # --- Sample noised action --- 195 | next_pi = actor_apply_fn(agent_state.actor_target.params, next_obs) 196 | rng, rng_action, rng_noise = jax.random.split(rng, 3) 197 | action = next_pi.sample(seed=rng_action) 198 | noise = jax.random.normal(rng_noise, shape=action.shape) 199 | noise *= args.policy_noise 200 | noise = jnp.clip(noise, -args.noise_clip, args.noise_clip) 201 | action = jnp.clip(action + noise, -1, 1) 202 | 203 | # --- Compute targets --- 204 | target_q = q_apply_fn( 205 | agent_state.dual_q_target.params, next_obs, action 206 | ) 207 | next_q_value = (1.0 - transition.done) * jnp.min(target_q) 208 | return transition.reward + args.gamma * next_q_value 209 | 210 | rng, rng_targets = jax.random.split(rng) 211 | rng_targets = jax.random.split(rng_targets, args.batch_size) 212 | targets = jax.vmap(_compute_target)(rng_targets, batch) 213 | 214 | # --- Compute critic loss --- 215 | @jax.value_and_grad 216 | def _q_loss_fn(params): 217 | q_pred = q_apply_fn(params, batch.obs, batch.action) 218 | return jnp.square(q_pred - jnp.expand_dims(targets, axis=-1)).mean() 219 | 220 | q_loss, q_grad = _q_loss_fn(agent_state.dual_q.params) 221 | updated_q_state = agent_state.dual_q.apply_gradients(grads=q_grad) 222 | agent_state = agent_state._replace(dual_q=updated_q_state) 223 | return (rng, agent_state), q_loss 224 | 225 | # --- Iterate critic update --- 226 | (rng, agent_state), q_loss = jax.lax.scan( 227 | _update_critics, 228 | (rng, agent_state), 229 | None, 230 | length=args.num_critic_updates_per_step, 231 | ) 232 | 233 | # --- Update actor --- 234 | def _actor_loss_function(params): 235 | def _transition_loss(transition): 236 | pi = actor_apply_fn(params, transition.obs) 237 | pi_action = pi.sample(seed=None) 238 | q = q_apply_fn(agent_state.dual_q.params, transition.obs, pi_action) 239 | bc_loss = jnp.square(pi_action - transition.action).mean() 240 | return q[0], bc_loss 241 | 242 | q, bc_loss = jax.vmap(_transition_loss)(batch) 243 | lambda_ = args.td3_alpha / (jnp.abs(q).mean() + 1e-7) 244 | lambda_ = jax.lax.stop_gradient(lambda_) 245 | actor_loss = -lambda_ * q.mean() + bc_loss.mean() 246 | return actor_loss.mean(), (q.mean(), lambda_.mean(), bc_loss.mean()) 247 | 248 | loss_fn = jax.value_and_grad(_actor_loss_function, has_aux=True) 249 | (actor_loss, (q_mean, lambda_, bc_loss)), actor_grad = loss_fn( 250 | agent_state.actor.params 251 | ) 252 | agent_state = agent_state._replace( 253 | actor=agent_state.actor.apply_gradients(grads=actor_grad) 254 | ) 255 | 256 | # --- Update target networks --- 257 | def _update_target(state, target_state): 258 | new_target_params = optax.incremental_update( 259 | state.params, target_state.params, args.polyak_step_size 260 | ) 261 | return target_state.replace( 262 | step=target_state.step + 1, params=new_target_params 263 | ) 264 | 265 | agent_state = agent_state._replace( 266 | actor_target=_update_target(agent_state.actor, agent_state.actor_target), 267 | dual_q_target=_update_target(agent_state.dual_q, agent_state.dual_q_target), 268 | ) 269 | 270 | loss = { 271 | "actor_loss": actor_loss, 272 | "q_loss": q_loss.mean(), 273 | "q_mean": q_mean, 274 | "lambda": lambda_, 275 | "bc_loss": bc_loss, 276 | } 277 | return (rng, agent_state), loss 278 | 279 | return _train_step 280 | 281 | 282 | if __name__ == "__main__": 283 | # --- Parse arguments --- 284 | args = tyro.cli(Args) 285 | rng = jax.random.PRNGKey(args.seed) 286 | 287 | # --- Initialize logger --- 288 | if args.log: 289 | wandb.init( 290 | config=args, 291 | project=args.wandb_project, 292 | entity=args.wandb_team, 293 | group=args.wandb_group, 294 | job_type="train_agent", 295 | ) 296 | 297 | # --- Initialize environment and dataset --- 298 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers) 299 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset)) 300 | dataset = Transition( 301 | obs=jnp.array(dataset["observations"]), 302 | action=jnp.array(dataset["actions"]), 303 | reward=jnp.array(dataset["rewards"]), 304 | next_obs=jnp.array(dataset["next_observations"]), 305 | done=jnp.array(dataset["terminals"]), 306 | ) 307 | 308 | # --- Initialize agent and value networks --- 309 | num_actions = env.single_action_space.shape[0] 310 | obs_mean = dataset.obs.mean(axis=0) 311 | obs_std = jnp.nan_to_num(dataset.obs.std(axis=0), nan=1.0) 312 | dummy_obs = jnp.zeros(env.single_observation_space.shape) 313 | dummy_action = jnp.zeros(num_actions) 314 | actor_net = DeterministicTanhActor(num_actions, obs_mean, obs_std) 315 | q_net = DualQNetwork(obs_mean, obs_std) 316 | 317 | # Target networks share seeds to match initialization 318 | rng, rng_actor, rng_q = jax.random.split(rng, 3) 319 | agent_state = AgentTrainState( 320 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]), 321 | actor_target=create_train_state(args, rng_actor, actor_net, [dummy_obs]), 322 | dual_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 323 | dual_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]), 324 | ) 325 | 326 | # --- Make train step --- 327 | _agent_train_step_fn = make_train_step(args, actor_net.apply, q_net.apply, dataset) 328 | 329 | num_evals = args.num_updates // args.eval_interval 330 | for eval_idx in range(num_evals): 331 | # --- Execute train loop --- 332 | (rng, agent_state), loss = jax.lax.scan( 333 | _agent_train_step_fn, 334 | (rng, agent_state), 335 | None, 336 | args.eval_interval, 337 | ) 338 | 339 | # --- Evaluate agent --- 340 | rng, rng_eval = jax.random.split(rng) 341 | returns = eval_agent(args, rng_eval, env, agent_state) 342 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0 343 | 344 | # --- Log metrics --- 345 | step = (eval_idx + 1) * args.eval_interval 346 | print("Step:", step, f"\t Score: {scores.mean():.2f}") 347 | if args.log: 348 | log_dict = { 349 | "return": returns.mean(), 350 | "score": scores.mean(), 351 | "score_std": scores.std(), 352 | "num_updates": step, 353 | **{k: loss[k][-1] for k in loss}, 354 | } 355 | wandb.log(log_dict) 356 | 357 | # --- Evaluate final agent --- 358 | if args.eval_final_episodes > 0: 359 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers)) 360 | print(f"Evaluating final agent for {final_iters} iterations...") 361 | _rng = jax.random.split(rng, final_iters) 362 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng]) 363 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0 364 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()} 365 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores") 366 | 367 | # --- Write final returns to file --- 368 | os.makedirs("final_returns", exist_ok=True) 369 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 370 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz" 371 | with open(os.path.join("final_returns", filename), "wb") as f: 372 | onp.savez_compressed(f, **info, args=asdict(args)) 373 | 374 | if args.log: 375 | wandb.save(os.path.join("final_returns", filename)) 376 | 377 | if args.log: 378 | wandb.finish() 379 | -------------------------------------------------------------------------------- /algorithms/termination_fns.py: -------------------------------------------------------------------------------- 1 | """ 2 | The world models we use in model-based RL don't predict termination, so we need to 3 | define termination functions for each task. 4 | This code is adapted from 5 | https://github.com/yihaosun1124/OfflineRL-Kit/blob/6e578d13568fa934096baa2ca96e38e1fa44a233/offlinerlkit/utils/termination_fns.py#L123 6 | Thanks to the authors! 7 | """ 8 | 9 | import jax.numpy as jnp 10 | 11 | 12 | def termination_fn_halfcheetah(obs, act, next_obs): 13 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 14 | 15 | not_done = jnp.logical_and(jnp.all(next_obs > -100), jnp.all(next_obs < 100)) 16 | done = ~not_done 17 | return done 18 | 19 | 20 | def termination_fn_hopper(obs, act, next_obs): 21 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 22 | 23 | height = next_obs[0] 24 | angle = next_obs[1] 25 | not_done = ( 26 | jnp.isfinite(next_obs).all() 27 | * jnp.abs(next_obs[1:] < 100).all() 28 | * (height > 0.7) 29 | * (jnp.abs(angle) < 0.2) 30 | ) 31 | 32 | done = ~not_done 33 | return done 34 | 35 | 36 | def termination_fn_halfcheetahveljump(obs, act, next_obs): 37 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 38 | 39 | done = jnp.array(False) 40 | return done 41 | 42 | 43 | def termination_fn_antangle(obs, act, next_obs): 44 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 45 | 46 | x = next_obs[0] 47 | not_done = jnp.isfinite(next_obs).all() * (x >= 0.2) * (x <= 1.0) 48 | 49 | done = ~not_done 50 | return done 51 | 52 | 53 | def termination_fn_ant(obs, act, next_obs): 54 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 55 | 56 | x = next_obs[0] 57 | not_done = jnp.isfinite(next_obs).all() * (x >= 0.2) * (x <= 1.0) 58 | 59 | done = ~not_done 60 | return done 61 | 62 | 63 | def termination_fn_walker2d(obs, act, next_obs): 64 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 65 | 66 | height = next_obs[0] 67 | angle = next_obs[1] 68 | not_done = ( 69 | jnp.logical_and(jnp.all(next_obs > -100), jnp.all(next_obs < 100)) 70 | * (height > 0.8) 71 | * (height < 2.0) 72 | * (angle > -1.0) 73 | * (angle < 1.0) 74 | ) 75 | done = ~not_done 76 | return done 77 | 78 | 79 | def termination_fn_point2denv(obs, act, next_obs): 80 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 81 | 82 | done = jnp.array(False) 83 | return done 84 | 85 | 86 | def termination_fn_point2dwallenv(obs, act, next_obs): 87 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 88 | 89 | done = jnp.array(False) 90 | return done 91 | 92 | 93 | def termination_fn_pendulum(obs, act, next_obs): 94 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 95 | 96 | done = jnp.array(False) 97 | return done 98 | 99 | 100 | def termination_fn_humanoid(obs, act, next_obs): 101 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 102 | 103 | z = next_obs[0] 104 | done = (z < 1.0) + (z > 2.0) 105 | 106 | return done 107 | 108 | 109 | def termination_fn_pen(obs, act, next_obs): 110 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 111 | 112 | obj_pos = next_obs[24:27] 113 | done = obj_pos[2] < 0.075 114 | 115 | return done 116 | 117 | 118 | def terminaltion_fn_door(obs, act, next_obs): 119 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 120 | 121 | done = jnp.array(False) 122 | 123 | return done 124 | 125 | 126 | def termination_fn_relocate(obs, act, next_obs): 127 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 128 | 129 | done = jnp.array(False) 130 | 131 | return done 132 | 133 | 134 | def maze2d_open_termination_fn(obs, act, next_obs): 135 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 136 | 137 | agent_location = jnp.array([obs[0], obs[1]]) 138 | goal_location = jnp.array([2, 3]) 139 | done = jnp.linalg.norm(agent_location - goal_location) < 0.5 140 | return done 141 | 142 | 143 | def maze2d_umaze_termination_fn(obs, act, next_obs): 144 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 145 | 146 | agent_location = jnp.array([obs[0], obs[1]]) 147 | goal_location = jnp.array([1, 1]) 148 | done = jnp.linalg.norm(agent_location - goal_location) < 0.5 149 | return done 150 | 151 | 152 | def maze2d_medium_termination_fn(obs, act, next_obs): 153 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 154 | 155 | agent_location = jnp.array([obs[0], obs[1]]) 156 | goal_location = jnp.array([6, 6]) 157 | done = jnp.linalg.norm(agent_location - goal_location) < 0.5 158 | return done 159 | 160 | 161 | def maze2d_large_termination_fn(obs, act, next_obs): 162 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 163 | 164 | agent_location = jnp.array([obs[0], obs[1]]) 165 | goal_location = jnp.array([7, 9]) 166 | done = jnp.linalg.norm(agent_location - goal_location) < 0.5 167 | return done 168 | 169 | 170 | def termination_fn_kitchen(obs, act, next_obs): 171 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1 172 | 173 | # Implementing termination function is tricky, since it's unclear how it works in 174 | # the original code (see d4rl/kitchen/kitchen_envs.py) 175 | # Not terminating the episode works as well, it is even an argument defined in gym. 176 | done = jnp.array(False) 177 | return done 178 | 179 | 180 | def get_termination_fn(task): 181 | if "halfcheetahvel" in task: 182 | return termination_fn_halfcheetahveljump 183 | elif "halfcheetah" in task: 184 | return termination_fn_halfcheetah 185 | elif "hopper" in task: 186 | return termination_fn_hopper 187 | elif "antangle" in task: 188 | return termination_fn_antangle 189 | elif "ant" in task: 190 | return termination_fn_ant 191 | elif "walker2d" in task: 192 | return termination_fn_walker2d 193 | elif "point2denv" in task: 194 | return termination_fn_point2denv 195 | elif "point2dwallenv" in task: 196 | return termination_fn_point2dwallenv 197 | elif "pendulum" in task: 198 | return termination_fn_pendulum 199 | elif "humanoid" in task: 200 | return termination_fn_humanoid 201 | elif "maze2d-open" in task: 202 | return maze2d_open_termination_fn 203 | elif "maze2d-umaze" in task: 204 | return maze2d_umaze_termination_fn 205 | elif "maze2d-medium" in task: 206 | return maze2d_medium_termination_fn 207 | elif "maze2d-large" in task: 208 | return maze2d_large_termination_fn 209 | elif "pen" in task: 210 | return termination_fn_pen 211 | elif "door" in task: 212 | return terminaltion_fn_door 213 | elif "relocate" in task: 214 | return termination_fn_relocate 215 | elif "kitchen" in task: 216 | return termination_fn_kitchen 217 | else: 218 | raise ValueError(f"Unknown task: {task}") 219 | -------------------------------------------------------------------------------- /configs/algorithms/bc.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: BC 8 | program: algorithms/bc.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: bc 20 | num_updates: 21 | value: 1_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 3e-4 42 | batch_size: 43 | value: 256 44 | -------------------------------------------------------------------------------- /configs/algorithms/combo.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: COMBO 8 | program: algorithms/combo.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | value: 0 # Fixed seed since we're sweeping over models 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: combo 20 | num_updates: 21 | value: 3000000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | values: 42 | - 1e-4 43 | - 3e-4 44 | batch_size: 45 | value: 256 46 | gamma: 47 | value: 0.99 48 | polyak_step_size: 49 | value: 0.005 50 | 51 | # --- Model-based specific --- 52 | model_path: 53 | values: # This will be replaced with actual model paths for each environment 54 | - "PLACEHOLDER_MODEL_PATH" 55 | model_retain_epochs: 56 | value: 5 57 | num_critics: 58 | value: 10 59 | rollout_batch_size: 60 | value: 50000 61 | rollout_interval: 62 | value: 1000 63 | rollout_length: 64 | values: [1, 5, 25] 65 | dataset_sample_ratio: 66 | values: [0.5, 0.8] 67 | 68 | # --- CQL --- 69 | actor_lr: 70 | values: [1e-5, 3e-5, 1e-4] 71 | cql_temperature: 72 | value: 1.0 73 | cql_min_q_weight: 74 | values: [0.5, 1.0, 5.0] 75 | -------------------------------------------------------------------------------- /configs/algorithms/cql.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: CQL 8 | program: algorithms/cql.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: cql 20 | num_updates: 21 | value: 1_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | values: [1e-4, 3e-4] 42 | batch_size: 43 | value: 256 44 | gamma: 45 | value: 0.99 46 | polyak_step_size: 47 | value: 0.005 48 | 49 | # --- SAC-N --- 50 | num_critics: 51 | value: 10 52 | 53 | # --- CQL --- 54 | actor_lr: 55 | values: [3e-5, 1e-4, 3e-4] 56 | cql_temperature: 57 | value: 1.0 58 | cql_min_q_weight: 59 | value: 10.0 60 | -------------------------------------------------------------------------------- /configs/algorithms/edac.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: EDAC 8 | program: algorithms/edac.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: edac 20 | num_updates: 21 | value: 3000000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 0.0003 42 | batch_size: 43 | value: 256 44 | gamma: 45 | value: 0.99 46 | polyak_step_size: 47 | value: 0.005 48 | 49 | # --- SAC-N --- 50 | num_critics: 51 | values: [10, 20, 50] # 100 has high GPU memory usage 52 | 53 | # --- EDAC --- 54 | eta: 55 | values: [0.0, 1.0, 5.0, 10.0, 100.0, 200.0, 500.0, 1000.0] 56 | -------------------------------------------------------------------------------- /configs/algorithms/iql.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: IQL 8 | program: algorithms/iql.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: iql 20 | num_updates: 21 | value: 1_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 0.0003 42 | batch_size: 43 | value: 256 44 | gamma: 45 | value: 0.99 46 | polyak_step_size: 47 | value: 0.005 48 | 49 | # --- IQL specific --- 50 | beta: 51 | values: [0.5, 3.0, 10.0] 52 | iql_tau: 53 | values: [0.5, 0.7, 0.9] 54 | exp_adv_clip: 55 | value: 100.0 56 | -------------------------------------------------------------------------------- /configs/algorithms/mopo.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: MOPO 8 | program: algorithms/mopo.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | value: 0 # Fixed seed since we're sweeping over models 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: mopo 20 | num_updates: 21 | value: 3000000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 1e-4 42 | batch_size: 43 | value: 256 44 | gamma: 45 | value: 0.99 46 | polyak_step_size: 47 | value: 0.005 48 | 49 | # --- Model-based specific --- 50 | model_path: 51 | values: # This should be replaced with actual model paths for each environment 52 | - "PLACEHOLDER_MODEL_PATH" 53 | model_retain_epochs: 54 | value: 5 55 | num_critics: 56 | value: 10 57 | rollout_batch_size: 58 | value: 50000 59 | rollout_interval: 60 | value: 1000 61 | rollout_length: 62 | values: [1, 5] 63 | dataset_sample_ratio: 64 | value: 0.05 65 | 66 | # --- MOPO specific --- 67 | step_penalty_coef: 68 | values: [1.0, 5.0] 69 | -------------------------------------------------------------------------------- /configs/algorithms/morel.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: MOReL 8 | program: algorithms/morel.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | value: 0 # Fixed seed since we're sweeping over models 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: morel 20 | num_updates: 21 | value: 3000000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 1e-4 42 | batch_size: 43 | value: 256 44 | gamma: 45 | value: 0.99 46 | polyak_step_size: 47 | value: 0.005 48 | 49 | # --- Model-based specific --- 50 | model_path: 51 | values: # This should be replaced with actual model paths for each environment 52 | - "PLACEHOLDER_MODEL_PATH" 53 | model_retain_epochs: 54 | value: 5 55 | num_critics: 56 | value: 10 57 | rollout_batch_size: 58 | value: 50000 59 | rollout_interval: 60 | value: 1000 61 | rollout_length: 62 | value: 5 63 | dataset_sample_ratio: 64 | value: 0.01 65 | 66 | # --- MOREL specific --- 67 | threshold_coef: 68 | values: [0, 5, 10, 15, 20, 25] 69 | term_penalty_offset: 70 | values: [-30, -50, -100, -200] 71 | -------------------------------------------------------------------------------- /configs/algorithms/rebrac.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: ReBRAC 8 | program: algorithms/rebrac.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: rebrac 20 | num_updates: 21 | value: 1_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 1e-3 42 | batch_size: 43 | value: 1024 44 | gamma: 45 | value: 0.99 46 | polyak_step_size: 47 | value: 0.005 48 | 49 | # --- TD3+BC --- 50 | noise_clip: 51 | value: 0.5 52 | policy_noise: 53 | value: 0.2 54 | num_critic_updates_per_step: 55 | value: 2 56 | 57 | # --- REBRAC --- 58 | critic_bc_coef: 59 | values: [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1] 60 | actor_bc_coef: 61 | values: [0.0005, 0.001, 0.002, 0.003, 0.03, 0.1, 0.3, 1.0] 62 | actor_ln: 63 | value: false 64 | critic_ln: 65 | value: true 66 | norm_obs: 67 | value: false 68 | -------------------------------------------------------------------------------- /configs/algorithms/sac_n.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: SAC-N 8 | program: algorithms/sac_n.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: sac_n 20 | num_updates: 21 | value: 3000000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 0.0003 42 | batch_size: 43 | value: 256 44 | gamma: 45 | value: 0.99 46 | polyak_step_size: 47 | value: 0.005 48 | 49 | # --- SAC-N --- 50 | num_critics: 51 | values: [5, 10, 20, 50, 100, 200] # 500 and 1000 have high GPU memory usage 52 | -------------------------------------------------------------------------------- /configs/algorithms/td3_bc.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: TD3+BC 8 | program: algorithms/td3_bc.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: td3_bc 20 | num_updates: 21 | value: 1_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 3e-4 42 | batch_size: 43 | value: 256 44 | gamma: 45 | value: 0.99 46 | polyak_step_size: 47 | value: 0.005 48 | 49 | # --- TD3+BC --- 50 | td3_alpha: 51 | values: [1.0, 2.0, 2.5, 3.0, 4.0] 52 | noise_clip: 53 | value: 0.5 54 | policy_noise: 55 | value: 0.2 56 | num_critic_updates_per_step: 57 | value: 2 58 | -------------------------------------------------------------------------------- /configs/dynamics.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: Dynamics-Model 8 | program: algorithms/dynamics.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: dynamics 20 | eval_interval: 21 | value: 10000 22 | log: 23 | value: true 24 | wandb_project: 25 | value: unifloral 26 | wandb_team: 27 | value: flair 28 | wandb_group: 29 | value: debug 30 | 31 | # --- Generic optimization --- 32 | lr: 33 | value: 0.001 34 | batch_size: 35 | value: 256 36 | 37 | # --- Dynamics --- 38 | n_layers: 39 | value: 4 40 | layer_size: 41 | value: 200 42 | num_ensemble: 43 | value: 7 44 | num_elites: 45 | value: 5 46 | num_epochs: 47 | value: 400 48 | logvar_diff_coef: 49 | value: 0.01 50 | weight_decay: 51 | value: 2.5e-5 52 | validation_split: 53 | value: 0.2 54 | precompute_term_stats: 55 | value: true 56 | -------------------------------------------------------------------------------- /configs/unifloral/bc.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: BC (Unifloral) 8 | program: algorithms/unifloral.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: unified 20 | num_updates: 21 | value: 1_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 3e-4 42 | actor_lr: 43 | value: 3e-4 44 | lr_schedule: 45 | value: constant 46 | batch_size: 47 | value: 256 48 | gamma: 49 | value: 0.99 50 | polyak_step_size: 51 | value: 0.005 52 | norm_obs: 53 | value: true 54 | 55 | # --- Actor architecture --- 56 | actor_num_layers: 57 | value: 2 58 | actor_layer_width: 59 | value: 256 60 | actor_ln: 61 | value: false 62 | deterministic: 63 | value: true 64 | deterministic_eval: 65 | value: false 66 | use_tanh_mean: 67 | value: true 68 | use_log_std_param: 69 | value: false 70 | log_std_min: 71 | value: -5.0 72 | log_std_max: 73 | value: 2.0 74 | 75 | # --- Critic + value function architecture --- 76 | num_critics: 77 | value: 2 78 | critic_num_layers: 79 | value: 2 80 | critic_layer_width: 81 | value: 256 82 | critic_ln: 83 | value: false 84 | 85 | # --- Actor loss components --- 86 | actor_bc_coef: 87 | value: 1.0 88 | actor_q_coef: 89 | value: 0.0 90 | use_q_target_in_actor: 91 | value: false 92 | normalize_q_loss: 93 | value: false 94 | aggregate_q: 95 | value: first 96 | 97 | # --- AWR (Advantage Weighted Regression) actor --- 98 | use_awr: 99 | value: false 100 | awr_temperature: 101 | value: 1.0 102 | awr_exp_adv_clip: 103 | value: 100.0 104 | 105 | # --- Critic loss components --- 106 | num_critic_updates_per_step: 107 | value: 1 108 | critic_bc_coef: 109 | value: 0.0 110 | diversity_coef: 111 | value: 0.0 112 | policy_noise: 113 | value: 0.0 114 | noise_clip: 115 | value: 0.0 116 | use_target_actor: 117 | value: false 118 | 119 | # --- Value function --- 120 | use_value_target: 121 | value: false 122 | value_expectile: 123 | value: 0.8 124 | 125 | # --- Entropy loss --- 126 | use_entropy_loss: 127 | value: false 128 | ent_coef_init: 129 | value: 1.0 130 | actor_entropy_coef: 131 | value: 0.0 132 | critic_entropy_coef: 133 | value: 0.0 134 | -------------------------------------------------------------------------------- /configs/unifloral/edac.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: EDAC (Unifloral) 8 | program: algorithms/unifloral.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: unified 20 | num_updates: 21 | value: 3_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 3e-4 42 | actor_lr: 43 | value: 3e-4 44 | lr_schedule: 45 | value: constant 46 | batch_size: 47 | value: 256 48 | gamma: 49 | value: 0.99 50 | polyak_step_size: 51 | value: 0.005 52 | norm_obs: 53 | value: false 54 | 55 | # --- Actor architecture --- 56 | actor_num_layers: 57 | value: 3 58 | actor_layer_width: 59 | value: 256 60 | actor_ln: 61 | value: false 62 | deterministic: 63 | value: false 64 | deterministic_eval: 65 | value: false 66 | use_tanh_mean: 67 | value: false 68 | use_log_std_param: 69 | value: false 70 | log_std_min: 71 | value: -5.0 72 | log_std_max: 73 | value: 2.0 74 | 75 | # --- Critic + value function architecture --- 76 | num_critics: 77 | values: [10, 20, 50] 78 | critic_num_layers: 79 | value: 3 80 | critic_layer_width: 81 | value: 256 82 | critic_ln: 83 | value: false 84 | 85 | # --- Actor loss components --- 86 | actor_bc_coef: 87 | value: 0.0 88 | actor_q_coef: 89 | value: 1.0 90 | use_q_target_in_actor: 91 | value: false 92 | normalize_q_loss: 93 | value: false 94 | aggregate_q: 95 | value: min 96 | 97 | # --- AWR (Advantage Weighted Regression) actor --- 98 | use_awr: 99 | value: false 100 | awr_temperature: 101 | value: 1.0 102 | awr_exp_adv_clip: 103 | value: 100.0 104 | 105 | # --- Critic loss components --- 106 | num_critic_updates_per_step: 107 | value: 1 108 | critic_bc_coef: 109 | value: 0.0 110 | diversity_coef: 111 | values: [0.0, 1.0, 5.0, 10.0, 100.0, 200.0, 500.0, 1000.0] 112 | policy_noise: 113 | value: 0.0 114 | noise_clip: 115 | value: 0.0 116 | use_target_actor: 117 | value: false 118 | 119 | # --- Value function --- 120 | use_value_target: 121 | value: false 122 | value_expectile: 123 | value: 0.8 124 | 125 | # --- Entropy loss --- 126 | use_entropy_loss: 127 | value: true 128 | ent_coef_init: 129 | value: 1.0 130 | actor_entropy_coef: 131 | value: 1.0 132 | critic_entropy_coef: 133 | value: 1.0 134 | -------------------------------------------------------------------------------- /configs/unifloral/iql.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: IQL (Unifloral) 8 | program: algorithms/unifloral.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: unified 20 | num_updates: 21 | value: 1_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 3e-4 42 | actor_lr: 43 | value: 3e-4 44 | lr_schedule: 45 | value: cosine 46 | batch_size: 47 | value: 256 48 | gamma: 49 | value: 0.99 50 | polyak_step_size: 51 | value: 0.005 52 | norm_obs: 53 | value: true 54 | 55 | # --- Actor architecture --- 56 | actor_num_layers: 57 | value: 2 58 | actor_layer_width: 59 | value: 256 60 | actor_ln: 61 | value: false 62 | deterministic: 63 | value: false 64 | deterministic_eval: 65 | value: true 66 | use_tanh_mean: 67 | value: true 68 | use_log_std_param: 69 | value: true 70 | log_std_min: 71 | value: -20.0 72 | log_std_max: 73 | value: 2.0 74 | 75 | # --- Critic + value function architecture --- 76 | num_critics: 77 | value: 2 78 | critic_num_layers: 79 | value: 2 80 | critic_layer_width: 81 | value: 256 82 | critic_ln: 83 | value: false 84 | 85 | # --- Actor loss components --- 86 | actor_bc_coef: 87 | value: 1.0 # Weight of AWR loss 88 | actor_q_coef: 89 | value: 0.0 90 | use_q_target_in_actor: 91 | value: false 92 | normalize_q_loss: 93 | value: false 94 | aggregate_q: 95 | value: min 96 | 97 | # --- AWR (Advantage Weighted Regression) actor --- 98 | use_awr: 99 | value: true 100 | awr_temperature: 101 | value: [0.5, 3.0, 10.0] # IQL beta 102 | awr_exp_adv_clip: 103 | value: 100.0 104 | 105 | # --- Critic loss components --- 106 | num_critic_updates_per_step: 107 | value: 1 108 | critic_bc_coef: 109 | value: 0.0 110 | diversity_coef: 111 | value: 0.0 112 | policy_noise: 113 | value: 0.0 114 | noise_clip: 115 | value: 0.0 116 | use_target_actor: 117 | value: false 118 | 119 | # --- Value function --- 120 | use_value_target: 121 | value: false 122 | value_expectile: 123 | value: [0.5, 0.7, 0.9] # IQL tau 124 | 125 | # --- Entropy loss --- 126 | use_entropy_loss: 127 | value: false 128 | ent_coef_init: 129 | value: 1.0 130 | actor_entropy_coef: 131 | value: 0.0 132 | critic_entropy_coef: 133 | value: 0.0 134 | -------------------------------------------------------------------------------- /configs/unifloral/mobrac.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: ReBRAC (Unifloral) 8 | program: algorithms/unifloral.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: unified 20 | num_updates: 21 | value: 1_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 1e-3 42 | actor_lr: 43 | value: 1e-3 44 | lr_schedule: 45 | value: constant 46 | batch_size: 47 | value: 1024 48 | gamma: 49 | value: 0.99 50 | polyak_step_size: 51 | value: 0.005 52 | norm_obs: 53 | value: false 54 | 55 | # --- Actor architecture --- 56 | actor_num_layers: 57 | value: 3 58 | actor_layer_width: 59 | value: 256 60 | actor_ln: 61 | value: true 62 | deterministic: 63 | value: true 64 | deterministic_eval: 65 | value: false 66 | use_tanh_mean: 67 | value: true 68 | use_log_std_param: 69 | value: false 70 | log_std_min: 71 | value: -5.0 72 | log_std_max: 73 | value: 2.0 74 | 75 | # --- Critic + value function architecture --- 76 | num_critics: 77 | value: 2 78 | critic_num_layers: 79 | value: 3 80 | critic_layer_width: 81 | value: 256 82 | critic_ln: 83 | value: true 84 | 85 | # --- Actor loss components --- 86 | actor_bc_coef: 87 | values: [0.0005, 0.001, 0.002, 0.003, 0.03, 0.1, 0.3, 1.0] 88 | actor_q_coef: 89 | value: 1.0 90 | use_q_target_in_actor: 91 | value: false 92 | normalize_q_loss: 93 | value: true 94 | aggregate_q: 95 | value: min 96 | 97 | # --- AWR (Advantage Weighted Regression) actor --- 98 | use_awr: 99 | value: false 100 | awr_temperature: 101 | value: 1.0 102 | awr_exp_adv_clip: 103 | value: 100.0 104 | 105 | # --- Critic loss components --- 106 | num_critic_updates_per_step: 107 | value: 2 108 | critic_bc_coef: 109 | values: [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1] 110 | diversity_coef: 111 | value: 0.0 112 | policy_noise: 113 | value: 0.2 114 | noise_clip: 115 | value: 0.5 116 | use_target_actor: 117 | value: true 118 | 119 | # --- Value function --- 120 | use_value_target: 121 | value: false 122 | value_expectile: 123 | value: 0.8 124 | 125 | # --- Entropy loss --- 126 | use_entropy_loss: 127 | value: false 128 | ent_coef_init: 129 | value: 1.0 130 | actor_entropy_coef: 131 | value: 0.0 132 | critic_entropy_coef: 133 | value: 0.0 134 | 135 | # --- World model --- 136 | model_path: 137 | value: "PLACEHOLDER MODEL PATH" 138 | dataset_sample_ratio: 139 | value: 0.05 140 | rollout_interval: 141 | value: 1000 142 | rollout_length: 143 | values: [1, 5] 144 | rollout_batch_size: 145 | value: 50000 146 | model_retain_epochs: 147 | value: 5 148 | step_penalty_coef: 149 | value: [0.0, 0.25, 0.5, 1.0, 5.0] 150 | -------------------------------------------------------------------------------- /configs/unifloral/rebrac.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: ReBRAC (Unifloral) 8 | program: algorithms/unifloral.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: unified 20 | num_updates: 21 | value: 1_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 1e-3 42 | actor_lr: 43 | value: 1e-3 44 | lr_schedule: 45 | value: constant 46 | batch_size: 47 | value: 1024 48 | gamma: 49 | value: 0.99 50 | polyak_step_size: 51 | value: 0.005 52 | norm_obs: 53 | value: false 54 | 55 | # --- Actor architecture --- 56 | actor_num_layers: 57 | value: 3 58 | actor_layer_width: 59 | value: 256 60 | actor_ln: 61 | value: true 62 | deterministic: 63 | value: true 64 | deterministic_eval: 65 | value: false 66 | use_tanh_mean: 67 | value: true 68 | use_log_std_param: 69 | value: false 70 | log_std_min: 71 | value: -5.0 72 | log_std_max: 73 | value: 2.0 74 | 75 | # --- Critic + value function architecture --- 76 | num_critics: 77 | value: 2 78 | critic_num_layers: 79 | value: 3 80 | critic_layer_width: 81 | value: 256 82 | critic_ln: 83 | value: true 84 | 85 | # --- Actor loss components --- 86 | actor_bc_coef: 87 | values: [0.0005, 0.001, 0.002, 0.003, 0.03, 0.1, 0.3, 1.0] 88 | actor_q_coef: 89 | value: 1.0 90 | use_q_target_in_actor: 91 | value: false 92 | normalize_q_loss: 93 | value: true 94 | aggregate_q: 95 | value: min 96 | 97 | # --- AWR (Advantage Weighted Regression) actor --- 98 | use_awr: 99 | value: false 100 | awr_temperature: 101 | value: 1.0 102 | awr_exp_adv_clip: 103 | value: 100.0 104 | 105 | # --- Critic loss components --- 106 | num_critic_updates_per_step: 107 | value: 2 108 | critic_bc_coef: 109 | values: [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1] 110 | diversity_coef: 111 | value: 0.0 112 | policy_noise: 113 | value: 0.2 114 | noise_clip: 115 | value: 0.5 116 | use_target_actor: 117 | value: true 118 | 119 | # --- Value function --- 120 | use_value_target: 121 | value: false 122 | value_expectile: 123 | value: 0.8 124 | 125 | # --- Entropy loss --- 126 | use_entropy_loss: 127 | value: false 128 | ent_coef_init: 129 | value: 1.0 130 | actor_entropy_coef: 131 | value: 0.0 132 | critic_entropy_coef: 133 | value: 0.0 134 | -------------------------------------------------------------------------------- /configs/unifloral/sac_n.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: SAC-N (Unifloral) 8 | program: algorithms/unifloral.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: unified 20 | num_updates: 21 | value: 3_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 3e-4 42 | actor_lr: 43 | value: 3e-4 44 | lr_schedule: 45 | value: constant 46 | batch_size: 47 | value: 256 48 | gamma: 49 | value: 0.99 50 | polyak_step_size: 51 | value: 0.005 52 | norm_obs: 53 | value: false 54 | 55 | # --- Actor architecture --- 56 | actor_num_layers: 57 | value: 3 58 | actor_layer_width: 59 | value: 256 60 | actor_ln: 61 | value: false 62 | deterministic: 63 | value: false 64 | deterministic_eval: 65 | value: false 66 | use_tanh_mean: 67 | value: false 68 | use_log_std_param: 69 | value: false 70 | log_std_min: 71 | value: -5.0 72 | log_std_max: 73 | value: 2.0 74 | 75 | # --- Critic + value function architecture --- 76 | num_critics: 77 | values: [5, 10, 20, 50, 100, 200] 78 | critic_num_layers: 79 | value: 3 80 | critic_layer_width: 81 | value: 256 82 | critic_ln: 83 | value: false 84 | 85 | # --- Actor loss components --- 86 | actor_bc_coef: 87 | value: 0.0 88 | actor_q_coef: 89 | value: 1.0 90 | use_q_target_in_actor: 91 | value: false 92 | normalize_q_loss: 93 | value: false 94 | aggregate_q: 95 | value: min 96 | 97 | # --- AWR (Advantage Weighted Regression) actor --- 98 | use_awr: 99 | value: false 100 | awr_temperature: 101 | value: 1.0 102 | awr_exp_adv_clip: 103 | value: 100.0 104 | 105 | # --- Critic loss components --- 106 | num_critic_updates_per_step: 107 | value: 1 108 | critic_bc_coef: 109 | value: 0.0 110 | diversity_coef: 111 | value: 0.0 112 | policy_noise: 113 | value: 0.0 114 | noise_clip: 115 | value: 0.0 116 | use_target_actor: 117 | value: false 118 | 119 | # --- Value function --- 120 | use_value_target: 121 | value: false 122 | value_expectile: 123 | value: 0.8 124 | 125 | # --- Entropy loss --- 126 | use_entropy_loss: 127 | value: true 128 | ent_coef_init: 129 | value: 1.0 130 | actor_entropy_coef: 131 | value: 1.0 132 | critic_entropy_coef: 133 | value: 1.0 134 | -------------------------------------------------------------------------------- /configs/unifloral/td3_awr.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: TD3-AWR 8 | program: algorithms/unifloral.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: unified 20 | num_updates: 21 | value: 1_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 1e-3 42 | actor_lr: 43 | value: 1e-3 44 | lr_schedule: 45 | value: constant 46 | batch_size: 47 | value: 1024 48 | gamma: 49 | value: 0.99 50 | polyak_step_size: 51 | value: 0.005 52 | norm_obs: 53 | value: false 54 | 55 | # --- Actor architecture --- 56 | actor_num_layers: 57 | value: 3 58 | actor_layer_width: 59 | value: 256 60 | actor_ln: 61 | value: true 62 | deterministic: 63 | value: true 64 | deterministic_eval: 65 | value: false 66 | use_tanh_mean: 67 | value: true 68 | use_log_std_param: 69 | value: false 70 | log_std_min: 71 | value: -5.0 72 | log_std_max: 73 | value: 2.0 74 | 75 | # --- Critic + value function architecture --- 76 | num_critics: 77 | value: 2 78 | critic_num_layers: 79 | value: 3 80 | critic_layer_width: 81 | value: 256 82 | critic_ln: 83 | value: true 84 | 85 | # --- Actor loss components --- 86 | actor_bc_coef: 87 | values: [0.0005, 0.001, 0.002, 0.003, 0.03, 0.1, 0.3, 1.0] 88 | actor_q_coef: 89 | value: 1.0 90 | use_q_target_in_actor: 91 | value: false 92 | normalize_q_loss: 93 | value: true 94 | aggregate_q: 95 | value: min 96 | 97 | # --- AWR (Advantage Weighted Regression) actor --- 98 | use_awr: 99 | value: true 100 | awr_temperature: 101 | values: [0.5, 3.0, 10.0] 102 | awr_exp_adv_clip: 103 | value: 100.0 104 | 105 | # --- Critic loss components --- 106 | num_critic_updates_per_step: 107 | value: 2 108 | critic_bc_coef: 109 | values: [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1] 110 | diversity_coef: 111 | value: 0.0 112 | policy_noise: 113 | value: 0.2 114 | noise_clip: 115 | value: 0.5 116 | use_target_actor: 117 | value: true 118 | 119 | # --- Value function --- 120 | use_value_target: 121 | value: false 122 | value_expectile: 123 | values: [0.5, 0.7, 0.9] 124 | 125 | # --- Entropy loss --- 126 | use_entropy_loss: 127 | value: false 128 | ent_coef_init: 129 | value: 1.0 130 | actor_entropy_coef: 131 | value: 0.0 132 | critic_entropy_coef: 133 | value: 0.0 134 | -------------------------------------------------------------------------------- /configs/unifloral/td3_bc.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - python3.9 3 | - ${program} 4 | - ${args_no_boolean_flags} 5 | entity: flair 6 | method: random 7 | name: TD3+BC (Unifloral) 8 | program: algorithms/unifloral.py 9 | project: unifloral 10 | 11 | parameters: 12 | # --- Experiment --- 13 | seed: 14 | values: [0, 1, 2, 3, 4] 15 | dataset: 16 | values: 17 | - halfcheetah-medium-v2 18 | algorithm: 19 | value: unified 20 | num_updates: 21 | value: 1_000_000 22 | eval_interval: 23 | value: 2500 24 | eval_workers: 25 | value: 8 26 | eval_final_episodes: 27 | value: 1000 28 | 29 | # --- Logging --- 30 | log: 31 | value: true 32 | wandb_project: 33 | value: unifloral 34 | wandb_team: 35 | value: flair 36 | wandb_group: 37 | value: debug 38 | 39 | # --- Generic optimization --- 40 | lr: 41 | value: 3e-4 42 | actor_lr: 43 | value: 3e-4 44 | lr_schedule: 45 | value: constant 46 | batch_size: 47 | value: 256 48 | gamma: 49 | value: 0.99 50 | polyak_step_size: 51 | value: 0.005 52 | norm_obs: 53 | value: true 54 | 55 | # --- Actor architecture --- 56 | actor_num_layers: 57 | value: 2 58 | actor_layer_width: 59 | value: 256 60 | actor_ln: 61 | value: false 62 | deterministic: 63 | value: true 64 | deterministic_eval: 65 | value: false 66 | use_tanh_mean: 67 | value: true 68 | use_log_std_param: 69 | value: false 70 | log_std_min: 71 | value: -5.0 72 | log_std_max: 73 | value: 2.0 74 | 75 | # --- Critic + value function architecture --- 76 | num_critics: 77 | value: 2 78 | critic_num_layers: 79 | value: 2 80 | critic_layer_width: 81 | value: 256 82 | critic_ln: 83 | value: false 84 | 85 | # --- Actor loss components --- 86 | actor_bc_coef: 87 | value: 1.0 88 | actor_q_coef: 89 | values: [1.0, 2.0, 2.5, 3.0, 4.0] 90 | use_q_target_in_actor: 91 | value: false 92 | normalize_q_loss: 93 | value: true 94 | aggregate_q: 95 | value: first 96 | 97 | # --- AWR (Advantage Weighted Regression) actor --- 98 | use_awr: 99 | value: false 100 | awr_temperature: 101 | value: 1.0 102 | awr_exp_adv_clip: 103 | value: 100.0 104 | 105 | # --- Critic loss components --- 106 | num_critic_updates_per_step: 107 | value: 2 108 | critic_bc_coef: 109 | value: 0.0 110 | diversity_coef: 111 | value: 0.0 112 | policy_noise: 113 | value: 0.2 114 | noise_clip: 115 | value: 0.5 116 | use_target_actor: 117 | value: true 118 | 119 | # --- Value function --- 120 | use_value_target: 121 | value: false 122 | value_expectile: 123 | value: 0.8 124 | 125 | # --- Entropy loss --- 126 | use_entropy_loss: 127 | value: false 128 | ent_coef_init: 129 | value: 1.0 130 | actor_entropy_coef: 131 | value: 0.0 132 | critic_entropy_coef: 133 | value: 0.0 134 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | """Evaluation utilities for the Unifloral project. 2 | 3 | This module provides tools for: 4 | 1. Loading and parsing experiment results 5 | 2. Running bandit-based policy selection 6 | 3. Computing confidence intervals via bootstrapping 7 | """ 8 | 9 | from collections import namedtuple 10 | from datetime import datetime 11 | import os 12 | import re 13 | from typing import Dict, Tuple 14 | import warnings 15 | 16 | from functools import partial 17 | import glob 18 | import jax 19 | from jax import numpy as jnp 20 | import numpy as np 21 | import pandas as pd 22 | 23 | 24 | r""" 25 | |\ __ 26 | \| /_/ 27 | \| 28 | ___|_____ 29 | \ / 30 | \ / 31 | \___/ Data loading 32 | """ 33 | 34 | 35 | def parse_and_load_npz(filename: str) -> Dict: 36 | """Load data from a result file and parse metadata from filename. 37 | 38 | Args: 39 | filename: Path to the .npz result file 40 | 41 | Returns: 42 | Dictionary containing loaded arrays and metadata 43 | """ 44 | # Parse filename to extract algorithm, dataset, and timestamp 45 | pattern = r"(.+)_(.+)_(\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2})" 46 | match = re.match(pattern, os.path.basename(filename)) 47 | if not match: 48 | raise ValueError(f"Could not parse filename: {filename}") 49 | 50 | algorithm, dataset, dt_str = match.groups() 51 | dt = datetime.strptime(dt_str, "%Y-%m-%d_%H-%M-%S") 52 | 53 | data = np.load(filename, allow_pickle=True) 54 | data = {k: v for k, v in data.items()} 55 | data["algorithm"] = algorithm 56 | data["dataset"] = dataset 57 | data["datetime"] = dt 58 | data.update(data.pop("args", np.array({})).item()) # Flatten args 59 | return data 60 | 61 | 62 | def load_results_dataframe(results_dir: str = "final_returns") -> pd.DataFrame: 63 | """Load all result files from a directory into a pandas DataFrame. 64 | 65 | Args: 66 | results_dir: Directory containing .npz result files 67 | 68 | Returns: 69 | DataFrame containing results from all successfully loaded files 70 | """ 71 | npz_files = glob.glob(os.path.join(results_dir, "*.npz")) 72 | data_list = [] 73 | 74 | for f in npz_files: 75 | try: 76 | data = parse_and_load_npz(f) 77 | data_list.append(data) 78 | except Exception as e: 79 | print(f"Error loading {f}: {e}") 80 | continue 81 | 82 | df = pd.DataFrame(data_list).drop(columns=["Index"], errors="ignore") 83 | if "final_scores" in df.columns: 84 | df["final_scores"] = df["final_scores"].apply(lambda x: x.reshape(-1)) 85 | if "final_returns" in df.columns: 86 | df["final_returns"] = df["final_returns"].apply(lambda x: x.reshape(-1)) 87 | 88 | df = df.sort_values(by=["algorithm", "dataset", "datetime"]) 89 | return df.reset_index(drop=True) 90 | 91 | 92 | r""" 93 | __/) 94 | .-(__(=: 95 | |\ | \) 96 | \ || 97 | \|| 98 | \| 99 | ___|_____ 100 | \ / 101 | \ / 102 | \___/ Bandit Evaluation and Bootstrapping 103 | """ 104 | 105 | BanditState = namedtuple("BanditState", "rng counts rewards total_pulls") 106 | 107 | 108 | def ucb( 109 | means: jnp.ndarray, counts: jnp.ndarray, total_counts: int, alpha: float 110 | ) -> jnp.ndarray: 111 | """Compute UCB exploration bonus. 112 | 113 | Args: 114 | means: Array of empirical means for each arm 115 | counts: Array of pull counts for each arm 116 | total_counts: Total number of pulls across all arms 117 | alpha: Exploration coefficient 118 | 119 | Returns: 120 | Array of UCB values for each arm 121 | """ 122 | exploration = jnp.sqrt(alpha * jnp.log(total_counts) / (counts + 1e-9)) 123 | return means + exploration 124 | 125 | 126 | def argmax_with_random_tiebreaking(rng: jnp.ndarray, values: jnp.ndarray) -> int: 127 | """Select maximum value with random tiebreaking. 128 | 129 | Args: 130 | rng: JAX PRNGKey 131 | values: Array of values to select from 132 | 133 | Returns: 134 | Index of selected maximum value 135 | """ 136 | mask = values == jnp.max(values) 137 | p = mask / (mask.sum() + 1e-9) 138 | return jax.random.choice(rng, jnp.arange(len(values)), p=p) 139 | 140 | 141 | @partial(jax.jit, static_argnums=(2,)) 142 | def run_bandit( 143 | returns_array: jnp.ndarray, 144 | rng: jnp.ndarray, 145 | max_pulls: int, 146 | alpha: float, 147 | policy_idx: jnp.ndarray, 148 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 149 | """Run a single bandit algorithm and report results after each pull. 150 | 151 | Args: 152 | returns_array: Array of returns for each policy and rollout 153 | rng: JAX PRNGKey 154 | max_pulls: Maximum number of pulls to execute 155 | alpha: UCB exploration coefficient 156 | policy_idx: Indices of policies to consider 157 | 158 | Returns: 159 | Tuple of (pulls, estimated_bests) 160 | """ 161 | returns_array = returns_array[policy_idx] 162 | num_policies, num_rollouts = returns_array.shape 163 | 164 | init_state = BanditState( 165 | rng=rng, 166 | counts=jnp.zeros(num_policies, dtype=jnp.int32), 167 | rewards=jnp.zeros(num_policies), 168 | total_pulls=1, 169 | ) 170 | 171 | def bandit_step(state: BanditState, _): 172 | """Run one bandit step and track performance.""" 173 | rng, rng_lever, rng_reward = jax.random.split(state.rng, 3) 174 | 175 | # Select arm using UCB 176 | means = state.rewards / jnp.maximum(state.counts, 1) 177 | ucb_values = ucb(means, state.counts, state.total_pulls, alpha) 178 | arm = argmax_with_random_tiebreaking(rng_lever, ucb_values) 179 | 180 | # Sample a reward for the chosen arm 181 | idx = jax.random.randint(rng_reward, shape=(), minval=0, maxval=num_rollouts) 182 | reward = returns_array[arm, idx] 183 | new_state = BanditState( 184 | rng=rng, 185 | counts=state.counts.at[arm].add(1), 186 | rewards=state.rewards.at[arm].add(reward), 187 | total_pulls=state.total_pulls + 1, 188 | ) 189 | 190 | # Calculate best arm based on current state 191 | updated_means = new_state.rewards / jnp.maximum(new_state.counts, 1) 192 | best_arm = jnp.argmax(updated_means) 193 | estimated_best = returns_array[best_arm].mean() 194 | 195 | return new_state, (state.total_pulls, estimated_best) 196 | 197 | _, (pulls, estimated_bests) = jax.lax.scan( 198 | bandit_step, init_state, length=max_pulls 199 | ) 200 | return pulls, estimated_bests 201 | 202 | 203 | def run_bandit_trials( 204 | returns_array: jnp.ndarray, 205 | seed: int = 17, 206 | num_subsample: int = 20, 207 | num_repeats: int = 1000, 208 | max_pulls: int = 200, 209 | ucb_alpha: float = 2.0, 210 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 211 | """Run multiple bandit trials and collect results at each step. 212 | 213 | Args: 214 | returns_array: Array of returns for each policy and rollout 215 | seed: Random seed 216 | num_subsample: Number of policies to subsample on each trial 217 | num_repeats: Number of trials to run 218 | max_pulls: Maximum number of pulls per trial 219 | ucb_alpha: UCB exploration coefficient 220 | 221 | Returns: 222 | Tuple of (pulls, estimated_bests) 223 | """ 224 | rng = jax.random.PRNGKey(seed) 225 | num_policies = returns_array.shape[0] 226 | 227 | num_subsample = min(num_subsample, num_policies) 228 | if num_subsample > num_policies: 229 | warnings.warn("Not enough policies to subsample, using all policies") 230 | 231 | rng, rng_trials, rng_sample = jax.random.split(rng, 3) 232 | rng_trials = jax.random.split(rng_trials, num_repeats) 233 | 234 | def sample_policies(rng: jnp.ndarray) -> jnp.ndarray: 235 | """Sample a subset of policy indices.""" 236 | if num_subsample > num_policies: 237 | return jnp.arange(num_policies) 238 | return jax.random.choice( 239 | rng, jnp.arange(num_policies), shape=(num_subsample,), replace=False 240 | ) 241 | 242 | # Create a batch of policy index arrays for all trials 243 | rng_sample_keys = jax.random.split(rng_sample, num_repeats) 244 | policy_indices = jax.vmap(sample_policies)(rng_sample_keys) 245 | 246 | # Run bandit trials with policy subsampling 247 | # Pulls are the same for all trials, so we can just return the first one 248 | vmap_run_bandit = jax.vmap(run_bandit, in_axes=(None, 0, None, None, 0)) 249 | pulls, estimated_bests = vmap_run_bandit( 250 | returns_array, rng_trials, max_pulls, ucb_alpha, policy_indices 251 | ) 252 | return pulls[0], estimated_bests 253 | 254 | 255 | def bootstrap_confidence_interval( 256 | rng: jnp.ndarray, 257 | data: jnp.ndarray, 258 | n_bootstraps: int = 1000, 259 | confidence: float = 0.95, 260 | ) -> Tuple[float, float]: 261 | """Compute bootstrap confidence interval for mean of data. 262 | 263 | Args: 264 | rng: JAX PRNGKey 265 | data: Array of values to bootstrap 266 | n_bootstraps: Number of bootstrap samples 267 | confidence: Confidence level (between 0 and 1) 268 | 269 | Returns: 270 | Tuple of (lower_bound, upper_bound) 271 | """ 272 | 273 | @jax.vmap 274 | def bootstrap_mean(rng): 275 | samples = jax.random.choice(rng, data, shape=(data.shape[0],), replace=True) 276 | return samples.mean() 277 | 278 | bootstrap_means = bootstrap_mean(jax.random.split(rng, n_bootstraps)) 279 | lower_bound = jnp.percentile(bootstrap_means, 100 * (1 - confidence) / 2) 280 | upper_bound = jnp.percentile(bootstrap_means, 100 * (1 + confidence) / 2) 281 | return lower_bound, upper_bound 282 | 283 | 284 | def bootstrap_bandit_trials( 285 | returns_array: jnp.ndarray, 286 | seed: int = 17, 287 | num_subsample: int = 20, 288 | num_repeats: int = 1000, 289 | max_pulls: int = 200, 290 | ucb_alpha: float = 2.0, 291 | n_bootstraps: int = 1000, 292 | confidence: float = 0.95, 293 | ) -> Dict[str, np.ndarray]: 294 | """Run bandit trials and compute bootstrap confidence intervals. 295 | 296 | Args: 297 | returns_array: Array of returns for each policy and rollout has shape (num_policies, num_rollouts) 298 | seed: Random seed 299 | num_subsample: Number of policies to subsample 300 | num_repeats: Number of bandit trials to run 301 | max_pulls: Maximum number of pulls per trial 302 | ucb_alpha: UCB exploration coefficient 303 | n_bootstraps: Number of bootstrap samples 304 | confidence: Confidence level for intervals 305 | 306 | Returns: 307 | Dictionary with the following keys: 308 | - pulls: Number of pulls at each step 309 | - estimated_bests_mean: Mean of the currently estimated best returns across trials 310 | - estimated_bests_ci_low: Lower confidence bound for estimated best returns 311 | - estimated_bests_ci_high: Upper confidence bound for estimated best returns 312 | """ 313 | rng = jax.random.PRNGKey(seed) 314 | rng = jax.random.split(rng, max_pulls) 315 | 316 | pulls, estimated_bests = run_bandit_trials( 317 | returns_array, seed, num_subsample, num_repeats, max_pulls, ucb_alpha 318 | ) 319 | vmap_bootstrap = jax.vmap(bootstrap_confidence_interval, in_axes=(0, 1, None, None)) 320 | ci_low, ci_high = vmap_bootstrap(rng, estimated_bests, n_bootstraps, confidence) 321 | estimated_bests_mean = estimated_bests.mean(axis=0) 322 | 323 | return { 324 | "pulls": pulls, 325 | "estimated_bests_mean": estimated_bests_mean, 326 | "estimated_bests_ci_low": ci_low, 327 | "estimated_bests_ci_high": ci_high, 328 | } 329 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | black==24.4.2 3 | cython<3 4 | d4rl==1.1 5 | distrax==0.1.3 6 | dm-control==1.0.5 7 | flax==0.6.11 8 | gym==0.23.1 9 | jax[cuda12_pip]==0.4.16 10 | mujoco==2.2.1 11 | mujoco-py==2.1.2.14 12 | numpy==1.22.4 13 | optax==0.1.5 14 | orbax-checkpoint==0.4.4 15 | pre-commit==3.7.1 16 | scipy==1.12.0 17 | tensorflow==2.13.0 18 | tensorstore==0.1.51 19 | tyro==0.7.3 20 | wandb==0.17.3 21 | --------------------------------------------------------------------------------