├── .gitignore
├── .pylintrc
├── LICENSE
├── README.md
├── algorithms
├── bc.py
├── combo.py
├── cql.py
├── dynamics.py
├── edac.py
├── iql.py
├── mopo.py
├── morel.py
├── rebrac.py
├── sac_n.py
├── td3_bc.py
├── termination_fns.py
└── unifloral.py
├── configs
├── algorithms
│ ├── bc.yaml
│ ├── combo.yaml
│ ├── cql.yaml
│ ├── edac.yaml
│ ├── iql.yaml
│ ├── mopo.yaml
│ ├── morel.yaml
│ ├── rebrac.yaml
│ ├── sac_n.yaml
│ └── td3_bc.yaml
├── dynamics.yaml
└── unifloral
│ ├── bc.yaml
│ ├── edac.yaml
│ ├── iql.yaml
│ ├── mobrac.yaml
│ ├── rebrac.yaml
│ ├── sac_n.yaml
│ ├── td3_awr.yaml
│ └── td3_bc.yaml
├── evaluation.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb/
2 | dynamics_models/
3 | final_returns/
4 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MESSAGES CONTROL]
2 |
3 | # Enable the message, report, category or checker with the given id(s). You can
4 | # either give multiple identifier separated by comma (,) or put this option
5 | # multiple time.
6 | #enable=
7 |
8 | # Disable the message, report, category or checker with the given id(s). You
9 | # can either give multiple identifier separated by comma (,) or put this option
10 | # multiple time (only on the command line, not in the configuration file where
11 | # it should appear only once).
12 | disable=C0114,C0115,C0116,W0105,W0621,C3001
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2025 Matthew Jackson
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
🌹 Unifloral: Unified Offline Reinforcement Learning
2 |
3 |
4 |
5 |
6 |
7 |
8 | Unified implementations and rigorous evaluation for offline reinforcement learning - built by [Matthew Jackson](https://github.com/EmptyJackson), [Uljad Berdica](https://github.com/uljad), and [Jarek Liesen](https://github.com/keraJLi).
9 |
10 | ## 💡 Code Philosophy
11 |
12 | - ⚛️ **Single-file**: We implement algorithms as standalone Python files.
13 | - 🤏 **Minimal**: We only edit what is necessary between algorithms, making comparisons straightforward.
14 | - ⚡️ **GPU-accelerated**: We use JAX and end-to-end compile all training code, enabling lightning-fast training.
15 |
16 | Inspired by [CORL](https://github.com/tinkoff-ai/CORL) and [CleanRL](https://github.com/vwxyzjn/cleanrl) - check them out!
17 |
18 | ## 🤖 Algorithms
19 |
20 | We provide two types of algorithm implementation:
21 |
22 | 1. **Standalone**: Each algorithm is implemented as a [single file](algorithms) with minimal dependencies, making it easy to understand and modify.
23 | 2. **Unified**: Most algorithms are available as configs for our unified implementation [`unifloral.py`](algorithms/unifloral.py).
24 |
25 | After training, final evaluation results are saved to `.npz` files in [`final_returns/`](final_returns) for analysis using our evaluation protocol.
26 |
27 | All scripts support [D4RL](https://github.com/Farama-Foundation/D4RL) and use [Weights & Biases](https://wandb.ai) for logging, with configs provided as WandB sweep files.
28 |
29 | ### Model-free
30 |
31 | | Algorithm | Standalone | Unified | Extras |
32 | | --- | --- | --- | --- |
33 | | BC | [`bc.py`](algorithms/bc.py) | [`unifloral/bc.yaml`](configs/unifloral/bc.yaml) | - |
34 | | SAC-N | [`sac_n.py`](algorithms/sac_n.py) | [`unifloral/sac_n.yaml`](configs/unifloral/sac_n.yaml) | [[ArXiv]](https://arxiv.org/abs/2110.01548) |
35 | | EDAC | [`edac.py`](algorithms/edac.py) | [`unifloral/edac.yaml`](configs/unifloral/edac.yaml) | [[ArXiv]](https://arxiv.org/abs/2110.01548) |
36 | | CQL | [`cql.py`](algorithms/cql.py) | - | [[ArXiv]](https://arxiv.org/abs/2006.04779) |
37 | | IQL | [`iql.py`](algorithms/iql.py) | [`unifloral/iql.yaml`](configs/unifloral/iql.yaml) | [[ArXiv]](https://arxiv.org/abs/2110.06169) |
38 | | TD3-BC | [`td3_bc.py`](algorithms/td3_bc.py) | [`unifloral/td3_bc.yaml`](configs/unifloral/td3_bc.yaml) | [[ArXiv]](https://arxiv.org/abs/2106.06860) |
39 | | ReBRAC | [`rebrac.py`](algorithms/rebrac.py) | [`unifloral/rebrac.yaml`](configs/unifloral/rebrac.yaml) | [[ArXiv]](https://arxiv.org/abs/2305.09836) |
40 | | TD3-AWR | - | [`unifloral/td3_awr.yaml`](configs/unifloral/td3_awr.yaml) | [[ArXiv]](https://arxiv.org/abs/2504.11453) |
41 |
42 | ### Model-based
43 |
44 | We implement a single script for dynamics model training: [`dynamics.py`](algorithms/dynamics.py), with config [`dynamics.yaml`](configs/dynamics.yaml).
45 |
46 | | Algorithm | Standalone | Unified | Extras |
47 | | --- | --- | --- | --- |
48 | | MOPO | [`mopo.py`](algorithms/mopo.py) | - | [[ArXiv]](https://arxiv.org/abs/2005.13239) |
49 | | MOReL | [`morel.py`](algorithms/morel.py) | - | [[ArXiv]](https://arxiv.org/abs/2005.05951) |
50 | | COMBO | [`combo.py`](algorithms/combo.py) | - | [[ArXiv]](https://arxiv.org/abs/2102.08363) |
51 | | MoBRAC | - | [`unifloral/mobrac.yaml`](configs/unifloral/mobrac.yaml) | [[ArXiv]](https://arxiv.org/abs/2504.11453) |
52 |
53 | New ones coming soon 👀
54 |
55 | ## 📊 Evaluation
56 |
57 | Our evaluation script ([`evaluation.py`](evaluation.py)) implements the protocol described in our paper, analysing the performance of a UCB bandit over a range of policy evaluations.
58 |
59 | ```python
60 | from evaluation import load_results_dataframe, bootstrap_bandit_trials
61 | import jax.numpy as jnp
62 |
63 | # Load all results from the final_returns directory
64 | df = load_results_dataframe("final_returns")
65 |
66 | # Run bandit trials with bootstrapped confidence intervals
67 | results = bootstrap_bandit_trials(
68 | returns_array=jnp.array(policy_returns), # Shape: (num_policies, num_rollouts)
69 | num_subsample=8, # Number of policies to subsample
70 | num_repeats=1000, # Number of bandit trials
71 | max_pulls=200, # Maximum pulls per trial
72 | ucb_alpha=2.0, # UCB exploration coefficient
73 | n_bootstraps=1000, # Bootstrap samples for confidence intervals
74 | confidence=0.95 # Confidence level
75 | )
76 |
77 | # Access results
78 | pulls = results["pulls"] # Number of pulls at each step
79 | means = results["estimated_bests_mean"] # Mean score of estimated best policy
80 | ci_low = results["estimated_bests_ci_low"] # Lower confidence bound
81 | ci_high = results["estimated_bests_ci_high"] # Upper confidence bound
82 | ```
83 |
84 | ## 📝 Cite us!
85 | ```bibtex
86 | @misc{jackson2025clean,
87 | title={A Clean Slate for Offline Reinforcement Learning},
88 | author={Matthew Thomas Jackson and Uljad Berdica and Jarek Liesen and Shimon Whiteson and Jakob Nicolaus Foerster},
89 | year={2025},
90 | eprint={2504.11453},
91 | archivePrefix={arXiv},
92 | primaryClass={cs.LG},
93 | url={https://arxiv.org/abs/2504.11453},
94 | }
95 | ```
96 |
--------------------------------------------------------------------------------
/algorithms/bc.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from dataclasses import dataclass, asdict
3 | from datetime import datetime
4 | import os
5 | import warnings
6 |
7 | import distrax
8 | import d4rl
9 | import flax.linen as nn
10 | from flax.training.train_state import TrainState
11 | import gym
12 | import jax
13 | import jax.numpy as jnp
14 | import numpy as onp
15 | import optax
16 | import tyro
17 | import wandb
18 |
19 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
20 |
21 |
22 | @dataclass
23 | class Args:
24 | # --- Experiment ---
25 | seed: int = 0
26 | dataset: str = "halfcheetah-medium-v2"
27 | algorithm: str = "bc"
28 | num_updates: int = 1_000_000
29 | eval_interval: int = 2500
30 | eval_workers: int = 8
31 | eval_final_episodes: int = 1000
32 | # --- Logging ---
33 | log: bool = False
34 | wandb_project: str = "unifloral"
35 | wandb_team: str = "flair"
36 | wandb_group: str = "debug"
37 | # --- Generic optimization ---
38 | lr: float = 3e-4
39 | batch_size: int = 256
40 |
41 |
42 | r"""
43 | |\ __
44 | \| /_/
45 | \|
46 | ___|_____
47 | \ /
48 | \ /
49 | \___/ Preliminaries
50 | """
51 |
52 | AgentTrainState = namedtuple("AgentTrainState", "actor")
53 | Transition = namedtuple("Transition", "obs action reward next_obs done")
54 |
55 |
56 | class DeterministicTanhActor(nn.Module):
57 | num_actions: int
58 | obs_mean: jax.Array
59 | obs_std: jax.Array
60 |
61 | @nn.compact
62 | def __call__(self, x):
63 | x = (x - self.obs_mean) / (self.obs_std + 1e-3)
64 | for _ in range(2):
65 | x = nn.Dense(256)(x)
66 | x = nn.relu(x)
67 | action = nn.Dense(self.num_actions)(x)
68 | pi = distrax.Transformed(
69 | distrax.Deterministic(action),
70 | distrax.Tanh(),
71 | )
72 | return pi
73 |
74 |
75 | def create_train_state(args, rng, network, dummy_input):
76 | return TrainState.create(
77 | apply_fn=network.apply,
78 | params=network.init(rng, *dummy_input),
79 | tx=optax.adam(args.lr, eps=1e-5),
80 | )
81 |
82 |
83 | def eval_agent(args, rng, env, agent_state):
84 | # --- Reset environment ---
85 | step = 0
86 | returned = onp.zeros(args.eval_workers).astype(bool)
87 | cum_reward = onp.zeros(args.eval_workers)
88 | rng, rng_reset = jax.random.split(rng)
89 | rng_reset = jax.random.split(rng_reset, args.eval_workers)
90 | obs = env.reset()
91 |
92 | # --- Rollout agent ---
93 | @jax.jit
94 | @jax.vmap
95 | def _policy_step(rng, obs):
96 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs)
97 | action = pi.sample(seed=rng)
98 | return jnp.nan_to_num(action)
99 |
100 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps
101 | while step < max_episode_steps and not returned.all():
102 | # --- Take step in environment ---
103 | step += 1
104 | rng, rng_step = jax.random.split(rng)
105 | rng_step = jax.random.split(rng_step, args.eval_workers)
106 | action = _policy_step(rng_step, jnp.array(obs))
107 | obs, reward, done, info = env.step(onp.array(action))
108 |
109 | # --- Track cumulative reward ---
110 | cum_reward += reward * ~returned
111 | returned |= done
112 |
113 | if step >= max_episode_steps and not returned.all():
114 | warnings.warn("Maximum steps reached before all episodes terminated")
115 | return cum_reward
116 |
117 |
118 | r"""
119 | __/)
120 | .-(__(=:
121 | |\ | \)
122 | \ ||
123 | \||
124 | \|
125 | ___|_____
126 | \ /
127 | \ /
128 | \___/ Agent
129 | """
130 |
131 |
132 | def make_train_step(args, actor_apply_fn, dataset):
133 | """Make JIT-compatible agent train step."""
134 |
135 | def _train_step(runner_state, _):
136 | rng, agent_state = runner_state
137 |
138 | # --- Sample batch ---
139 | rng, rng_batch = jax.random.split(rng)
140 | batch_indices = jax.random.randint(
141 | rng_batch, (args.batch_size,), 0, len(dataset.obs)
142 | )
143 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset)
144 |
145 | # --- Update actor ---
146 | def _actor_loss_function(params):
147 | def _transition_loss(transition):
148 | pi = actor_apply_fn(params, transition.obs)
149 | pi_action = pi.sample(seed=None)
150 | bc_loss = jnp.square(pi_action - transition.action).mean()
151 | return bc_loss
152 |
153 | bc_loss = jax.vmap(_transition_loss)(batch)
154 | return bc_loss.mean()
155 |
156 | loss_fn = jax.value_and_grad(_actor_loss_function)
157 | actor_loss, actor_grad = loss_fn(agent_state.actor.params)
158 | agent_state = agent_state._replace(
159 | actor=agent_state.actor.apply_gradients(grads=actor_grad)
160 | )
161 |
162 | loss = {
163 | "actor_loss": actor_loss,
164 | "bc_loss": actor_loss,
165 | }
166 | return (rng, agent_state), loss
167 |
168 | return _train_step
169 |
170 |
171 | if __name__ == "__main__":
172 | # --- Parse arguments ---
173 | args = tyro.cli(Args)
174 | rng = jax.random.PRNGKey(args.seed)
175 |
176 | # --- Initialize logger ---
177 | if args.log:
178 | wandb.init(
179 | config=args,
180 | project=args.wandb_project,
181 | entity=args.wandb_team,
182 | group=args.wandb_group,
183 | job_type="train_agent",
184 | )
185 |
186 | # --- Initialize environment and dataset ---
187 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers)
188 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset))
189 | dataset = Transition(
190 | obs=jnp.array(dataset["observations"]),
191 | action=jnp.array(dataset["actions"]),
192 | reward=jnp.array(dataset["rewards"]),
193 | next_obs=jnp.array(dataset["next_observations"]),
194 | done=jnp.array(dataset["terminals"]),
195 | )
196 |
197 | # --- Initialize agent and value networks ---
198 | num_actions = env.single_action_space.shape[0]
199 | obs_mean = dataset.obs.mean(axis=0)
200 | obs_std = jnp.nan_to_num(dataset.obs.std(axis=0), nan=1.0)
201 | dummy_obs = jnp.zeros(env.single_observation_space.shape)
202 | dummy_action = jnp.zeros(num_actions)
203 | actor_net = DeterministicTanhActor(num_actions, obs_mean, obs_std)
204 |
205 | rng, rng_actor = jax.random.split(rng)
206 | agent_state = AgentTrainState(
207 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs])
208 | )
209 |
210 | # --- Make train step ---
211 | _agent_train_step_fn = make_train_step(args, actor_net.apply, dataset)
212 |
213 | num_evals = args.num_updates // args.eval_interval
214 | for eval_idx in range(num_evals):
215 | # --- Execute train loop ---
216 | (rng, agent_state), loss = jax.lax.scan(
217 | _agent_train_step_fn,
218 | (rng, agent_state),
219 | None,
220 | args.eval_interval,
221 | )
222 |
223 | # --- Evaluate agent ---
224 | rng, rng_eval = jax.random.split(rng)
225 | returns = eval_agent(args, rng_eval, env, agent_state)
226 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0
227 |
228 | # --- Log metrics ---
229 | step = (eval_idx + 1) * args.eval_interval
230 | print("Step:", step, f"\t Score: {scores.mean():.2f}")
231 | if args.log:
232 | log_dict = {
233 | "return": returns.mean(),
234 | "score": scores.mean(),
235 | "score_std": scores.std(),
236 | "num_updates": step,
237 | **{k: loss[k][-1] for k in loss},
238 | }
239 | wandb.log(log_dict)
240 |
241 | # --- Evaluate final agent ---
242 | if args.eval_final_episodes > 0:
243 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers))
244 | print(f"Evaluating final agent for {final_iters} iterations...")
245 | _rng = jax.random.split(rng, final_iters)
246 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng])
247 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0
248 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()}
249 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores")
250 |
251 | # --- Write final returns to file ---
252 | os.makedirs("final_returns", exist_ok=True)
253 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
254 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz"
255 | with open(os.path.join("final_returns", filename), "wb") as f:
256 | onp.savez_compressed(f, **info, args=asdict(args))
257 |
258 | if args.log:
259 | wandb.save(os.path.join("final_returns", filename))
260 |
261 | if args.log:
262 | wandb.finish()
263 |
--------------------------------------------------------------------------------
/algorithms/combo.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from dataclasses import dataclass, asdict
3 | from datetime import datetime
4 | from functools import partial
5 | import os
6 | import warnings
7 |
8 | import distrax
9 | import d4rl
10 | import flax.linen as nn
11 | from flax.linen.initializers import constant, uniform
12 | from flax.training.train_state import TrainState
13 | import gym
14 | import jax
15 | import jax.numpy as jnp
16 | import numpy as onp
17 | import optax
18 | import tyro
19 | import wandb
20 |
21 | from dynamics import (
22 | Transition,
23 | load_dynamics_model,
24 | EnsembleDynamics, # required for loading dynamics model
25 | EnsembleDynamicsModel, # required for loading dynamics model
26 | )
27 |
28 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
29 |
30 |
31 | @dataclass
32 | class Args:
33 | # --- Experiment ---
34 | seed: int = 0
35 | dataset: str = "halfcheetah-medium-v2"
36 | algorithm: str = "combo"
37 | num_updates: int = 3_000_000
38 | eval_interval: int = 2500
39 | eval_workers: int = 8
40 | eval_final_episodes: int = 1000
41 | # --- Logging ---
42 | log: bool = False
43 | wandb_project: str = "unifloral"
44 | wandb_team: str = "flair"
45 | wandb_group: str = "debug"
46 | # --- Generic optimization ---
47 | lr: float = 3e-4
48 | batch_size: int = 256
49 | gamma: float = 0.99
50 | polyak_step_size: float = 0.005
51 | # --- SAC-N ---
52 | num_critics: int = 10
53 | # --- World model ---
54 | model_path: str = ""
55 | rollout_interval: int = 1000
56 | rollout_length: int = 5
57 | rollout_batch_size: int = 50000
58 | model_retain_epochs: int = 5
59 | dataset_sample_ratio: float = 0.5
60 | # --- CQL ---
61 | actor_lr: float = 1e-4
62 | cql_temperature: float = 1.0
63 | cql_min_q_weight: float = 0.5
64 |
65 |
66 | r"""
67 | |\ __
68 | \| /_/
69 | \|
70 | ___|_____
71 | \ /
72 | \ /
73 | \___/ Preliminaries
74 | """
75 |
76 | AgentTrainState = namedtuple("AgentTrainState", "actor vec_q vec_q_target alpha")
77 |
78 |
79 | def sym(scale):
80 | def _init(*args, **kwargs):
81 | return uniform(2 * scale)(*args, **kwargs) - scale
82 |
83 | return _init
84 |
85 |
86 | class SoftQNetwork(nn.Module):
87 | @nn.compact
88 | def __call__(self, obs, action):
89 | x = jnp.concatenate([obs, action], axis=-1)
90 | for _ in range(3):
91 | x = nn.Dense(256, bias_init=constant(0.1))(x)
92 | x = nn.relu(x)
93 | q = nn.Dense(1, kernel_init=sym(3e-3), bias_init=sym(3e-3))(x)
94 | return q.squeeze(-1)
95 |
96 |
97 | class VectorQ(nn.Module):
98 | num_critics: int
99 |
100 | @nn.compact
101 | def __call__(self, obs, action):
102 | vmap_critic = nn.vmap(
103 | SoftQNetwork,
104 | variable_axes={"params": 0}, # Parameters not shared between critics
105 | split_rngs={"params": True, "dropout": True}, # Different initializations
106 | in_axes=None,
107 | out_axes=-1,
108 | axis_size=self.num_critics,
109 | )
110 | q_values = vmap_critic()(obs, action)
111 | return q_values
112 |
113 |
114 | class TanhGaussianActor(nn.Module):
115 | num_actions: int
116 | log_std_max: float = 2.0
117 | log_std_min: float = -5.0
118 |
119 | @nn.compact
120 | def __call__(self, x):
121 | for _ in range(3):
122 | x = nn.Dense(256, bias_init=constant(0.1))(x)
123 | x = nn.relu(x)
124 | log_std = nn.Dense(
125 | self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3)
126 | )(x)
127 | std = jnp.exp(jnp.clip(log_std, self.log_std_min, self.log_std_max))
128 | mean = nn.Dense(self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3))(x)
129 | pi = distrax.Transformed(
130 | distrax.Normal(mean, std),
131 | distrax.Tanh(),
132 | )
133 | return pi
134 |
135 |
136 | class EntropyCoef(nn.Module):
137 | ent_coef_init: float = 1.0
138 |
139 | @nn.compact
140 | def __call__(self) -> jnp.ndarray:
141 | log_ent_coef = self.param(
142 | "log_ent_coef",
143 | init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)),
144 | )
145 | return log_ent_coef
146 |
147 |
148 | def create_train_state(args, rng, network, dummy_input, lr=None):
149 | return TrainState.create(
150 | apply_fn=network.apply,
151 | params=network.init(rng, *dummy_input),
152 | tx=optax.adam(lr if lr is not None else args.lr, eps=1e-5),
153 | )
154 |
155 |
156 | def eval_agent(args, rng, env, agent_state):
157 | # --- Reset environment ---
158 | step = 0
159 | returned = onp.zeros(args.eval_workers).astype(bool)
160 | cum_reward = onp.zeros(args.eval_workers)
161 | rng, rng_reset = jax.random.split(rng)
162 | rng_reset = jax.random.split(rng_reset, args.eval_workers)
163 | obs = env.reset()
164 |
165 | # --- Rollout agent ---
166 | @jax.jit
167 | @jax.vmap
168 | def _policy_step(rng, obs):
169 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs)
170 | action = pi.sample(seed=rng)
171 | return jnp.nan_to_num(action)
172 |
173 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps
174 | while step < max_episode_steps and not returned.all():
175 | # --- Take step in environment ---
176 | step += 1
177 | rng, rng_step = jax.random.split(rng)
178 | rng_step = jax.random.split(rng_step, args.eval_workers)
179 | action = _policy_step(rng_step, jnp.array(obs))
180 | obs, reward, done, info = env.step(onp.array(action))
181 |
182 | # --- Track cumulative reward ---
183 | cum_reward += reward * ~returned
184 | returned |= done
185 |
186 | if step >= max_episode_steps and not returned.all():
187 | warnings.warn("Maximum steps reached before all episodes terminated")
188 | return cum_reward
189 |
190 |
191 | def sample_from_buffer(buffer, batch_size, rng):
192 | """Sample a batch from the buffer."""
193 | idxs = jax.random.randint(rng, (batch_size,), 0, len(buffer.obs))
194 | return jax.tree_map(lambda x: x[idxs], buffer)
195 |
196 |
197 | r"""
198 | __/)
199 | .-(__(=:
200 | |\ | \)
201 | \ ||
202 | \||
203 | \|
204 | ___|_____
205 | \ /
206 | \ /
207 | \___/ Agent
208 | """
209 |
210 |
211 | def make_train_step(
212 | args, actor_apply_fn, q_apply_fn, alpha_apply_fn, dataset, rollout_fn
213 | ):
214 | """Make JIT-compatible agent train step with model-based rollouts."""
215 |
216 | def _train_step(runner_state, _):
217 | rng, agent_state, rollout_buffer = runner_state
218 |
219 | # --- Update model buffer ---
220 | params = agent_state.actor.params
221 | policy_fn = lambda obs, rng: actor_apply_fn(params, obs).sample(seed=rng)
222 | rng, rng_buffer = jax.random.split(rng)
223 | rollout_buffer = jax.lax.cond(
224 | agent_state.actor.step % args.rollout_interval == 0,
225 | lambda: rollout_fn(rng_buffer, policy_fn, rollout_buffer),
226 | lambda: rollout_buffer,
227 | )
228 |
229 | # --- Sample batch ---
230 | rng, rng_dataset, rng_rollout = jax.random.split(rng, 3)
231 | dataset_size = int(args.batch_size * args.dataset_sample_ratio)
232 | rollout_size = args.batch_size - dataset_size
233 | dataset_batch = sample_from_buffer(dataset, dataset_size, rng_dataset)
234 | rollout_batch = sample_from_buffer(rollout_buffer, rollout_size, rng_rollout)
235 | batch = jax.tree_map(
236 | lambda x, y: jnp.concatenate([x, y]), dataset_batch, rollout_batch
237 | )
238 |
239 | # --- Update alpha ---
240 | @jax.value_and_grad
241 | def _alpha_loss_fn(params, rng):
242 | def _compute_entropy(rng, transition):
243 | pi = actor_apply_fn(agent_state.actor.params, transition.obs)
244 | _, log_pi = pi.sample_and_log_prob(seed=rng)
245 | return -log_pi.sum()
246 |
247 | log_alpha = alpha_apply_fn(params)
248 | rng = jax.random.split(rng, args.batch_size)
249 | entropy = jax.vmap(_compute_entropy)(rng, batch).mean()
250 | target_entropy = -batch.action.shape[-1]
251 | return log_alpha * (entropy - target_entropy)
252 |
253 | rng, rng_alpha = jax.random.split(rng)
254 | alpha_loss, alpha_grad = _alpha_loss_fn(agent_state.alpha.params, rng_alpha)
255 | updated_alpha = agent_state.alpha.apply_gradients(grads=alpha_grad)
256 | agent_state = agent_state._replace(alpha=updated_alpha)
257 | alpha = jnp.exp(alpha_apply_fn(agent_state.alpha.params))
258 |
259 | # --- Update actor ---
260 | @partial(jax.value_and_grad, has_aux=True)
261 | def _actor_loss_function(params, rng):
262 | def _compute_loss(rng, transition):
263 | pi = actor_apply_fn(params, transition.obs)
264 | sampled_action, log_pi = pi.sample_and_log_prob(seed=rng)
265 | log_pi = log_pi.sum()
266 | q_values = q_apply_fn(
267 | agent_state.vec_q.params, transition.obs, sampled_action
268 | )
269 | q_min = jnp.min(q_values)
270 | return -q_min + alpha * log_pi, -log_pi, q_min, q_values.std()
271 |
272 | rng = jax.random.split(rng, args.batch_size)
273 | loss, entropy, q_min, q_std = jax.vmap(_compute_loss)(rng, batch)
274 | return loss.mean(), (entropy.mean(), q_min.mean(), q_std.mean())
275 |
276 | rng, rng_actor = jax.random.split(rng)
277 | (actor_loss, (entropy, q_min, q_std)), actor_grad = _actor_loss_function(
278 | agent_state.actor.params, rng_actor
279 | )
280 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad)
281 | agent_state = agent_state._replace(actor=updated_actor)
282 |
283 | # --- Update Q target network ---
284 | updated_q_target_params = optax.incremental_update(
285 | agent_state.vec_q.params,
286 | agent_state.vec_q_target.params,
287 | args.polyak_step_size,
288 | )
289 | updated_q_target = agent_state.vec_q_target.replace(
290 | step=agent_state.vec_q_target.step + 1, params=updated_q_target_params
291 | )
292 | agent_state = agent_state._replace(vec_q_target=updated_q_target)
293 |
294 | # --- Compute targets ---
295 | def _sample_next_v(rng, transition):
296 | next_pi = actor_apply_fn(agent_state.actor.params, transition.next_obs)
297 | # Note: Important to use sample_and_log_prob here for numerical stability
298 | # See https://github.com/deepmind/distrax/issues/7
299 | next_action, log_next_pi = next_pi.sample_and_log_prob(seed=rng)
300 | # Minimum of the target Q-values
301 | next_q = q_apply_fn(
302 | agent_state.vec_q_target.params, transition.next_obs, next_action
303 | )
304 | return next_q.min(-1) - alpha * log_next_pi.sum(-1)
305 |
306 | rng, rng_next_v = jax.random.split(rng)
307 | rng_next_v = jax.random.split(rng_next_v, args.batch_size)
308 | next_v_target = jax.vmap(_sample_next_v)(rng_next_v, batch)
309 | target = batch.reward + args.gamma * (1 - batch.done) * next_v_target
310 |
311 | # --- Sample actions for CQL ---
312 | def _sample_actions(rng, obs):
313 | pi = actor_apply_fn(agent_state.actor.params, obs)
314 | return pi.sample(seed=rng)
315 |
316 | rng, rng_pi, rng_next = jax.random.split(rng, 3)
317 | pi_actions = _sample_actions(rng_pi, batch.obs)
318 | pi_next_actions = _sample_actions(rng_next, batch.next_obs)
319 | rng, rng_random = jax.random.split(rng)
320 | cql_random_actions = jax.random.uniform(
321 | rng_random, shape=batch.action.shape, minval=-1.0, maxval=1.0
322 | )
323 |
324 | # --- Update critics ---
325 | @jax.value_and_grad
326 | def _q_loss_fn(params):
327 | q_pred = q_apply_fn(params, batch.obs, batch.action)
328 | critic_loss = jnp.square((q_pred - jnp.expand_dims(target, -1)))
329 | critic_loss = critic_loss.sum(-1).mean()
330 | q_pred_combo = q_apply_fn(params, dataset_batch.obs, dataset_batch.action)
331 | q_pred_combo = q_pred_combo.mean()
332 |
333 | rand_q = q_apply_fn(params, batch.obs, cql_random_actions)
334 | pi_q = q_apply_fn(params, batch.obs, pi_actions)
335 | # Note: Source implementation erroneously uses current obs in next_pi_q
336 | next_pi_q = q_apply_fn(params, batch.next_obs, pi_next_actions)
337 | all_qs = jnp.concatenate([rand_q, pi_q, next_pi_q, q_pred], axis=1)
338 | q_ood = jax.scipy.special.logsumexp(all_qs / args.cql_temperature, axis=1)
339 | q_ood = jax.lax.stop_gradient(q_ood * args.cql_temperature)
340 | q_diff = (jnp.expand_dims(q_ood, 1) - q_pred_combo).mean()
341 | min_q_loss = q_diff * args.cql_min_q_weight
342 |
343 | critic_loss += min_q_loss.mean()
344 | return critic_loss
345 |
346 | critic_loss, critic_grad = _q_loss_fn(agent_state.vec_q.params)
347 | updated_q = agent_state.vec_q.apply_gradients(grads=critic_grad)
348 | agent_state = agent_state._replace(vec_q=updated_q)
349 |
350 | num_done = jnp.sum(batch.done)
351 | loss = {
352 | "critic_loss": critic_loss,
353 | "actor_loss": actor_loss,
354 | "alpha_loss": alpha_loss,
355 | "entropy": entropy,
356 | "alpha": alpha,
357 | "q_min": q_min,
358 | "q_std": q_std,
359 | "terminations/num_done": num_done,
360 | "terminations/done_ratio": num_done / batch.done.shape[0],
361 | }
362 | return (rng, agent_state, rollout_buffer), loss
363 |
364 | return _train_step
365 |
366 |
367 | if __name__ == "__main__":
368 | # --- Parse arguments ---
369 | args = tyro.cli(Args)
370 | rng = jax.random.PRNGKey(args.seed)
371 |
372 | # --- Initialize logger ---
373 | if args.log:
374 | wandb.init(
375 | config=args,
376 | project=args.wandb_project,
377 | entity=args.wandb_team,
378 | group=args.wandb_group,
379 | job_type="train_agent",
380 | )
381 |
382 | # --- Initialize environment and dataset ---
383 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers)
384 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset))
385 | dataset = Transition(
386 | obs=jnp.array(dataset["observations"]),
387 | action=jnp.array(dataset["actions"]),
388 | reward=jnp.array(dataset["rewards"]),
389 | next_obs=jnp.array(dataset["next_observations"]),
390 | done=jnp.array(dataset["terminals"]),
391 | next_action=jnp.roll(dataset["actions"], -1, axis=0),
392 | )
393 |
394 | # --- Initialize agent and value networks ---
395 | num_actions = env.single_action_space.shape[0]
396 | dummy_obs = jnp.zeros(env.single_observation_space.shape)
397 | dummy_action = jnp.zeros(num_actions)
398 | actor_net = TanhGaussianActor(num_actions)
399 | q_net = VectorQ(args.num_critics)
400 | alpha_net = EntropyCoef()
401 |
402 | # Target networks share seeds to match initialization
403 | rng, rng_actor, rng_q, rng_alpha = jax.random.split(rng, 4)
404 | actor_lr = args.actor_lr if args.actor_lr is not None else args.lr
405 | agent_state = AgentTrainState(
406 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs], actor_lr),
407 | vec_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
408 | vec_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
409 | alpha=create_train_state(args, rng_alpha, alpha_net, []),
410 | )
411 |
412 | # --- Initialize buffer and rollout function ---
413 | assert args.model_path, "Model path must be provided for model-based methods"
414 | dynamics_model = load_dynamics_model(args.model_path)
415 | dynamics_model.dataset = dataset
416 | max_buffer_size = args.rollout_batch_size * args.rollout_length
417 | max_buffer_size *= args.model_retain_epochs
418 | rollout_buffer = jax.tree_map(
419 | lambda x: jnp.zeros((max_buffer_size, *x.shape[1:])),
420 | dataset,
421 | )
422 | rollout_fn = dynamics_model.make_rollout_fn(
423 | batch_size=args.rollout_batch_size, rollout_length=args.rollout_length
424 | )
425 |
426 | # --- Make train step ---
427 | _agent_train_step_fn = make_train_step(
428 | args, actor_net.apply, q_net.apply, alpha_net.apply, dataset, rollout_fn
429 | )
430 |
431 | num_evals = args.num_updates // args.eval_interval
432 | for eval_idx in range(num_evals):
433 | # --- Execute train loop ---
434 | (rng, agent_state, rollout_buffer), loss = jax.lax.scan(
435 | _agent_train_step_fn,
436 | (rng, agent_state, rollout_buffer),
437 | None,
438 | args.eval_interval,
439 | )
440 |
441 | # --- Evaluate agent ---
442 | rng, rng_eval = jax.random.split(rng)
443 | returns = eval_agent(args, rng_eval, env, agent_state)
444 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0
445 |
446 | # --- Log metrics ---
447 | step = (eval_idx + 1) * args.eval_interval
448 | print("Step:", step, f"\t Score: {scores.mean():.2f}")
449 | if args.log:
450 | log_dict = {
451 | "return": returns.mean(),
452 | "score": scores.mean(),
453 | "score_std": scores.std(),
454 | "num_updates": step,
455 | **{k: loss[k][-1] for k in loss},
456 | }
457 | wandb.log(log_dict)
458 |
459 | # --- Evaluate final agent ---
460 | if args.eval_final_episodes > 0:
461 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers))
462 | print(f"Evaluating final agent for {final_iters} iterations...")
463 | _rng = jax.random.split(rng, final_iters)
464 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng])
465 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0
466 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()}
467 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores")
468 |
469 | # --- Write final returns to file ---
470 | os.makedirs("final_returns", exist_ok=True)
471 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
472 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz"
473 | with open(os.path.join("final_returns", filename), "wb") as f:
474 | onp.savez_compressed(f, **info, args=asdict(args))
475 |
476 | if args.log:
477 | wandb.save(os.path.join("final_returns", filename))
478 |
479 | if args.log:
480 | wandb.finish()
481 |
--------------------------------------------------------------------------------
/algorithms/cql.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from dataclasses import dataclass, asdict
3 | from datetime import datetime
4 | from functools import partial
5 | import os
6 | import warnings
7 |
8 | import distrax
9 | import d4rl
10 | import flax.linen as nn
11 | from flax.linen.initializers import constant, uniform
12 | from flax.training.train_state import TrainState
13 | import gym
14 | import jax
15 | import jax.numpy as jnp
16 | import numpy as onp
17 | import optax
18 | import tyro
19 | import wandb
20 |
21 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
22 |
23 |
24 | @dataclass
25 | class Args:
26 | # --- Experiment ---
27 | seed: int = 0
28 | dataset: str = "halfcheetah-medium-v2"
29 | algorithm: str = "cql"
30 | num_updates: int = 1_000_000
31 | eval_interval: int = 2500
32 | eval_workers: int = 8
33 | eval_final_episodes: int = 1000
34 | # --- Logging ---
35 | log: bool = False
36 | wandb_project: str = "unifloral"
37 | wandb_team: str = "flair"
38 | wandb_group: str = "debug"
39 | # --- Generic optimization ---
40 | lr: float = 3e-4
41 | batch_size: int = 256
42 | gamma: float = 0.99
43 | polyak_step_size: float = 0.005
44 | # --- SAC-N ---
45 | num_critics: int = 10
46 | # --- CQL---
47 | actor_lr: float = 3e-5
48 | cql_temperature: float = 1.0
49 | cql_min_q_weight: float = 10.0
50 |
51 |
52 | r"""
53 | |\ __
54 | \| /_/
55 | \|
56 | ___|_____
57 | \ /
58 | \ /
59 | \___/ Preliminaries
60 | """
61 |
62 | AgentTrainState = namedtuple("AgentTrainState", "actor vec_q vec_q_target alpha")
63 | Transition = namedtuple("Transition", "obs action reward next_obs done")
64 |
65 |
66 | def sym(scale):
67 | def _init(*args, **kwargs):
68 | return uniform(2 * scale)(*args, **kwargs) - scale
69 |
70 | return _init
71 |
72 |
73 | class SoftQNetwork(nn.Module):
74 | @nn.compact
75 | def __call__(self, obs, action):
76 | x = jnp.concatenate([obs, action], axis=-1)
77 | for _ in range(3):
78 | x = nn.Dense(256, bias_init=constant(0.1))(x)
79 | x = nn.relu(x)
80 | q = nn.Dense(1, kernel_init=sym(3e-3), bias_init=sym(3e-3))(x)
81 | return q.squeeze(-1)
82 |
83 |
84 | class VectorQ(nn.Module):
85 | num_critics: int
86 |
87 | @nn.compact
88 | def __call__(self, obs, action):
89 | vmap_critic = nn.vmap(
90 | SoftQNetwork,
91 | variable_axes={"params": 0}, # Parameters not shared between critics
92 | split_rngs={"params": True, "dropout": True}, # Different initializations
93 | in_axes=None,
94 | out_axes=-1,
95 | axis_size=self.num_critics,
96 | )
97 | q_values = vmap_critic()(obs, action)
98 | return q_values
99 |
100 |
101 | class TanhGaussianActor(nn.Module):
102 | num_actions: int
103 | log_std_max: float = 2.0
104 | log_std_min: float = -5.0
105 |
106 | @nn.compact
107 | def __call__(self, x):
108 | for _ in range(3):
109 | x = nn.Dense(256, bias_init=constant(0.1))(x)
110 | x = nn.relu(x)
111 | log_std = nn.Dense(
112 | self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3)
113 | )(x)
114 | std = jnp.exp(jnp.clip(log_std, self.log_std_min, self.log_std_max))
115 | mean = nn.Dense(self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3))(x)
116 | pi = distrax.Transformed(
117 | distrax.Normal(mean, std),
118 | distrax.Tanh(),
119 | )
120 | return pi
121 |
122 |
123 | class EntropyCoef(nn.Module):
124 | ent_coef_init: float = 1.0
125 |
126 | @nn.compact
127 | def __call__(self) -> jnp.ndarray:
128 | log_ent_coef = self.param(
129 | "log_ent_coef",
130 | init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)),
131 | )
132 | return log_ent_coef
133 |
134 |
135 | def create_train_state(args, rng, network, dummy_input, lr=None):
136 | return TrainState.create(
137 | apply_fn=network.apply,
138 | params=network.init(rng, *dummy_input),
139 | tx=optax.adam(lr if lr is not None else args.lr, eps=1e-5),
140 | )
141 |
142 |
143 | def eval_agent(args, rng, env, agent_state):
144 | # --- Reset environment ---
145 | step = 0
146 | returned = onp.zeros(args.eval_workers).astype(bool)
147 | cum_reward = onp.zeros(args.eval_workers)
148 | rng, rng_reset = jax.random.split(rng)
149 | rng_reset = jax.random.split(rng_reset, args.eval_workers)
150 | obs = env.reset()
151 |
152 | # --- Rollout agent ---
153 | @jax.jit
154 | @jax.vmap
155 | def _policy_step(rng, obs):
156 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs)
157 | action = pi.sample(seed=rng)
158 | return jnp.nan_to_num(action)
159 |
160 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps
161 | while step < max_episode_steps and not returned.all():
162 | # --- Take step in environment ---
163 | step += 1
164 | rng, rng_step = jax.random.split(rng)
165 | rng_step = jax.random.split(rng_step, args.eval_workers)
166 | action = _policy_step(rng_step, jnp.array(obs))
167 | obs, reward, done, info = env.step(onp.array(action))
168 |
169 | # --- Track cumulative reward ---
170 | cum_reward += reward * ~returned
171 | returned |= done
172 |
173 | if step >= max_episode_steps and not returned.all():
174 | warnings.warn("Maximum steps reached before all episodes terminated")
175 | return cum_reward
176 |
177 |
178 | r"""
179 | __/)
180 | .-(__(=:
181 | |\ | \)
182 | \ ||
183 | \||
184 | \|
185 | ___|_____
186 | \ /
187 | \ /
188 | \___/ Agent
189 | """
190 |
191 |
192 | def make_train_step(args, actor_apply_fn, q_apply_fn, alpha_apply_fn, dataset):
193 | """Make JIT-compatible agent train step."""
194 |
195 | def _train_step(runner_state, _):
196 | rng, agent_state = runner_state
197 |
198 | # --- Sample batch ---
199 | rng, rng_batch = jax.random.split(rng)
200 | batch_indices = jax.random.randint(
201 | rng_batch, (args.batch_size,), 0, len(dataset.obs)
202 | )
203 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset)
204 |
205 | # --- Update alpha ---
206 | @jax.value_and_grad
207 | def _alpha_loss_fn(params, rng):
208 | def _compute_entropy(rng, transition):
209 | pi = actor_apply_fn(agent_state.actor.params, transition.obs)
210 | _, log_pi = pi.sample_and_log_prob(seed=rng)
211 | return -log_pi.sum()
212 |
213 | log_alpha = alpha_apply_fn(params)
214 | rng = jax.random.split(rng, args.batch_size)
215 | entropy = jax.vmap(_compute_entropy)(rng, batch).mean()
216 | target_entropy = -batch.action.shape[-1]
217 | return log_alpha * (entropy - target_entropy)
218 |
219 | rng, rng_alpha = jax.random.split(rng)
220 | alpha_loss, alpha_grad = _alpha_loss_fn(agent_state.alpha.params, rng_alpha)
221 | updated_alpha = agent_state.alpha.apply_gradients(grads=alpha_grad)
222 | agent_state = agent_state._replace(alpha=updated_alpha)
223 | alpha = jnp.exp(alpha_apply_fn(agent_state.alpha.params))
224 |
225 | # --- Update actor ---
226 | @partial(jax.value_and_grad, has_aux=True)
227 | def _actor_loss_function(params, rng):
228 | def _compute_loss(rng, transition):
229 | pi = actor_apply_fn(params, transition.obs)
230 | sampled_action, log_pi = pi.sample_and_log_prob(seed=rng)
231 | log_pi = log_pi.sum()
232 | q_values = q_apply_fn(
233 | agent_state.vec_q.params, transition.obs, sampled_action
234 | )
235 | q_min = jnp.min(q_values)
236 | return -q_min + alpha * log_pi, -log_pi, q_min, q_values.std()
237 |
238 | rng = jax.random.split(rng, args.batch_size)
239 | loss, entropy, q_min, q_std = jax.vmap(_compute_loss)(rng, batch)
240 | return loss.mean(), (entropy.mean(), q_min.mean(), q_std.mean())
241 |
242 | rng, rng_actor = jax.random.split(rng)
243 | (actor_loss, (entropy, q_min, q_std)), actor_grad = _actor_loss_function(
244 | agent_state.actor.params, rng_actor
245 | )
246 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad)
247 | agent_state = agent_state._replace(actor=updated_actor)
248 |
249 | # --- Update Q target network ---
250 | updated_q_target_params = optax.incremental_update(
251 | agent_state.vec_q.params,
252 | agent_state.vec_q_target.params,
253 | args.polyak_step_size,
254 | )
255 | updated_q_target = agent_state.vec_q_target.replace(
256 | step=agent_state.vec_q_target.step + 1, params=updated_q_target_params
257 | )
258 | agent_state = agent_state._replace(vec_q_target=updated_q_target)
259 |
260 | # --- Compute targets ---
261 | def _sample_next_v(rng, transition):
262 | next_pi = actor_apply_fn(agent_state.actor.params, transition.next_obs)
263 | # Note: Important to use sample_and_log_prob here for numerical stability
264 | # See https://github.com/deepmind/distrax/issues/7
265 | next_action, log_next_pi = next_pi.sample_and_log_prob(seed=rng)
266 | # Minimum of the target Q-values
267 | next_q = q_apply_fn(
268 | agent_state.vec_q_target.params, transition.next_obs, next_action
269 | )
270 | return next_q.min(-1) - alpha * log_next_pi.sum(-1)
271 |
272 | rng, rng_next_v = jax.random.split(rng)
273 | rng_next_v = jax.random.split(rng_next_v, args.batch_size)
274 | next_v_target = jax.vmap(_sample_next_v)(rng_next_v, batch)
275 | target = batch.reward + args.gamma * (1 - batch.done) * next_v_target
276 |
277 | # --- Sample actions for CQL ---
278 | def _sample_actions(rng, obs):
279 | pi = actor_apply_fn(agent_state.actor.params, obs)
280 | return pi.sample(seed=rng)
281 |
282 | rng, rng_pi, rng_next = jax.random.split(rng, 3)
283 | pi_actions = _sample_actions(rng_pi, batch.obs)
284 | pi_next_actions = _sample_actions(rng_next, batch.next_obs)
285 | rng, rng_random = jax.random.split(rng)
286 | cql_random_actions = jax.random.uniform(
287 | rng_random, shape=batch.action.shape, minval=-1.0, maxval=1.0
288 | )
289 |
290 | # --- Update critics ---
291 | @jax.value_and_grad
292 | def _q_loss_fn(params):
293 | q_pred = q_apply_fn(params, batch.obs, batch.action)
294 | critic_loss = jnp.square((q_pred - jnp.expand_dims(target, -1)))
295 | critic_loss = critic_loss.sum(-1).mean()
296 |
297 | rand_q = q_apply_fn(params, batch.obs, cql_random_actions)
298 | pi_q = q_apply_fn(params, batch.obs, pi_actions)
299 | # Note: Source implementation erroneously uses current obs in next_pi_q
300 | next_pi_q = q_apply_fn(params, batch.next_obs, pi_next_actions)
301 | all_qs = jnp.concatenate([rand_q, pi_q, next_pi_q, q_pred], axis=1)
302 | q_ood = jax.scipy.special.logsumexp(all_qs / args.cql_temperature, axis=1)
303 | q_ood = jax.lax.stop_gradient(q_ood * args.cql_temperature)
304 | q_diff = (jnp.expand_dims(q_ood, 1) - q_pred).mean()
305 | min_q_loss = q_diff * args.cql_min_q_weight
306 |
307 | critic_loss += min_q_loss.mean()
308 | return critic_loss
309 |
310 | critic_loss, critic_grad = _q_loss_fn(agent_state.vec_q.params)
311 | updated_q = agent_state.vec_q.apply_gradients(grads=critic_grad)
312 | agent_state = agent_state._replace(vec_q=updated_q)
313 |
314 | loss = {
315 | "critic_loss": critic_loss,
316 | "actor_loss": actor_loss,
317 | "alpha_loss": alpha_loss,
318 | "entropy": entropy,
319 | "alpha": alpha,
320 | "q_min": q_min,
321 | "q_std": q_std,
322 | }
323 | return (rng, agent_state), loss
324 |
325 | return _train_step
326 |
327 |
328 | if __name__ == "__main__":
329 | # --- Parse arguments ---
330 | args = tyro.cli(Args)
331 | rng = jax.random.PRNGKey(args.seed)
332 |
333 | # --- Initialize logger ---
334 | if args.log:
335 | wandb.init(
336 | config=args,
337 | project=args.wandb_project,
338 | entity=args.wandb_team,
339 | group=args.wandb_group,
340 | job_type="train_agent",
341 | )
342 |
343 | # --- Initialize environment and dataset ---
344 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers)
345 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset))
346 | dataset = Transition(
347 | obs=jnp.array(dataset["observations"]),
348 | action=jnp.array(dataset["actions"]),
349 | reward=jnp.array(dataset["rewards"]),
350 | next_obs=jnp.array(dataset["next_observations"]),
351 | done=jnp.array(dataset["terminals"]),
352 | )
353 |
354 | # --- Initialize agent and value networks ---
355 | num_actions = env.single_action_space.shape[0]
356 | dummy_obs = jnp.zeros(env.single_observation_space.shape)
357 | dummy_action = jnp.zeros(num_actions)
358 | actor_net = TanhGaussianActor(num_actions)
359 | q_net = VectorQ(args.num_critics)
360 | alpha_net = EntropyCoef()
361 |
362 | # Target networks share seeds to match initialization
363 | rng, rng_actor, rng_q, rng_alpha = jax.random.split(rng, 4)
364 | actor_lr = args.actor_lr if args.actor_lr is not None else args.lr
365 | agent_state = AgentTrainState(
366 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs], actor_lr),
367 | vec_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
368 | vec_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
369 | alpha=create_train_state(args, rng_alpha, alpha_net, []),
370 | )
371 |
372 | # --- Make train step ---
373 | _agent_train_step_fn = make_train_step(
374 | args, actor_net.apply, q_net.apply, alpha_net.apply, dataset
375 | )
376 |
377 | num_evals = args.num_updates // args.eval_interval
378 | for eval_idx in range(num_evals):
379 | # --- Execute train loop ---
380 | (rng, agent_state), loss = jax.lax.scan(
381 | _agent_train_step_fn,
382 | (rng, agent_state),
383 | None,
384 | args.eval_interval,
385 | )
386 |
387 | # --- Evaluate agent ---
388 | rng, rng_eval = jax.random.split(rng)
389 | returns = eval_agent(args, rng_eval, env, agent_state)
390 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0
391 |
392 | # --- Log metrics ---
393 | step = (eval_idx + 1) * args.eval_interval
394 | print("Step:", step, f"\t Score: {scores.mean():.2f}")
395 | if args.log:
396 | log_dict = {
397 | "return": returns.mean(),
398 | "score": scores.mean(),
399 | "score_std": scores.std(),
400 | "num_updates": step,
401 | **{k: loss[k][-1] for k in loss},
402 | }
403 | wandb.log(log_dict)
404 |
405 | # --- Evaluate final agent ---
406 | if args.eval_final_episodes > 0:
407 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers))
408 | print(f"Evaluating final agent for {final_iters} iterations...")
409 | _rng = jax.random.split(rng, final_iters)
410 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng])
411 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0
412 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()}
413 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores")
414 |
415 | # --- Write final returns to file ---
416 | os.makedirs("final_returns", exist_ok=True)
417 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
418 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz"
419 | with open(os.path.join("final_returns", filename), "wb") as f:
420 | onp.savez_compressed(f, **info, args=asdict(args))
421 |
422 | if args.log:
423 | wandb.save(os.path.join("final_returns", filename))
424 |
425 | if args.log:
426 | wandb.finish()
427 |
--------------------------------------------------------------------------------
/algorithms/edac.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from dataclasses import dataclass, asdict
3 | from datetime import datetime
4 | from functools import partial
5 | import os
6 | import warnings
7 |
8 | import distrax
9 | import d4rl
10 | import flax.linen as nn
11 | from flax.linen.initializers import constant, uniform
12 | from flax.training.train_state import TrainState
13 | import gym
14 | import jax
15 | import jax.numpy as jnp
16 | import numpy as onp
17 | import optax
18 | import tyro
19 | import wandb
20 |
21 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
22 |
23 |
24 | @dataclass
25 | class Args:
26 | # --- Experiment ---
27 | seed: int = 0
28 | dataset: str = "halfcheetah-medium-v2"
29 | algorithm: str = "edac"
30 | num_updates: int = 3_000_000
31 | eval_interval: int = 2500
32 | eval_workers: int = 8
33 | eval_final_episodes: int = 1000
34 | # --- Logging ---
35 | log: bool = False
36 | wandb_project: str = "unifloral"
37 | wandb_team: str = "flair"
38 | wandb_group: str = "debug"
39 | # --- Generic optimization ---
40 | lr: float = 3e-4
41 | batch_size: int = 256
42 | gamma: float = 0.99
43 | polyak_step_size: float = 0.005
44 | # --- SAC-N ---
45 | num_critics: int = 10
46 | # --- EDAC ---
47 | eta: float = 1.0
48 |
49 |
50 | r"""
51 | |\ __
52 | \| /_/
53 | \|
54 | ___|_____
55 | \ /
56 | \ /
57 | \___/ Preliminaries
58 | """
59 |
60 | AgentTrainState = namedtuple("AgentTrainState", "actor vec_q vec_q_target alpha")
61 | Transition = namedtuple("Transition", "obs action reward next_obs done")
62 |
63 |
64 | def sym(scale):
65 | def _init(*args, **kwargs):
66 | return uniform(2 * scale)(*args, **kwargs) - scale
67 |
68 | return _init
69 |
70 |
71 | class SoftQNetwork(nn.Module):
72 | @nn.compact
73 | def __call__(self, obs, action):
74 | x = jnp.concatenate([obs, action], axis=-1)
75 | for _ in range(3):
76 | x = nn.Dense(256, bias_init=constant(0.1))(x)
77 | x = nn.relu(x)
78 | q = nn.Dense(1, kernel_init=sym(3e-3), bias_init=sym(3e-3))(x)
79 | return q.squeeze(-1)
80 |
81 |
82 | class VectorQ(nn.Module):
83 | num_critics: int
84 |
85 | @nn.compact
86 | def __call__(self, obs, action):
87 | vmap_critic = nn.vmap(
88 | SoftQNetwork,
89 | variable_axes={"params": 0}, # Parameters not shared between critics
90 | split_rngs={"params": True, "dropout": True}, # Different initializations
91 | in_axes=None,
92 | out_axes=-1,
93 | axis_size=self.num_critics,
94 | )
95 | q_values = vmap_critic()(obs, action)
96 | return q_values
97 |
98 |
99 | class TanhGaussianActor(nn.Module):
100 | num_actions: int
101 | log_std_max: float = 2.0
102 | log_std_min: float = -5.0
103 |
104 | @nn.compact
105 | def __call__(self, x):
106 | for _ in range(3):
107 | x = nn.Dense(256, bias_init=constant(0.1))(x)
108 | x = nn.relu(x)
109 | log_std = nn.Dense(
110 | self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3)
111 | )(x)
112 | std = jnp.exp(jnp.clip(log_std, self.log_std_min, self.log_std_max))
113 | mean = nn.Dense(self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3))(x)
114 | pi = distrax.Transformed(
115 | distrax.Normal(mean, std),
116 | distrax.Tanh(),
117 | )
118 | return pi
119 |
120 |
121 | class EntropyCoef(nn.Module):
122 | ent_coef_init: float = 1.0
123 |
124 | @nn.compact
125 | def __call__(self) -> jnp.ndarray:
126 | log_ent_coef = self.param(
127 | "log_ent_coef",
128 | init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)),
129 | )
130 | return log_ent_coef
131 |
132 |
133 | def create_train_state(args, rng, network, dummy_input):
134 | return TrainState.create(
135 | apply_fn=network.apply,
136 | params=network.init(rng, *dummy_input),
137 | tx=optax.adam(args.lr, eps=1e-5),
138 | )
139 |
140 |
141 | def eval_agent(args, rng, env, agent_state):
142 | # --- Reset environment ---
143 | step = 0
144 | returned = onp.zeros(args.eval_workers).astype(bool)
145 | cum_reward = onp.zeros(args.eval_workers)
146 | rng, rng_reset = jax.random.split(rng)
147 | rng_reset = jax.random.split(rng_reset, args.eval_workers)
148 | obs = env.reset()
149 |
150 | # --- Rollout agent ---
151 | @jax.jit
152 | @jax.vmap
153 | def _policy_step(rng, obs):
154 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs)
155 | action = pi.sample(seed=rng)
156 | return jnp.nan_to_num(action)
157 |
158 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps
159 | while step < max_episode_steps and not returned.all():
160 | # --- Take step in environment ---
161 | step += 1
162 | rng, rng_step = jax.random.split(rng)
163 | rng_step = jax.random.split(rng_step, args.eval_workers)
164 | action = _policy_step(rng_step, jnp.array(obs))
165 | obs, reward, done, info = env.step(onp.array(action))
166 |
167 | # --- Track cumulative reward ---
168 | cum_reward += reward * ~returned
169 | returned |= done
170 |
171 | if step >= max_episode_steps and not returned.all():
172 | warnings.warn("Maximum steps reached before all episodes terminated")
173 | return cum_reward
174 |
175 |
176 | r"""
177 | __/)
178 | .-(__(=:
179 | |\ | \)
180 | \ ||
181 | \||
182 | \|
183 | ___|_____
184 | \ /
185 | \ /
186 | \___/ Agent
187 | """
188 |
189 |
190 | def make_train_step(args, actor_apply_fn, q_apply_fn, alpha_apply_fn, dataset):
191 | """Make JIT-compatible agent train step."""
192 |
193 | def _train_step(runner_state, _):
194 | rng, agent_state = runner_state
195 |
196 | # --- Sample batch ---
197 | rng, rng_batch = jax.random.split(rng)
198 | batch_indices = jax.random.randint(
199 | rng_batch, (args.batch_size,), 0, len(dataset.obs)
200 | )
201 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset)
202 |
203 | # --- Update alpha ---
204 | @jax.value_and_grad
205 | def _alpha_loss_fn(params, rng):
206 | def _compute_entropy(rng, transition):
207 | pi = actor_apply_fn(agent_state.actor.params, transition.obs)
208 | _, log_pi = pi.sample_and_log_prob(seed=rng)
209 | return -log_pi.sum()
210 |
211 | log_alpha = alpha_apply_fn(params)
212 | rng = jax.random.split(rng, args.batch_size)
213 | entropy = jax.vmap(_compute_entropy)(rng, batch).mean()
214 | target_entropy = -batch.action.shape[-1]
215 | return log_alpha * (entropy - target_entropy)
216 |
217 | rng, rng_alpha = jax.random.split(rng)
218 | alpha_loss, alpha_grad = _alpha_loss_fn(agent_state.alpha.params, rng_alpha)
219 | updated_alpha = agent_state.alpha.apply_gradients(grads=alpha_grad)
220 | agent_state = agent_state._replace(alpha=updated_alpha)
221 | alpha = jnp.exp(alpha_apply_fn(agent_state.alpha.params))
222 |
223 | # --- Update actor ---
224 | @partial(jax.value_and_grad, has_aux=True)
225 | def _actor_loss_function(params, rng):
226 | def _compute_loss(rng, transition):
227 | pi = actor_apply_fn(params, transition.obs)
228 | sampled_action, log_pi = pi.sample_and_log_prob(seed=rng)
229 | log_pi = log_pi.sum()
230 | q_values = q_apply_fn(
231 | agent_state.vec_q.params, transition.obs, sampled_action
232 | )
233 | q_min = jnp.min(q_values)
234 | return -q_min + alpha * log_pi, -log_pi, q_min, q_values.std()
235 |
236 | rng = jax.random.split(rng, args.batch_size)
237 | loss, entropy, q_min, q_std = jax.vmap(_compute_loss)(rng, batch)
238 | return loss.mean(), (entropy.mean(), q_min.mean(), q_std.mean())
239 |
240 | rng, rng_actor = jax.random.split(rng)
241 | (actor_loss, (entropy, q_min, q_std)), actor_grad = _actor_loss_function(
242 | agent_state.actor.params, rng_actor
243 | )
244 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad)
245 | agent_state = agent_state._replace(actor=updated_actor)
246 |
247 | # --- Update Q target network ---
248 | updated_q_target_params = optax.incremental_update(
249 | agent_state.vec_q.params,
250 | agent_state.vec_q_target.params,
251 | args.polyak_step_size,
252 | )
253 | updated_q_target = agent_state.vec_q_target.replace(
254 | step=agent_state.vec_q_target.step + 1, params=updated_q_target_params
255 | )
256 | agent_state = agent_state._replace(vec_q_target=updated_q_target)
257 |
258 | # --- Compute targets ---
259 | def _sample_next_v(rng, transition):
260 | next_pi = actor_apply_fn(agent_state.actor.params, transition.next_obs)
261 | # Note: Important to use sample_and_log_prob here for numerical stability
262 | # See https://github.com/deepmind/distrax/issues/7
263 | next_action, log_next_pi = next_pi.sample_and_log_prob(seed=rng)
264 | # Minimum of the target Q-values
265 | next_q = q_apply_fn(
266 | agent_state.vec_q_target.params, transition.next_obs, next_action
267 | )
268 | return next_q.min(-1) - alpha * log_next_pi.sum(-1)
269 |
270 | rng, rng_next_v = jax.random.split(rng)
271 | rng_next_v = jax.random.split(rng_next_v, args.batch_size)
272 | next_v_target = jax.vmap(_sample_next_v)(rng_next_v, batch)
273 | target = batch.reward + args.gamma * (1 - batch.done) * next_v_target
274 |
275 | # --- Update critics ---
276 | @partial(jax.value_and_grad, has_aux=True)
277 | def _q_loss_fn(params):
278 | q_pred = q_apply_fn(params, batch.obs, batch.action)
279 | value_loss = jnp.square((q_pred - jnp.expand_dims(target, -1)))
280 | value_loss = value_loss.sum(-1).mean()
281 |
282 | def _diversity_loss_fn(obs, action):
283 | action_jac = jax.jacrev(q_apply_fn, argnums=2)(params, obs, action)
284 | action_jac /= jnp.linalg.norm(action_jac, axis=-1, keepdims=True) + 1e-6
285 | div_loss = action_jac @ action_jac.T
286 | div_loss *= 1.0 - jnp.eye(args.num_critics)
287 | return div_loss.sum()
288 |
289 | diversity_loss = jax.vmap(_diversity_loss_fn)(batch.obs, batch.action)
290 | diversity_loss = diversity_loss.mean() / (args.num_critics - 1)
291 | critic_loss = value_loss + args.eta * diversity_loss
292 | return critic_loss, (value_loss, diversity_loss)
293 |
294 | (critic_loss, (value_loss, diversity_loss)), critic_grad = _q_loss_fn(
295 | agent_state.vec_q.params
296 | )
297 | updated_q = agent_state.vec_q.apply_gradients(grads=critic_grad)
298 | agent_state = agent_state._replace(vec_q=updated_q)
299 |
300 | loss = {
301 | "critic_loss": critic_loss,
302 | "value_loss": value_loss,
303 | "diversity_loss": diversity_loss,
304 | "actor_loss": actor_loss,
305 | "alpha_loss": alpha_loss,
306 | "entropy": entropy,
307 | "alpha": alpha,
308 | "q_min": q_min,
309 | "q_std": q_std,
310 | }
311 | return (rng, agent_state), loss
312 |
313 | return _train_step
314 |
315 |
316 | if __name__ == "__main__":
317 | # --- Parse arguments ---
318 | args = tyro.cli(Args)
319 | rng = jax.random.PRNGKey(args.seed)
320 |
321 | # --- Initialize logger ---
322 | if args.log:
323 | wandb.init(
324 | config=args,
325 | project=args.wandb_project,
326 | entity=args.wandb_team,
327 | group=args.wandb_group,
328 | job_type="train_agent",
329 | )
330 |
331 | # --- Initialize environment and dataset ---
332 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers)
333 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset))
334 | dataset = Transition(
335 | obs=jnp.array(dataset["observations"]),
336 | action=jnp.array(dataset["actions"]),
337 | reward=jnp.array(dataset["rewards"]),
338 | next_obs=jnp.array(dataset["next_observations"]),
339 | done=jnp.array(dataset["terminals"]),
340 | )
341 |
342 | # --- Initialize agent and value networks ---
343 | num_actions = env.single_action_space.shape[0]
344 | dummy_obs = jnp.zeros(env.single_observation_space.shape)
345 | dummy_action = jnp.zeros(num_actions)
346 | actor_net = TanhGaussianActor(num_actions)
347 | q_net = VectorQ(args.num_critics)
348 | alpha_net = EntropyCoef()
349 |
350 | # Target networks share seeds to match initialization
351 | rng, rng_actor, rng_q, rng_alpha = jax.random.split(rng, 4)
352 | agent_state = AgentTrainState(
353 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]),
354 | vec_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
355 | vec_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
356 | alpha=create_train_state(args, rng_alpha, alpha_net, []),
357 | )
358 |
359 | # --- Make train step ---
360 | _agent_train_step_fn = make_train_step(
361 | args, actor_net.apply, q_net.apply, alpha_net.apply, dataset
362 | )
363 |
364 | num_evals = args.num_updates // args.eval_interval
365 | for eval_idx in range(num_evals):
366 | # --- Execute train loop ---
367 | (rng, agent_state), loss = jax.lax.scan(
368 | _agent_train_step_fn,
369 | (rng, agent_state),
370 | None,
371 | args.eval_interval,
372 | )
373 |
374 | # --- Evaluate agent ---
375 | rng, rng_eval = jax.random.split(rng)
376 | returns = eval_agent(args, rng_eval, env, agent_state)
377 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0
378 |
379 | # --- Log metrics ---
380 | step = (eval_idx + 1) * args.eval_interval
381 | print("Step:", step, f"\t Score: {scores.mean():.2f}")
382 | if args.log:
383 | log_dict = {
384 | "return": returns.mean(),
385 | "score": scores.mean(),
386 | "score_std": scores.std(),
387 | "num_updates": step,
388 | **{k: loss[k][-1] for k in loss},
389 | }
390 | wandb.log(log_dict)
391 |
392 | # --- Evaluate final agent ---
393 | if args.eval_final_episodes > 0:
394 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers))
395 | print(f"Evaluating final agent for {final_iters} iterations...")
396 | _rng = jax.random.split(rng, final_iters)
397 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng])
398 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0
399 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()}
400 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores")
401 |
402 | # --- Write final returns to file ---
403 | os.makedirs("final_returns", exist_ok=True)
404 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
405 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz"
406 | with open(os.path.join("final_returns", filename), "wb") as f:
407 | onp.savez_compressed(f, **info, args=asdict(args))
408 |
409 | if args.log:
410 | wandb.save(os.path.join("final_returns", filename))
411 |
412 | if args.log:
413 | wandb.finish()
414 |
--------------------------------------------------------------------------------
/algorithms/iql.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from dataclasses import dataclass, asdict
3 | from datetime import datetime
4 | from functools import partial
5 | import os
6 | import warnings
7 |
8 | import distrax
9 | import d4rl
10 | import flax.linen as nn
11 | from flax.training.train_state import TrainState
12 | import gym
13 | import jax
14 | import jax.numpy as jnp
15 | import numpy as onp
16 | import optax
17 | import tyro
18 | import wandb
19 |
20 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
21 |
22 |
23 | @dataclass
24 | class Args:
25 | # --- Experiment ---
26 | seed: int = 0
27 | dataset: str = "halfcheetah-medium-v2"
28 | algorithm: str = "iql"
29 | num_updates: int = 1_000_000
30 | eval_interval: int = 2500
31 | eval_workers: int = 8
32 | eval_final_episodes: int = 1000
33 | # --- Logging ---
34 | log: bool = False
35 | wandb_project: str = "unifloral"
36 | wandb_team: str = "flair"
37 | wandb_group: str = "debug"
38 | # --- Generic optimization ---
39 | lr: float = 3e-4
40 | batch_size: int = 256
41 | gamma: float = 0.99
42 | polyak_step_size: float = 0.005
43 | # --- IQL ---
44 | beta: float = 3.0
45 | iql_tau: float = 0.7
46 | exp_adv_clip: float = 100.0
47 |
48 |
49 | r"""
50 | |\ __
51 | \| /_/
52 | \|
53 | ___|_____
54 | \ /
55 | \ /
56 | \___/ Preliminaries
57 | """
58 |
59 | AgentTrainState = namedtuple("AgentTrainState", "actor dual_q dual_q_target value")
60 | Transition = namedtuple("Transition", "obs action reward next_obs done")
61 |
62 |
63 | class SoftQNetwork(nn.Module):
64 | obs_mean: jax.Array
65 | obs_std: jax.Array
66 |
67 | @nn.compact
68 | def __call__(self, obs, action):
69 | obs = (obs - self.obs_mean) / (self.obs_std + 1e-3)
70 | x = jnp.concatenate([obs, action], axis=-1)
71 | for _ in range(2):
72 | x = nn.Dense(256)(x)
73 | x = nn.relu(x)
74 | q = nn.Dense(1)(x)
75 | return q.squeeze(-1)
76 |
77 |
78 | class DualQNetwork(nn.Module):
79 | obs_mean: jax.Array
80 | obs_std: jax.Array
81 |
82 | @nn.compact
83 | def __call__(self, obs, action):
84 | vmap_critic = nn.vmap(
85 | SoftQNetwork,
86 | variable_axes={"params": 0}, # Parameters not shared between critics
87 | split_rngs={"params": True, "dropout": True}, # Different initializations
88 | in_axes=None,
89 | out_axes=-1,
90 | axis_size=2, # Two Q networks
91 | )
92 | q_values = vmap_critic(self.obs_mean, self.obs_std)(obs, action)
93 | return q_values
94 |
95 |
96 | class StateValueFunction(nn.Module):
97 | obs_mean: jax.Array
98 | obs_std: jax.Array
99 |
100 | @nn.compact
101 | def __call__(self, x):
102 | x = (x - self.obs_mean) / (self.obs_std + 1e-3)
103 | for _ in range(2):
104 | x = nn.Dense(256)(x)
105 | x = nn.relu(x)
106 | v = nn.Dense(1)(x)
107 | return v.squeeze(-1)
108 |
109 |
110 | class TanhGaussianActor(nn.Module):
111 | num_actions: int
112 | obs_mean: jax.Array
113 | obs_std: jax.Array
114 | log_std_max: float = 2.0
115 | log_std_min: float = -20.0
116 |
117 | @nn.compact
118 | def __call__(self, x, eval=False):
119 | x = (x - self.obs_mean) / (self.obs_std + 1e-3)
120 | for _ in range(2):
121 | x = nn.Dense(256)(x)
122 | x = nn.relu(x)
123 | x = nn.Dense(self.num_actions)(x)
124 | x = nn.tanh(x)
125 | if eval:
126 | return distrax.Deterministic(x)
127 | logstd = self.param(
128 | "logstd",
129 | init_fn=lambda key: jnp.zeros(self.num_actions, dtype=jnp.float32),
130 | )
131 | std = jnp.exp(jnp.clip(logstd, self.log_std_min, self.log_std_max))
132 | return distrax.Normal(x, std)
133 |
134 |
135 | def create_train_state(args, rng, network, dummy_input):
136 | lr_schedule = optax.cosine_decay_schedule(args.lr, args.num_updates)
137 | return TrainState.create(
138 | apply_fn=network.apply,
139 | params=network.init(rng, *dummy_input),
140 | tx=optax.adam(lr_schedule, eps=1e-5),
141 | )
142 |
143 |
144 | def eval_agent(args, rng, env, agent_state):
145 | # --- Reset environment ---
146 | step = 0
147 | returned = onp.zeros(args.eval_workers).astype(bool)
148 | cum_reward = onp.zeros(args.eval_workers)
149 | rng, rng_reset = jax.random.split(rng)
150 | rng_reset = jax.random.split(rng_reset, args.eval_workers)
151 | obs = env.reset()
152 |
153 | # --- Rollout agent ---
154 | @jax.jit
155 | @jax.vmap
156 | def _policy_step(rng, obs):
157 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs, eval=True)
158 | action = pi.sample(seed=rng)
159 | return jnp.nan_to_num(action)
160 |
161 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps
162 | while step < max_episode_steps and not returned.all():
163 | # --- Take step in environment ---
164 | step += 1
165 | rng, rng_step = jax.random.split(rng)
166 | rng_step = jax.random.split(rng_step, args.eval_workers)
167 | action = _policy_step(rng_step, jnp.array(obs))
168 | obs, reward, done, info = env.step(onp.array(action))
169 |
170 | # --- Track cumulative reward ---
171 | cum_reward += reward * ~returned
172 | returned |= done
173 |
174 | if step >= max_episode_steps and not returned.all():
175 | warnings.warn("Maximum steps reached before all episodes terminated")
176 | return cum_reward
177 |
178 |
179 | r"""
180 | __/)
181 | .-(__(=:
182 | |\ | \)
183 | \ ||
184 | \||
185 | \|
186 | ___|_____
187 | \ /
188 | \ /
189 | \___/ Agent
190 | """
191 |
192 |
193 | def make_train_step(args, actor_apply_fn, q_apply_fn, value_apply_fn, dataset):
194 | """Make JIT-compatible agent train step."""
195 |
196 | def _train_step(runner_state, _):
197 | rng, agent_state = runner_state
198 |
199 | # --- Sample batch ---
200 | rng, rng_batch = jax.random.split(rng)
201 | batch_indices = jax.random.randint(
202 | rng_batch, (args.batch_size,), 0, len(dataset.obs)
203 | )
204 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset)
205 |
206 | # --- Update Q target network ---
207 | updated_q_target_params = optax.incremental_update(
208 | agent_state.dual_q.params,
209 | agent_state.dual_q_target.params,
210 | args.polyak_step_size,
211 | )
212 | updated_q_target = agent_state.dual_q_target.replace(
213 | step=agent_state.dual_q_target.step + 1, params=updated_q_target_params
214 | )
215 | agent_state = agent_state._replace(dual_q_target=updated_q_target)
216 |
217 | # --- Compute targets ---
218 | v_target = q_apply_fn(agent_state.dual_q_target.params, batch.obs, batch.action)
219 | v_target = v_target.min(-1)
220 | next_v_target = value_apply_fn(agent_state.value.params, batch.next_obs)
221 | q_targets = batch.reward + args.gamma * (1 - batch.done) * next_v_target
222 |
223 | # --- Update Q and value functions ---
224 | def _q_loss_fn(params):
225 | # Compute loss for both critics
226 | q_pred = q_apply_fn(params, batch.obs, batch.action)
227 | q_loss = jnp.square(q_pred - jnp.expand_dims(q_targets, axis=-1)).mean()
228 | return q_loss
229 |
230 | @partial(jax.value_and_grad, has_aux=True)
231 | def _value_loss_fn(params):
232 | adv = v_target - value_apply_fn(params, batch.obs)
233 | # Asymmetric L2 loss
234 | value_loss = jnp.abs(args.iql_tau - (adv < 0.0).astype(float)) * (adv**2)
235 | return jnp.mean(value_loss), adv
236 |
237 | q_loss, q_grad = jax.value_and_grad(_q_loss_fn)(agent_state.dual_q.params)
238 | (value_loss, adv), value_grad = _value_loss_fn(agent_state.value.params)
239 | agent_state = agent_state._replace(
240 | dual_q=agent_state.dual_q.apply_gradients(grads=q_grad),
241 | value=agent_state.value.apply_gradients(grads=value_grad),
242 | )
243 |
244 | # --- Update actor ---
245 | exp_adv = jnp.exp(adv * args.beta).clip(max=args.exp_adv_clip)
246 |
247 | @jax.value_and_grad
248 | def _actor_loss_function(params):
249 | def _compute_loss(transition, exp_adv):
250 | pi = actor_apply_fn(params, transition.obs)
251 | bc_loss = -pi.log_prob(transition.action)
252 | return exp_adv * bc_loss.sum()
253 |
254 | actor_loss = jax.vmap(_compute_loss)(batch, exp_adv)
255 | return actor_loss.mean()
256 |
257 | actor_loss, actor_grad = _actor_loss_function(agent_state.actor.params)
258 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad)
259 | agent_state = agent_state._replace(actor=updated_actor)
260 |
261 | loss = {
262 | "value_loss": value_loss,
263 | "q_loss": q_loss,
264 | "actor_loss": actor_loss,
265 | }
266 | return (rng, agent_state), loss
267 |
268 | return _train_step
269 |
270 |
271 | if __name__ == "__main__":
272 | # --- Parse arguments ---
273 | args = tyro.cli(Args)
274 | rng = jax.random.PRNGKey(args.seed)
275 |
276 | # --- Initialize logger ---
277 | if args.log:
278 | wandb.init(
279 | config=args,
280 | project=args.wandb_project,
281 | entity=args.wandb_team,
282 | group=args.wandb_group,
283 | job_type="train_agent",
284 | )
285 |
286 | # --- Initialize environment and dataset ---
287 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers)
288 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset))
289 | dataset = Transition(
290 | obs=jnp.array(dataset["observations"]),
291 | action=jnp.array(dataset["actions"]),
292 | reward=jnp.array(dataset["rewards"]),
293 | next_obs=jnp.array(dataset["next_observations"]),
294 | done=jnp.array(dataset["terminals"]),
295 | )
296 |
297 | # --- Initialize agent and value networks ---
298 | num_actions = env.single_action_space.shape[0]
299 | obs_mean = dataset.obs.mean(axis=0)
300 | obs_std = jnp.nan_to_num(dataset.obs.std(axis=0), nan=1.0)
301 | dummy_obs = jnp.zeros(env.single_observation_space.shape)
302 | dummy_action = jnp.zeros(num_actions)
303 | actor_net = TanhGaussianActor(num_actions, obs_mean, obs_std)
304 | q_net = DualQNetwork(obs_mean, obs_std)
305 | value_net = StateValueFunction(obs_mean, obs_std)
306 |
307 | # Target networks share seeds to match initialization
308 | rng, rng_actor, rng_q, rng_value = jax.random.split(rng, 4)
309 | agent_state = AgentTrainState(
310 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]),
311 | dual_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
312 | dual_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
313 | value=create_train_state(args, rng_value, value_net, [dummy_obs]),
314 | )
315 |
316 | # --- Make train step ---
317 | _agent_train_step_fn = make_train_step(
318 | args, actor_net.apply, q_net.apply, value_net.apply, dataset
319 | )
320 |
321 | num_evals = args.num_updates // args.eval_interval
322 | for eval_idx in range(num_evals):
323 | # --- Execute train loop ---
324 | (rng, agent_state), loss = jax.lax.scan(
325 | _agent_train_step_fn,
326 | (rng, agent_state),
327 | None,
328 | args.eval_interval,
329 | )
330 |
331 | # --- Evaluate agent ---
332 | rng, rng_eval = jax.random.split(rng)
333 | returns = eval_agent(args, rng_eval, env, agent_state)
334 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0
335 |
336 | # --- Log metrics ---
337 | step = (eval_idx + 1) * args.eval_interval
338 | print("Step:", step, f"\t Score: {scores.mean():.2f}")
339 | if args.log:
340 | log_dict = {
341 | "return": returns.mean(),
342 | "score": scores.mean(),
343 | "score_std": scores.std(),
344 | "num_updates": step,
345 | **{k: loss[k][-1] for k in loss},
346 | }
347 | wandb.log(log_dict)
348 |
349 | # --- Evaluate final agent ---
350 | if args.eval_final_episodes > 0:
351 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers))
352 | print(f"Evaluating final agent for {final_iters} iterations...")
353 | _rng = jax.random.split(rng, final_iters)
354 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng])
355 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0
356 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()}
357 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores")
358 |
359 | # --- Write final returns to file ---
360 | os.makedirs("final_returns", exist_ok=True)
361 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
362 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz"
363 | with open(os.path.join("final_returns", filename), "wb") as f:
364 | onp.savez_compressed(f, **info, args=asdict(args))
365 |
366 | if args.log:
367 | wandb.save(os.path.join("final_returns", filename))
368 |
369 | if args.log:
370 | wandb.finish()
371 |
--------------------------------------------------------------------------------
/algorithms/mopo.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from dataclasses import dataclass, asdict
3 | from datetime import datetime
4 | from functools import partial
5 | import os
6 | import warnings
7 |
8 | import distrax
9 | import d4rl
10 | import flax.linen as nn
11 | from flax.linen.initializers import constant, uniform
12 | from flax.training.train_state import TrainState
13 | import gym
14 | import jax
15 | import jax.numpy as jnp
16 | import numpy as onp
17 | import optax
18 | import tyro
19 | import wandb
20 |
21 | from dynamics import (
22 | Transition,
23 | load_dynamics_model,
24 | EnsembleDynamics, # required for loading dynamics model
25 | EnsembleDynamicsModel, # required for loading dynamics model
26 | )
27 |
28 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
29 |
30 |
31 | @dataclass
32 | class Args:
33 | # --- Experiment ---
34 | seed: int = 0
35 | dataset: str = "halfcheetah-medium-v2"
36 | algorithm: str = "mopo"
37 | num_updates: int = 3_000_000
38 | eval_interval: int = 2500
39 | eval_workers: int = 8
40 | eval_final_episodes: int = 1000
41 | # --- Logging ---
42 | log: bool = False
43 | wandb_project: str = "unifloral"
44 | wandb_team: str = "flair"
45 | wandb_group: str = "debug"
46 | # --- Generic optimization ---
47 | lr: float = 1e-4
48 | batch_size: int = 256
49 | gamma: float = 0.99
50 | polyak_step_size: float = 0.005
51 | # --- SAC-N ---
52 | num_critics: int = 10
53 | # --- World model ---
54 | model_path: str = ""
55 | rollout_interval: int = 1000
56 | rollout_length: int = 5
57 | rollout_batch_size: int = 50000
58 | model_retain_epochs: int = 5
59 | dataset_sample_ratio: float = 0.05
60 | # --- MOPO ---
61 | step_penalty_coef: float = 0.5
62 |
63 |
64 | r"""
65 | |\ __
66 | \| /_/
67 | \|
68 | ___|_____
69 | \ /
70 | \ /
71 | \___/ Preliminaries
72 | """
73 |
74 | AgentTrainState = namedtuple("AgentTrainState", "actor vec_q vec_q_target alpha")
75 |
76 |
77 | def sym(scale):
78 | def _init(*args, **kwargs):
79 | return uniform(2 * scale)(*args, **kwargs) - scale
80 |
81 | return _init
82 |
83 |
84 | class SoftQNetwork(nn.Module):
85 | @nn.compact
86 | def __call__(self, obs, action):
87 | x = jnp.concatenate([obs, action], axis=-1)
88 | for _ in range(3):
89 | x = nn.Dense(256, bias_init=constant(0.1))(x)
90 | x = nn.relu(x)
91 | q = nn.Dense(1, kernel_init=sym(3e-3), bias_init=sym(3e-3))(x)
92 | return q.squeeze(-1)
93 |
94 |
95 | class VectorQ(nn.Module):
96 | num_critics: int
97 |
98 | @nn.compact
99 | def __call__(self, obs, action):
100 | vmap_critic = nn.vmap(
101 | SoftQNetwork,
102 | variable_axes={"params": 0}, # Parameters not shared between critics
103 | split_rngs={"params": True, "dropout": True}, # Different initializations
104 | in_axes=None,
105 | out_axes=-1,
106 | axis_size=self.num_critics,
107 | )
108 | q_values = vmap_critic()(obs, action)
109 | return q_values
110 |
111 |
112 | class TanhGaussianActor(nn.Module):
113 | num_actions: int
114 | log_std_max: float = 2.0
115 | log_std_min: float = -5.0
116 |
117 | @nn.compact
118 | def __call__(self, x):
119 | for _ in range(3):
120 | x = nn.Dense(256, bias_init=constant(0.1))(x)
121 | x = nn.relu(x)
122 | log_std = nn.Dense(
123 | self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3)
124 | )(x)
125 | std = jnp.exp(jnp.clip(log_std, self.log_std_min, self.log_std_max))
126 | mean = nn.Dense(self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3))(x)
127 | pi = distrax.Transformed(
128 | distrax.Normal(mean, std),
129 | distrax.Tanh(),
130 | )
131 | return pi
132 |
133 |
134 | class EntropyCoef(nn.Module):
135 | ent_coef_init: float = 1.0
136 |
137 | @nn.compact
138 | def __call__(self) -> jnp.ndarray:
139 | log_ent_coef = self.param(
140 | "log_ent_coef",
141 | init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)),
142 | )
143 | return log_ent_coef
144 |
145 |
146 | def create_train_state(args, rng, network, dummy_input):
147 | return TrainState.create(
148 | apply_fn=network.apply,
149 | params=network.init(rng, *dummy_input),
150 | tx=optax.adam(args.lr, eps=1e-5),
151 | )
152 |
153 |
154 | def eval_agent(args, rng, env, agent_state):
155 | # --- Reset environment ---
156 | step = 0
157 | returned = onp.zeros(args.eval_workers).astype(bool)
158 | cum_reward = onp.zeros(args.eval_workers)
159 | rng, rng_reset = jax.random.split(rng)
160 | rng_reset = jax.random.split(rng_reset, args.eval_workers)
161 | obs = env.reset()
162 |
163 | # --- Rollout agent ---
164 | @jax.jit
165 | @jax.vmap
166 | def _policy_step(rng, obs):
167 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs)
168 | action = pi.sample(seed=rng)
169 | return jnp.nan_to_num(action)
170 |
171 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps
172 | while step < max_episode_steps and not returned.all():
173 | # --- Take step in environment ---
174 | step += 1
175 | rng, rng_step = jax.random.split(rng)
176 | rng_step = jax.random.split(rng_step, args.eval_workers)
177 | action = _policy_step(rng_step, jnp.array(obs))
178 | obs, reward, done, info = env.step(onp.array(action))
179 |
180 | # --- Track cumulative reward ---
181 | cum_reward += reward * ~returned
182 | returned |= done
183 |
184 | if step >= max_episode_steps and not returned.all():
185 | warnings.warn("Maximum steps reached before all episodes terminated")
186 | return cum_reward
187 |
188 |
189 | def sample_from_buffer(buffer, batch_size, rng):
190 | """Sample a batch from the buffer."""
191 | idxs = jax.random.randint(rng, (batch_size,), 0, len(buffer.obs))
192 | return jax.tree_map(lambda x: x[idxs], buffer)
193 |
194 |
195 | r"""
196 | __/)
197 | .-(__(=:
198 | |\ | \)
199 | \ ||
200 | \||
201 | \|
202 | ___|_____
203 | \ /
204 | \ /
205 | \___/ Agent
206 | """
207 |
208 |
209 | def make_train_step(
210 | args, actor_apply_fn, q_apply_fn, alpha_apply_fn, dataset, rollout_fn
211 | ):
212 | """Make JIT-compatible agent train step with model-based rollouts."""
213 |
214 | def _train_step(runner_state, _):
215 | rng, agent_state, rollout_buffer = runner_state
216 |
217 | # --- Update model buffer ---
218 | params = agent_state.actor.params
219 | policy_fn = lambda obs, rng: actor_apply_fn(params, obs).sample(seed=rng)
220 | rng, rng_buffer = jax.random.split(rng)
221 | rollout_buffer = jax.lax.cond(
222 | agent_state.actor.step % args.rollout_interval == 0,
223 | lambda: rollout_fn(rng_buffer, policy_fn, rollout_buffer),
224 | lambda: rollout_buffer,
225 | )
226 |
227 | # --- Sample batch ---
228 | rng, rng_dataset, rng_rollout = jax.random.split(rng, 3)
229 | dataset_size = int(args.batch_size * args.dataset_sample_ratio)
230 | rollout_size = args.batch_size - dataset_size
231 | dataset_batch = sample_from_buffer(dataset, dataset_size, rng_dataset)
232 | rollout_batch = sample_from_buffer(rollout_buffer, rollout_size, rng_rollout)
233 | batch = jax.tree_map(
234 | lambda x, y: jnp.concatenate([x, y]), dataset_batch, rollout_batch
235 | )
236 |
237 | # --- Update alpha ---
238 | @jax.value_and_grad
239 | def _alpha_loss_fn(params, rng):
240 | def _compute_entropy(rng, transition):
241 | pi = actor_apply_fn(agent_state.actor.params, transition.obs)
242 | _, log_pi = pi.sample_and_log_prob(seed=rng)
243 | return -log_pi.sum()
244 |
245 | log_alpha = alpha_apply_fn(params)
246 | rng = jax.random.split(rng, args.batch_size)
247 | entropy = jax.vmap(_compute_entropy)(rng, batch).mean()
248 | target_entropy = -batch.action.shape[-1]
249 | return log_alpha * (entropy - target_entropy)
250 |
251 | rng, rng_alpha = jax.random.split(rng)
252 | alpha_loss, alpha_grad = _alpha_loss_fn(agent_state.alpha.params, rng_alpha)
253 | updated_alpha = agent_state.alpha.apply_gradients(grads=alpha_grad)
254 | agent_state = agent_state._replace(alpha=updated_alpha)
255 | alpha = jnp.exp(alpha_apply_fn(agent_state.alpha.params))
256 |
257 | # --- Update actor ---
258 | @partial(jax.value_and_grad, has_aux=True)
259 | def _actor_loss_function(params, rng):
260 | def _compute_loss(rng, transition):
261 | pi = actor_apply_fn(params, transition.obs)
262 | sampled_action, log_pi = pi.sample_and_log_prob(seed=rng)
263 | log_pi = log_pi.sum()
264 | q_values = q_apply_fn(
265 | agent_state.vec_q.params, transition.obs, sampled_action
266 | )
267 | q_min = jnp.min(q_values)
268 | return -q_min + alpha * log_pi, -log_pi, q_min, q_values.std()
269 |
270 | rng = jax.random.split(rng, args.batch_size)
271 | loss, entropy, q_min, q_std = jax.vmap(_compute_loss)(rng, batch)
272 | return loss.mean(), (entropy.mean(), q_min.mean(), q_std.mean())
273 |
274 | rng, rng_actor = jax.random.split(rng)
275 | (actor_loss, (entropy, q_min, q_std)), actor_grad = _actor_loss_function(
276 | agent_state.actor.params, rng_actor
277 | )
278 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad)
279 | agent_state = agent_state._replace(actor=updated_actor)
280 |
281 | # --- Update Q target network ---
282 | updated_q_target_params = optax.incremental_update(
283 | agent_state.vec_q.params,
284 | agent_state.vec_q_target.params,
285 | args.polyak_step_size,
286 | )
287 | updated_q_target = agent_state.vec_q_target.replace(
288 | step=agent_state.vec_q_target.step + 1, params=updated_q_target_params
289 | )
290 | agent_state = agent_state._replace(vec_q_target=updated_q_target)
291 |
292 | # --- Compute targets ---
293 | def _sample_next_v(rng, transition):
294 | next_pi = actor_apply_fn(agent_state.actor.params, transition.next_obs)
295 | # Note: Important to use sample_and_log_prob here for numerical stability
296 | # See https://github.com/deepmind/distrax/issues/7
297 | next_action, log_next_pi = next_pi.sample_and_log_prob(seed=rng)
298 | # Minimum of the target Q-values
299 | next_q = q_apply_fn(
300 | agent_state.vec_q_target.params, transition.next_obs, next_action
301 | )
302 | return next_q.min(-1) - alpha * log_next_pi.sum(-1)
303 |
304 | rng, rng_next_v = jax.random.split(rng)
305 | rng_next_v = jax.random.split(rng_next_v, args.batch_size)
306 | next_v_target = jax.vmap(_sample_next_v)(rng_next_v, batch)
307 | target = batch.reward + args.gamma * (1 - batch.done) * next_v_target
308 |
309 | # --- Update critics ---
310 | @jax.value_and_grad
311 | def _q_loss_fn(params):
312 | q_pred = q_apply_fn(params, batch.obs, batch.action)
313 | return jnp.square((q_pred - jnp.expand_dims(target, -1))).sum(-1).mean()
314 |
315 | critic_loss, critic_grad = _q_loss_fn(agent_state.vec_q.params)
316 | updated_q = agent_state.vec_q.apply_gradients(grads=critic_grad)
317 | agent_state = agent_state._replace(vec_q=updated_q)
318 |
319 | num_done = jnp.sum(batch.done)
320 | loss = {
321 | "critic_loss": critic_loss,
322 | "actor_loss": actor_loss,
323 | "alpha_loss": alpha_loss,
324 | "entropy": entropy,
325 | "alpha": alpha,
326 | "q_min": q_min,
327 | "q_std": q_std,
328 | "terminations/num_done": num_done,
329 | "terminations/done_ratio": num_done / batch.done.shape[0],
330 | }
331 | return (rng, agent_state, rollout_buffer), loss
332 |
333 | return _train_step
334 |
335 |
336 | if __name__ == "__main__":
337 | # --- Parse arguments ---
338 | args = tyro.cli(Args)
339 | rng = jax.random.PRNGKey(args.seed)
340 |
341 | # --- Initialize logger ---
342 | if args.log:
343 | wandb.init(
344 | config=args,
345 | project=args.wandb_project,
346 | entity=args.wandb_team,
347 | group=args.wandb_group,
348 | job_type="train_agent",
349 | )
350 |
351 | # --- Initialize environment and dataset ---
352 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers)
353 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset))
354 | dataset = Transition(
355 | obs=jnp.array(dataset["observations"]),
356 | action=jnp.array(dataset["actions"]),
357 | reward=jnp.array(dataset["rewards"]),
358 | next_obs=jnp.array(dataset["next_observations"]),
359 | done=jnp.array(dataset["terminals"]),
360 | next_action=jnp.roll(dataset["actions"], -1, axis=0),
361 | )
362 |
363 | # --- Initialize agent and value networks ---
364 | num_actions = env.single_action_space.shape[0]
365 | dummy_obs = jnp.zeros(env.single_observation_space.shape)
366 | dummy_action = jnp.zeros(num_actions)
367 | actor_net = TanhGaussianActor(num_actions)
368 | q_net = VectorQ(args.num_critics)
369 | alpha_net = EntropyCoef()
370 |
371 | # Target networks share seeds to match initialization
372 | rng, rng_actor, rng_q, rng_alpha = jax.random.split(rng, 4)
373 | agent_state = AgentTrainState(
374 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]),
375 | vec_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
376 | vec_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
377 | alpha=create_train_state(args, rng_alpha, alpha_net, []),
378 | )
379 |
380 | # --- Initialize buffer and rollout function ---
381 | assert args.model_path, "Model path must be provided for model-based methods"
382 | dynamics_model = load_dynamics_model(args.model_path)
383 | dynamics_model.dataset = dataset
384 | max_buffer_size = args.rollout_batch_size * args.rollout_length
385 | max_buffer_size *= args.model_retain_epochs
386 | rollout_buffer = jax.tree_map(
387 | lambda x: jnp.zeros((max_buffer_size, *x.shape[1:])),
388 | dataset,
389 | )
390 | rollout_fn = dynamics_model.make_rollout_fn(
391 | batch_size=args.rollout_batch_size,
392 | rollout_length=args.rollout_length,
393 | step_penalty_coef=args.step_penalty_coef,
394 | )
395 |
396 | # --- Make train step ---
397 | _agent_train_step_fn = make_train_step(
398 | args, actor_net.apply, q_net.apply, alpha_net.apply, dataset, rollout_fn
399 | )
400 |
401 | num_evals = args.num_updates // args.eval_interval
402 | for eval_idx in range(num_evals):
403 | # --- Execute train loop ---
404 | (rng, agent_state, rollout_buffer), loss = jax.lax.scan(
405 | _agent_train_step_fn,
406 | (rng, agent_state, rollout_buffer),
407 | None,
408 | args.eval_interval,
409 | )
410 |
411 | # --- Evaluate agent ---
412 | rng, rng_eval = jax.random.split(rng)
413 | returns = eval_agent(args, rng_eval, env, agent_state)
414 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0
415 |
416 | # --- Log metrics ---
417 | step = (eval_idx + 1) * args.eval_interval
418 | print("Step:", step, f"\t Score: {scores.mean():.2f}")
419 | if args.log:
420 | log_dict = {
421 | "return": returns.mean(),
422 | "score": scores.mean(),
423 | "score_std": scores.std(),
424 | "num_updates": step,
425 | **{k: loss[k][-1] for k in loss},
426 | }
427 | wandb.log(log_dict)
428 |
429 | # --- Evaluate final agent ---
430 | if args.eval_final_episodes > 0:
431 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers))
432 | print(f"Evaluating final agent for {final_iters} iterations...")
433 | _rng = jax.random.split(rng, final_iters)
434 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng])
435 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0
436 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()}
437 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores")
438 |
439 | # --- Write final returns to file ---
440 | os.makedirs("final_returns", exist_ok=True)
441 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
442 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz"
443 | with open(os.path.join("final_returns", filename), "wb") as f:
444 | onp.savez_compressed(f, **info, args=asdict(args))
445 |
446 | if args.log:
447 | wandb.save(os.path.join("final_returns", filename))
448 |
449 | if args.log:
450 | wandb.finish()
451 |
--------------------------------------------------------------------------------
/algorithms/morel.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from dataclasses import dataclass, asdict
3 | from datetime import datetime
4 | from functools import partial
5 | import os
6 | import warnings
7 |
8 | import distrax
9 | import d4rl
10 | import flax.linen as nn
11 | from flax.linen.initializers import constant, uniform
12 | from flax.training.train_state import TrainState
13 | import gym
14 | import jax
15 | import jax.numpy as jnp
16 | import numpy as onp
17 | import optax
18 | import tyro
19 | import wandb
20 |
21 | from dynamics import (
22 | Transition,
23 | load_dynamics_model,
24 | EnsembleDynamics, # required for loading dynamics model
25 | EnsembleDynamicsModel, # required for loading dynamics model
26 | )
27 |
28 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
29 |
30 |
31 | @dataclass
32 | class Args:
33 | # --- Experiment ---
34 | seed: int = 0
35 | dataset: str = "halfcheetah-medium-v2"
36 | algorithm: str = "morel"
37 | num_updates: int = 3_000_000
38 | eval_interval: int = 2500
39 | eval_workers: int = 8
40 | eval_final_episodes: int = 1000
41 | # --- Logging ---
42 | log: bool = False
43 | wandb_project: str = "unifloral"
44 | wandb_team: str = "flair"
45 | wandb_group: str = "debug"
46 | # --- Generic optimization ---
47 | lr: float = 1e-4
48 | batch_size: int = 256
49 | gamma: float = 0.99
50 | polyak_step_size: float = 0.005
51 | # --- SAC-N ---
52 | num_critics: int = 10
53 | # --- World model ---
54 | model_path: str = ""
55 | rollout_interval: int = 1000
56 | rollout_length: int = 5
57 | rollout_batch_size: int = 50000
58 | model_retain_epochs: int = 5
59 | dataset_sample_ratio: float = 0.01
60 | # --- MOReL ---
61 | threshold_coef: float = 1.0
62 | term_penalty_offset: float = -200
63 |
64 |
65 | r"""
66 | |\ __
67 | \| /_/
68 | \|
69 | ___|_____
70 | \ /
71 | \ /
72 | \___/ Preliminaries
73 | """
74 |
75 | AgentTrainState = namedtuple("AgentTrainState", "actor vec_q vec_q_target alpha")
76 |
77 |
78 | def sym(scale):
79 | def _init(*args, **kwargs):
80 | return uniform(2 * scale)(*args, **kwargs) - scale
81 |
82 | return _init
83 |
84 |
85 | class SoftQNetwork(nn.Module):
86 | @nn.compact
87 | def __call__(self, obs, action):
88 | x = jnp.concatenate([obs, action], axis=-1)
89 | for _ in range(3):
90 | x = nn.Dense(256, bias_init=constant(0.1))(x)
91 | x = nn.relu(x)
92 | q = nn.Dense(1, kernel_init=sym(3e-3), bias_init=sym(3e-3))(x)
93 | return q.squeeze(-1)
94 |
95 |
96 | class VectorQ(nn.Module):
97 | num_critics: int
98 |
99 | @nn.compact
100 | def __call__(self, obs, action):
101 | vmap_critic = nn.vmap(
102 | SoftQNetwork,
103 | variable_axes={"params": 0}, # Parameters not shared between critics
104 | split_rngs={"params": True, "dropout": True}, # Different initializations
105 | in_axes=None,
106 | out_axes=-1,
107 | axis_size=self.num_critics,
108 | )
109 | q_values = vmap_critic()(obs, action)
110 | return q_values
111 |
112 |
113 | class TanhGaussianActor(nn.Module):
114 | num_actions: int
115 | log_std_max: float = 2.0
116 | log_std_min: float = -5.0
117 |
118 | @nn.compact
119 | def __call__(self, x):
120 | for _ in range(3):
121 | x = nn.Dense(256, bias_init=constant(0.1))(x)
122 | x = nn.relu(x)
123 | log_std = nn.Dense(
124 | self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3)
125 | )(x)
126 | std = jnp.exp(jnp.clip(log_std, self.log_std_min, self.log_std_max))
127 | mean = nn.Dense(self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3))(x)
128 | pi = distrax.Transformed(
129 | distrax.Normal(mean, std),
130 | distrax.Tanh(),
131 | )
132 | return pi
133 |
134 |
135 | class EntropyCoef(nn.Module):
136 | ent_coef_init: float = 1.0
137 |
138 | @nn.compact
139 | def __call__(self) -> jnp.ndarray:
140 | log_ent_coef = self.param(
141 | "log_ent_coef",
142 | init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)),
143 | )
144 | return log_ent_coef
145 |
146 |
147 | def create_train_state(args, rng, network, dummy_input):
148 | return TrainState.create(
149 | apply_fn=network.apply,
150 | params=network.init(rng, *dummy_input),
151 | tx=optax.adam(args.lr, eps=1e-5),
152 | )
153 |
154 |
155 | def eval_agent(args, rng, env, agent_state):
156 | # --- Reset environment ---
157 | step = 0
158 | returned = onp.zeros(args.eval_workers).astype(bool)
159 | cum_reward = onp.zeros(args.eval_workers)
160 | rng, rng_reset = jax.random.split(rng)
161 | rng_reset = jax.random.split(rng_reset, args.eval_workers)
162 | obs = env.reset()
163 |
164 | # --- Rollout agent ---
165 | @jax.jit
166 | @jax.vmap
167 | def _policy_step(rng, obs):
168 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs)
169 | action = pi.sample(seed=rng)
170 | return jnp.nan_to_num(action)
171 |
172 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps
173 | while step < max_episode_steps and not returned.all():
174 | # --- Take step in environment ---
175 | step += 1
176 | rng, rng_step = jax.random.split(rng)
177 | rng_step = jax.random.split(rng_step, args.eval_workers)
178 | action = _policy_step(rng_step, jnp.array(obs))
179 | obs, reward, done, info = env.step(onp.array(action))
180 |
181 | # --- Track cumulative reward ---
182 | cum_reward += reward * ~returned
183 | returned |= done
184 |
185 | if step >= max_episode_steps and not returned.all():
186 | warnings.warn("Maximum steps reached before all episodes terminated")
187 | return cum_reward
188 |
189 |
190 | def sample_from_buffer(buffer, batch_size, rng):
191 | """Sample a batch from the buffer."""
192 | idxs = jax.random.randint(rng, (batch_size,), 0, len(buffer.obs))
193 | return jax.tree_map(lambda x: x[idxs], buffer)
194 |
195 |
196 | r"""
197 | __/)
198 | .-(__(=:
199 | |\ | \)
200 | \ ||
201 | \||
202 | \|
203 | ___|_____
204 | \ /
205 | \ /
206 | \___/ Agent
207 | """
208 |
209 |
210 | def make_train_step(
211 | args, actor_apply_fn, q_apply_fn, alpha_apply_fn, dataset, rollout_fn
212 | ):
213 | """Make JIT-compatible agent train step with model-based rollouts."""
214 |
215 | def _train_step(runner_state, _):
216 | rng, agent_state, rollout_buffer = runner_state
217 |
218 | # --- Update model buffer ---
219 | params = agent_state.actor.params
220 | policy_fn = lambda obs, rng: actor_apply_fn(params, obs).sample(seed=rng)
221 | rng, rng_buffer = jax.random.split(rng)
222 | rollout_buffer = jax.lax.cond(
223 | agent_state.actor.step % args.rollout_interval == 0,
224 | lambda: rollout_fn(rng_buffer, policy_fn, rollout_buffer),
225 | lambda: rollout_buffer,
226 | )
227 |
228 | # --- Sample batch ---
229 | rng, rng_dataset, rng_rollout = jax.random.split(rng, 3)
230 | dataset_size = int(args.batch_size * args.dataset_sample_ratio)
231 | rollout_size = args.batch_size - dataset_size
232 | dataset_batch = sample_from_buffer(dataset, dataset_size, rng_dataset)
233 | rollout_batch = sample_from_buffer(rollout_buffer, rollout_size, rng_rollout)
234 | batch = jax.tree_map(
235 | lambda x, y: jnp.concatenate([x, y]), dataset_batch, rollout_batch
236 | )
237 |
238 | # --- Update alpha ---
239 | @jax.value_and_grad
240 | def _alpha_loss_fn(params, rng):
241 | def _compute_entropy(rng, transition):
242 | pi = actor_apply_fn(agent_state.actor.params, transition.obs)
243 | _, log_pi = pi.sample_and_log_prob(seed=rng)
244 | return -log_pi.sum()
245 |
246 | log_alpha = alpha_apply_fn(params)
247 | rng = jax.random.split(rng, args.batch_size)
248 | entropy = jax.vmap(_compute_entropy)(rng, batch).mean()
249 | target_entropy = -batch.action.shape[-1]
250 | return log_alpha * (entropy - target_entropy)
251 |
252 | rng, rng_alpha = jax.random.split(rng)
253 | alpha_loss, alpha_grad = _alpha_loss_fn(agent_state.alpha.params, rng_alpha)
254 | updated_alpha = agent_state.alpha.apply_gradients(grads=alpha_grad)
255 | agent_state = agent_state._replace(alpha=updated_alpha)
256 | alpha = jnp.exp(alpha_apply_fn(agent_state.alpha.params))
257 |
258 | # --- Update actor ---
259 | @partial(jax.value_and_grad, has_aux=True)
260 | def _actor_loss_function(params, rng):
261 | def _compute_loss(rng, transition):
262 | pi = actor_apply_fn(params, transition.obs)
263 | sampled_action, log_pi = pi.sample_and_log_prob(seed=rng)
264 | log_pi = log_pi.sum()
265 | q_values = q_apply_fn(
266 | agent_state.vec_q.params, transition.obs, sampled_action
267 | )
268 | q_min = jnp.min(q_values)
269 | return -q_min + alpha * log_pi, -log_pi, q_min, q_values.std()
270 |
271 | rng = jax.random.split(rng, args.batch_size)
272 | loss, entropy, q_min, q_std = jax.vmap(_compute_loss)(rng, batch)
273 | return loss.mean(), (entropy.mean(), q_min.mean(), q_std.mean())
274 |
275 | rng, rng_actor = jax.random.split(rng)
276 | (actor_loss, (entropy, q_min, q_std)), actor_grad = _actor_loss_function(
277 | agent_state.actor.params, rng_actor
278 | )
279 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad)
280 | agent_state = agent_state._replace(actor=updated_actor)
281 |
282 | # --- Update Q target network ---
283 | updated_q_target_params = optax.incremental_update(
284 | agent_state.vec_q.params,
285 | agent_state.vec_q_target.params,
286 | args.polyak_step_size,
287 | )
288 | updated_q_target = agent_state.vec_q_target.replace(
289 | step=agent_state.vec_q_target.step + 1, params=updated_q_target_params
290 | )
291 | agent_state = agent_state._replace(vec_q_target=updated_q_target)
292 |
293 | # --- Compute targets ---
294 | def _sample_next_v(rng, transition):
295 | next_pi = actor_apply_fn(agent_state.actor.params, transition.next_obs)
296 | # Note: Important to use sample_and_log_prob here for numerical stability
297 | # See https://github.com/deepmind/distrax/issues/7
298 | next_action, log_next_pi = next_pi.sample_and_log_prob(seed=rng)
299 | # Minimum of the target Q-values
300 | next_q = q_apply_fn(
301 | agent_state.vec_q_target.params, transition.next_obs, next_action
302 | )
303 | return next_q.min(-1) - alpha * log_next_pi.sum(-1)
304 |
305 | rng, rng_next_v = jax.random.split(rng)
306 | rng_next_v = jax.random.split(rng_next_v, args.batch_size)
307 | next_v_target = jax.vmap(_sample_next_v)(rng_next_v, batch)
308 | target = batch.reward + args.gamma * (1 - batch.done) * next_v_target
309 |
310 | # --- Update critics ---
311 | @jax.value_and_grad
312 | def _q_loss_fn(params):
313 | q_pred = q_apply_fn(params, batch.obs, batch.action)
314 | return jnp.square((q_pred - jnp.expand_dims(target, -1))).sum(-1).mean()
315 |
316 | critic_loss, critic_grad = _q_loss_fn(agent_state.vec_q.params)
317 | updated_q = agent_state.vec_q.apply_gradients(grads=critic_grad)
318 | agent_state = agent_state._replace(vec_q=updated_q)
319 |
320 | num_done = jnp.sum(batch.done)
321 | loss = {
322 | "critic_loss": critic_loss,
323 | "actor_loss": actor_loss,
324 | "alpha_loss": alpha_loss,
325 | "entropy": entropy,
326 | "alpha": alpha,
327 | "q_min": q_min,
328 | "q_std": q_std,
329 | "terminations/num_done": num_done,
330 | "terminations/done_ratio": num_done / batch.done.shape[0],
331 | }
332 | return (rng, agent_state, rollout_buffer), loss
333 |
334 | return _train_step
335 |
336 |
337 | if __name__ == "__main__":
338 | # --- Parse arguments ---
339 | args = tyro.cli(Args)
340 | rng = jax.random.PRNGKey(args.seed)
341 |
342 | # --- Initialize logger ---
343 | if args.log:
344 | wandb.init(
345 | config=args,
346 | project=args.wandb_project,
347 | entity=args.wandb_team,
348 | group=args.wandb_group,
349 | job_type="train_agent",
350 | )
351 |
352 | # --- Initialize environment and dataset ---
353 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers)
354 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset))
355 | dataset = Transition(
356 | obs=jnp.array(dataset["observations"]),
357 | action=jnp.array(dataset["actions"]),
358 | reward=jnp.array(dataset["rewards"]),
359 | next_obs=jnp.array(dataset["next_observations"]),
360 | done=jnp.array(dataset["terminals"]),
361 | next_action=jnp.roll(dataset["actions"], -1, axis=0),
362 | )
363 |
364 | # --- Initialize agent and value networks ---
365 | num_actions = env.single_action_space.shape[0]
366 | dummy_obs = jnp.zeros(env.single_observation_space.shape)
367 | dummy_action = jnp.zeros(num_actions)
368 | actor_net = TanhGaussianActor(num_actions)
369 | q_net = VectorQ(args.num_critics)
370 | alpha_net = EntropyCoef()
371 |
372 | # Target networks share seeds to match initialization
373 | rng, rng_actor, rng_q, rng_alpha = jax.random.split(rng, 4)
374 | agent_state = AgentTrainState(
375 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]),
376 | vec_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
377 | vec_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
378 | alpha=create_train_state(args, rng_alpha, alpha_net, []),
379 | )
380 |
381 | # --- Initialize buffer and rollout function ---
382 | assert args.model_path, "Model path must be provided for model-based methods"
383 | dynamics_model = load_dynamics_model(args.model_path)
384 | assert (
385 | dynamics_model.discrepancy is not None and dynamics_model.min_r is not None
386 | ), "MOReL requires a dynamics model with precomputed statistics (train with --precompute_term_stats)"
387 | dynamics_model.dataset = dataset
388 | max_buffer_size = args.rollout_batch_size * args.rollout_length
389 | max_buffer_size *= args.model_retain_epochs
390 | rollout_buffer = jax.tree_map(
391 | lambda x: jnp.zeros((max_buffer_size, *x.shape[1:])),
392 | dataset,
393 | )
394 | rollout_fn = dynamics_model.make_rollout_fn(
395 | batch_size=args.rollout_batch_size,
396 | rollout_length=args.rollout_length,
397 | term_penalty_offset=args.term_penalty_offset,
398 | threshold_coef=args.threshold_coef,
399 | )
400 |
401 | # --- Make train step ---
402 | _agent_train_step_fn = make_train_step(
403 | args, actor_net.apply, q_net.apply, alpha_net.apply, dataset, rollout_fn
404 | )
405 |
406 | num_evals = args.num_updates // args.eval_interval
407 | for eval_idx in range(num_evals):
408 | # --- Execute train loop ---
409 | (rng, agent_state, rollout_buffer), loss = jax.lax.scan(
410 | _agent_train_step_fn,
411 | (rng, agent_state, rollout_buffer),
412 | None,
413 | args.eval_interval,
414 | )
415 |
416 | # --- Evaluate agent ---
417 | rng, rng_eval = jax.random.split(rng)
418 | returns = eval_agent(args, rng_eval, env, agent_state)
419 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0
420 |
421 | # --- Log metrics ---
422 | step = (eval_idx + 1) * args.eval_interval
423 | print("Step:", step, f"\t Score: {scores.mean():.2f}")
424 | if args.log:
425 | log_dict = {
426 | "return": returns.mean(),
427 | "score": scores.mean(),
428 | "score_std": scores.std(),
429 | "num_updates": step,
430 | **{k: loss[k][-1] for k in loss},
431 | }
432 | wandb.log(log_dict)
433 |
434 | # --- Evaluate final agent ---
435 | if args.eval_final_episodes > 0:
436 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers))
437 | print(f"Evaluating final agent for {final_iters} iterations...")
438 | _rng = jax.random.split(rng, final_iters)
439 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng])
440 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0
441 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()}
442 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores")
443 |
444 | # --- Write final returns to file ---
445 | os.makedirs("final_returns", exist_ok=True)
446 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
447 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz"
448 | with open(os.path.join("final_returns", filename), "wb") as f:
449 | onp.savez_compressed(f, **info, args=asdict(args))
450 |
451 | if args.log:
452 | wandb.save(os.path.join("final_returns", filename))
453 |
454 | if args.log:
455 | wandb.finish()
456 |
--------------------------------------------------------------------------------
/algorithms/rebrac.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from dataclasses import dataclass, asdict
3 | from datetime import datetime
4 | import os
5 | import warnings
6 |
7 | import distrax
8 | import d4rl
9 | import flax.linen as nn
10 | from flax.linen.initializers import constant, uniform
11 | from flax.training.train_state import TrainState
12 | import gym
13 | import jax
14 | import jax.numpy as jnp
15 | import numpy as onp
16 | import optax
17 | import tyro
18 | import wandb
19 |
20 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
21 |
22 |
23 | @dataclass
24 | class Args:
25 | # --- Experiment ---
26 | seed: int = 0
27 | dataset: str = "halfcheetah-medium-v2"
28 | algorithm: str = "rebrac"
29 | num_updates: int = 1_000_000
30 | eval_interval: int = 2500
31 | eval_workers: int = 8
32 | eval_final_episodes: int = 1000
33 | # --- Logging ---
34 | log: bool = False
35 | wandb_project: str = "unifloral"
36 | wandb_team: str = "flair"
37 | wandb_group: str = "debug"
38 | # --- Generic optimization ---
39 | lr: float = 1e-3
40 | batch_size: int = 1024
41 | gamma: float = 0.99
42 | polyak_step_size: float = 0.005
43 | # --- TD3+BC ---
44 | noise_clip: float = 0.5
45 | policy_noise: float = 0.2
46 | num_critic_updates_per_step: int = 2
47 | # --- REBRAC ---
48 | critic_bc_coef: float = 0.01
49 | actor_bc_coef: float = 0.001
50 | actor_ln: bool = False
51 | critic_ln: bool = True
52 | norm_obs: bool = False
53 |
54 |
55 | r"""
56 | |\ __
57 | \| /_/
58 | \|
59 | ___|_____
60 | \ /
61 | \ /
62 | \___/ Preliminaries
63 | """
64 |
65 | AgentTrainState = namedtuple(
66 | "AgentTrainState", "actor actor_target dual_q dual_q_target"
67 | )
68 | Transition = namedtuple("Transition", "obs action reward next_obs next_action done")
69 |
70 |
71 | def sym(scale):
72 | def _init(*args, **kwargs):
73 | return uniform(2 * scale)(*args, **kwargs) - scale
74 |
75 | return _init
76 |
77 |
78 | class SoftQNetwork(nn.Module):
79 | obs_mean: jax.Array
80 | obs_std: jax.Array
81 | use_ln: bool
82 | norm_obs: bool
83 |
84 | @nn.compact
85 | def __call__(self, obs, action):
86 | if self.norm_obs:
87 | obs = (obs - self.obs_mean) / (self.obs_std + 1e-3)
88 | x = jnp.concatenate([obs, action], axis=-1)
89 | for _ in range(3):
90 | x = nn.Dense(256, bias_init=constant(0.1))(x)
91 | x = nn.relu(x)
92 | x = nn.LayerNorm()(x) if self.use_ln else x
93 | q = nn.Dense(1, bias_init=sym(3e-3), kernel_init=sym(3e-3))(x)
94 | return q.squeeze(-1)
95 |
96 |
97 | class DualQNetwork(nn.Module):
98 | obs_mean: jax.Array
99 | obs_std: jax.Array
100 | use_ln: bool
101 | norm_obs: bool
102 |
103 | @nn.compact
104 | def __call__(self, obs, action):
105 | vmap_critic = nn.vmap(
106 | SoftQNetwork,
107 | variable_axes={"params": 0}, # Parameters not shared between critics
108 | split_rngs={"params": True, "dropout": True}, # Different initializations
109 | in_axes=None,
110 | out_axes=-1,
111 | axis_size=2, # Two Q networks
112 | )
113 | q_fn = vmap_critic(self.obs_mean, self.obs_std, self.use_ln, self.norm_obs)
114 | q_values = q_fn(obs, action)
115 | return q_values
116 |
117 |
118 | class DeterministicTanhActor(nn.Module):
119 | num_actions: int
120 | obs_mean: jax.Array
121 | obs_std: jax.Array
122 | use_ln: bool
123 | norm_obs: bool
124 |
125 | @nn.compact
126 | def __call__(self, x):
127 | if self.norm_obs:
128 | x = (x - self.obs_mean) / (self.obs_std + 1e-3)
129 | for _ in range(3):
130 | x = nn.Dense(256, bias_init=constant(0.1))(x)
131 | x = nn.relu(x)
132 | x = nn.LayerNorm()(x) if self.use_ln else x
133 | init_fn = sym(1e-3)
134 | action = nn.Dense(self.num_actions, bias_init=init_fn, kernel_init=init_fn)(x)
135 | pi = distrax.Transformed(
136 | distrax.Deterministic(action),
137 | distrax.Tanh(),
138 | )
139 | return pi
140 |
141 |
142 | def create_train_state(args, rng, network, dummy_input):
143 | return TrainState.create(
144 | apply_fn=network.apply,
145 | params=network.init(rng, *dummy_input),
146 | tx=optax.adam(args.lr, eps=1e-5),
147 | )
148 |
149 |
150 | def eval_agent(args, rng, env, agent_state):
151 | # --- Reset environment ---
152 | step = 0
153 | returned = onp.zeros(args.eval_workers).astype(bool)
154 | cum_reward = onp.zeros(args.eval_workers)
155 | rng, rng_reset = jax.random.split(rng)
156 | rng_reset = jax.random.split(rng_reset, args.eval_workers)
157 | obs = env.reset()
158 |
159 | # --- Rollout agent ---
160 | @jax.jit
161 | @jax.vmap
162 | def _policy_step(rng, obs):
163 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs)
164 | action = pi.sample(seed=rng)
165 | return jnp.nan_to_num(action)
166 |
167 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps
168 | while step < max_episode_steps and not returned.all():
169 | # --- Take step in environment ---
170 | step += 1
171 | rng, rng_step = jax.random.split(rng)
172 | rng_step = jax.random.split(rng_step, args.eval_workers)
173 | action = _policy_step(rng_step, jnp.array(obs))
174 | obs, reward, done, info = env.step(onp.array(action))
175 |
176 | # --- Track cumulative reward ---
177 | cum_reward += reward * ~returned
178 | returned |= done
179 |
180 | if step >= max_episode_steps and not returned.all():
181 | warnings.warn("Maximum steps reached before all episodes terminated")
182 | return cum_reward
183 |
184 |
185 | r"""
186 | __/)
187 | .-(__(=:
188 | |\ | \)
189 | \ ||
190 | \||
191 | \|
192 | ___|_____
193 | \ /
194 | \ /
195 | \___/ Agent
196 | """
197 |
198 |
199 | def make_train_step(args, actor_apply_fn, q_apply_fn, dataset):
200 | """Make JIT-compatible agent train step."""
201 |
202 | def _train_step(runner_state, _):
203 | rng, agent_state = runner_state
204 |
205 | # --- Sample batch ---
206 | rng, rng_batch = jax.random.split(rng)
207 | batch_indices = jax.random.randint(
208 | rng_batch, (args.batch_size,), 0, len(dataset.obs)
209 | )
210 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset)
211 |
212 | # --- Update critics ---
213 | def _update_critics(runner_state, _):
214 | rng, agent_state = runner_state
215 |
216 | def _compute_target(rng, transition):
217 | next_obs = transition.next_obs
218 |
219 | # --- Sample noised action ---
220 | next_pi = actor_apply_fn(agent_state.actor_target.params, next_obs)
221 | rng, rng_action, rng_noise = jax.random.split(rng, 3)
222 | action = next_pi.sample(seed=rng_action)
223 | noise = jax.random.normal(rng_noise, shape=action.shape)
224 | noise *= args.policy_noise
225 | noise = jnp.clip(noise, -args.noise_clip, args.noise_clip)
226 | action = jnp.clip(action + noise, -1, 1)
227 | bc_loss = jnp.square(action - transition.next_action).sum()
228 |
229 | # --- Compute targets ---
230 | target_q = q_apply_fn(
231 | agent_state.dual_q_target.params, next_obs, action
232 | )
233 | next_q_value = jnp.min(target_q) - args.critic_bc_coef * bc_loss
234 | next_q_value = (1.0 - transition.done) * next_q_value
235 | return transition.reward + args.gamma * next_q_value, bc_loss
236 |
237 | rng, rng_targets = jax.random.split(rng)
238 | rng_targets = jax.random.split(rng_targets, args.batch_size)
239 | target_fn = jax.vmap(_compute_target)
240 | targets, bc_loss = target_fn(rng_targets, batch)
241 |
242 | # --- Compute critic loss ---
243 | @jax.value_and_grad
244 | def _q_loss_fn(params):
245 | q_pred = q_apply_fn(params, batch.obs, batch.action)
246 | q_loss = jnp.square(q_pred - jnp.expand_dims(targets, axis=-1)).sum(-1)
247 | return q_loss.mean()
248 |
249 | q_loss, q_grad = _q_loss_fn(agent_state.dual_q.params)
250 | updated_q_state = agent_state.dual_q.apply_gradients(grads=q_grad)
251 | agent_state = agent_state._replace(dual_q=updated_q_state)
252 | return (rng, agent_state), (q_loss, bc_loss)
253 |
254 | # --- Iterate critic update ---
255 | (rng, agent_state), (q_loss, critic_bc_loss) = jax.lax.scan(
256 | _update_critics,
257 | (rng, agent_state),
258 | None,
259 | length=args.num_critic_updates_per_step,
260 | )
261 |
262 | # --- Update actor ---
263 | def _actor_loss_function(params):
264 | def _transition_loss(transition):
265 | pi = actor_apply_fn(params, transition.obs)
266 | pi_action = pi.sample(seed=None)
267 | q = q_apply_fn(agent_state.dual_q.params, transition.obs, pi_action)
268 | bc_loss = jnp.square(pi_action - transition.action).sum()
269 | return q.min(), bc_loss
270 |
271 | q, bc_loss = jax.vmap(_transition_loss)(batch)
272 | lambda_ = 1.0 / (jnp.abs(q).mean() + 1e-7)
273 | lambda_ = jax.lax.stop_gradient(lambda_)
274 | actor_loss = -lambda_ * q.mean() + args.actor_bc_coef * bc_loss.mean()
275 | return actor_loss.mean(), (q.mean(), lambda_.mean(), bc_loss.mean())
276 |
277 | loss_fn = jax.value_and_grad(_actor_loss_function, has_aux=True)
278 | (actor_loss, (q_mean, lambda_, bc_loss)), actor_grad = loss_fn(
279 | agent_state.actor.params
280 | )
281 | agent_state = agent_state._replace(
282 | actor=agent_state.actor.apply_gradients(grads=actor_grad)
283 | )
284 |
285 | # --- Update target networks ---
286 | def _update_target(state, target_state):
287 | new_target_params = optax.incremental_update(
288 | state.params, target_state.params, args.polyak_step_size
289 | )
290 | return target_state.replace(
291 | step=target_state.step + 1, params=new_target_params
292 | )
293 |
294 | agent_state = agent_state._replace(
295 | actor_target=_update_target(agent_state.actor, agent_state.actor_target),
296 | dual_q_target=_update_target(agent_state.dual_q, agent_state.dual_q_target),
297 | )
298 |
299 | loss = {
300 | "actor_loss": actor_loss,
301 | "q_loss": q_loss.mean(),
302 | "q_mean": q_mean,
303 | "lambda": lambda_,
304 | "bc_loss": bc_loss,
305 | "critic_bc_loss": critic_bc_loss.mean(),
306 | }
307 | return (rng, agent_state), loss
308 |
309 | return _train_step
310 |
311 |
312 | if __name__ == "__main__":
313 | # --- Parse arguments ---
314 | args = tyro.cli(Args)
315 | rng = jax.random.PRNGKey(args.seed)
316 |
317 | # --- Initialize logger ---
318 | if args.log:
319 | wandb.init(
320 | config=args,
321 | project=args.wandb_project,
322 | entity=args.wandb_team,
323 | group=args.wandb_group,
324 | job_type="train_agent",
325 | )
326 |
327 | # --- Initialize environment and dataset ---
328 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers)
329 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset))
330 | dataset = Transition(
331 | obs=jnp.array(dataset["observations"]),
332 | action=jnp.array(dataset["actions"]),
333 | reward=jnp.array(dataset["rewards"]),
334 | next_obs=jnp.array(dataset["next_observations"]),
335 | next_action=jnp.roll(dataset["actions"], -1, axis=0),
336 | done=jnp.array(dataset["terminals"]),
337 | )
338 |
339 | # --- Initialize agent and value networks ---
340 | num_actions = env.single_action_space.shape[0]
341 | obs_mean = dataset.obs.mean(axis=0)
342 | obs_std = jnp.nan_to_num(dataset.obs.std(axis=0), nan=1.0)
343 | dummy_obs = jnp.zeros(env.single_observation_space.shape)
344 | dummy_action = jnp.zeros(num_actions)
345 | actor_cls = DeterministicTanhActor
346 | actor_net = actor_cls(num_actions, obs_mean, obs_std, args.actor_ln, args.norm_obs)
347 | q_net = DualQNetwork(obs_mean, obs_std, args.critic_ln, args.norm_obs)
348 |
349 | # Target networks share seeds to match initialization
350 | rng, rng_actor, rng_q = jax.random.split(rng, 3)
351 | agent_state = AgentTrainState(
352 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]),
353 | actor_target=create_train_state(args, rng_actor, actor_net, [dummy_obs]),
354 | dual_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
355 | dual_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
356 | )
357 |
358 | # --- Make train step ---
359 | _agent_train_step_fn = make_train_step(args, actor_net.apply, q_net.apply, dataset)
360 |
361 | num_evals = args.num_updates // args.eval_interval
362 | for eval_idx in range(num_evals):
363 | # --- Execute train loop ---
364 | (rng, agent_state), loss = jax.lax.scan(
365 | _agent_train_step_fn,
366 | (rng, agent_state),
367 | None,
368 | args.eval_interval,
369 | )
370 |
371 | # --- Evaluate agent ---
372 | rng, rng_eval = jax.random.split(rng)
373 | returns = eval_agent(args, rng_eval, env, agent_state)
374 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0
375 |
376 | # --- Log metrics ---
377 | step = (eval_idx + 1) * args.eval_interval
378 | print("Step:", step, f"\t Score: {scores.mean():.2f}")
379 | if args.log:
380 | log_dict = {
381 | "return": returns.mean(),
382 | "score": scores.mean(),
383 | "score_std": scores.std(),
384 | "num_updates": step,
385 | **{k: loss[k][-1] for k in loss},
386 | }
387 | wandb.log(log_dict)
388 |
389 | # --- Evaluate final agent ---
390 | if args.eval_final_episodes > 0:
391 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers))
392 | print(f"Evaluating final agent for {final_iters} iterations...")
393 | _rng = jax.random.split(rng, final_iters)
394 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng])
395 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0
396 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()}
397 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores")
398 |
399 | # --- Write final returns to file ---
400 | os.makedirs("final_returns", exist_ok=True)
401 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
402 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz"
403 | with open(os.path.join("final_returns", filename), "wb") as f:
404 | onp.savez_compressed(f, **info, args=asdict(args))
405 |
406 | if args.log:
407 | wandb.save(os.path.join("final_returns", filename))
408 |
409 | if args.log:
410 | wandb.finish()
411 |
--------------------------------------------------------------------------------
/algorithms/sac_n.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from dataclasses import dataclass, asdict
3 | from datetime import datetime
4 | from functools import partial
5 | import os
6 | import warnings
7 |
8 | import distrax
9 | import d4rl
10 | import flax.linen as nn
11 | from flax.linen.initializers import constant, uniform
12 | from flax.training.train_state import TrainState
13 | import gym
14 | import jax
15 | import jax.numpy as jnp
16 | import numpy as onp
17 | import optax
18 | import tyro
19 | import wandb
20 |
21 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
22 |
23 |
24 | @dataclass
25 | class Args:
26 | # --- Experiment ---
27 | seed: int = 0
28 | dataset: str = "halfcheetah-medium-v2"
29 | algorithm: str = "sac_n"
30 | num_updates: int = 3_000_000
31 | eval_interval: int = 2500
32 | eval_workers: int = 8
33 | eval_final_episodes: int = 1000
34 | # --- Logging ---
35 | log: bool = False
36 | wandb_project: str = "unifloral"
37 | wandb_team: str = "flair"
38 | wandb_group: str = "debug"
39 | # --- Generic optimization ---
40 | lr: float = 3e-4
41 | batch_size: int = 256
42 | gamma: float = 0.99
43 | polyak_step_size: float = 0.005
44 | # --- SAC-N ---
45 | num_critics: int = 10
46 |
47 |
48 | r"""
49 | |\ __
50 | \| /_/
51 | \|
52 | ___|_____
53 | \ /
54 | \ /
55 | \___/ Preliminaries
56 | """
57 |
58 | AgentTrainState = namedtuple("AgentTrainState", "actor vec_q vec_q_target alpha")
59 | Transition = namedtuple("Transition", "obs action reward next_obs done")
60 |
61 |
62 | def sym(scale):
63 | def _init(*args, **kwargs):
64 | return uniform(2 * scale)(*args, **kwargs) - scale
65 |
66 | return _init
67 |
68 |
69 | class SoftQNetwork(nn.Module):
70 | @nn.compact
71 | def __call__(self, obs, action):
72 | x = jnp.concatenate([obs, action], axis=-1)
73 | for _ in range(3):
74 | x = nn.Dense(256, bias_init=constant(0.1))(x)
75 | x = nn.relu(x)
76 | q = nn.Dense(1, kernel_init=sym(3e-3), bias_init=sym(3e-3))(x)
77 | return q.squeeze(-1)
78 |
79 |
80 | class VectorQ(nn.Module):
81 | num_critics: int
82 |
83 | @nn.compact
84 | def __call__(self, obs, action):
85 | vmap_critic = nn.vmap(
86 | SoftQNetwork,
87 | variable_axes={"params": 0}, # Parameters not shared between critics
88 | split_rngs={"params": True, "dropout": True}, # Different initializations
89 | in_axes=None,
90 | out_axes=-1,
91 | axis_size=self.num_critics,
92 | )
93 | q_values = vmap_critic()(obs, action)
94 | return q_values
95 |
96 |
97 | class TanhGaussianActor(nn.Module):
98 | num_actions: int
99 | log_std_max: float = 2.0
100 | log_std_min: float = -5.0
101 |
102 | @nn.compact
103 | def __call__(self, x):
104 | for _ in range(3):
105 | x = nn.Dense(256, bias_init=constant(0.1))(x)
106 | x = nn.relu(x)
107 | log_std = nn.Dense(
108 | self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3)
109 | )(x)
110 | std = jnp.exp(jnp.clip(log_std, self.log_std_min, self.log_std_max))
111 | mean = nn.Dense(self.num_actions, kernel_init=sym(1e-3), bias_init=sym(1e-3))(x)
112 | pi = distrax.Transformed(
113 | distrax.Normal(mean, std),
114 | distrax.Tanh(),
115 | )
116 | return pi
117 |
118 |
119 | class EntropyCoef(nn.Module):
120 | ent_coef_init: float = 1.0
121 |
122 | @nn.compact
123 | def __call__(self) -> jnp.ndarray:
124 | log_ent_coef = self.param(
125 | "log_ent_coef",
126 | init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)),
127 | )
128 | return log_ent_coef
129 |
130 |
131 | def create_train_state(args, rng, network, dummy_input):
132 | return TrainState.create(
133 | apply_fn=network.apply,
134 | params=network.init(rng, *dummy_input),
135 | tx=optax.adam(args.lr, eps=1e-5),
136 | )
137 |
138 |
139 | def eval_agent(args, rng, env, agent_state):
140 | # --- Reset environment ---
141 | step = 0
142 | returned = onp.zeros(args.eval_workers).astype(bool)
143 | cum_reward = onp.zeros(args.eval_workers)
144 | rng, rng_reset = jax.random.split(rng)
145 | rng_reset = jax.random.split(rng_reset, args.eval_workers)
146 | obs = env.reset()
147 |
148 | # --- Rollout agent ---
149 | @jax.jit
150 | @jax.vmap
151 | def _policy_step(rng, obs):
152 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs)
153 | action = pi.sample(seed=rng)
154 | return jnp.nan_to_num(action)
155 |
156 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps
157 | while step < max_episode_steps and not returned.all():
158 | # --- Take step in environment ---
159 | step += 1
160 | rng, rng_step = jax.random.split(rng)
161 | rng_step = jax.random.split(rng_step, args.eval_workers)
162 | action = _policy_step(rng_step, jnp.array(obs))
163 | obs, reward, done, info = env.step(onp.array(action))
164 |
165 | # --- Track cumulative reward ---
166 | cum_reward += reward * ~returned
167 | returned |= done
168 |
169 | if step >= max_episode_steps and not returned.all():
170 | warnings.warn("Maximum steps reached before all episodes terminated")
171 | return cum_reward
172 |
173 |
174 | r"""
175 | __/)
176 | .-(__(=:
177 | |\ | \)
178 | \ ||
179 | \||
180 | \|
181 | ___|_____
182 | \ /
183 | \ /
184 | \___/ Agent
185 | """
186 |
187 |
188 | def make_train_step(args, actor_apply_fn, q_apply_fn, alpha_apply_fn, dataset):
189 | """Make JIT-compatible agent train step."""
190 |
191 | def _train_step(runner_state, _):
192 | rng, agent_state = runner_state
193 |
194 | # --- Sample batch ---
195 | rng, rng_batch = jax.random.split(rng)
196 | batch_indices = jax.random.randint(
197 | rng_batch, (args.batch_size,), 0, len(dataset.obs)
198 | )
199 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset)
200 |
201 | # --- Update alpha ---
202 | @jax.value_and_grad
203 | def _alpha_loss_fn(params, rng):
204 | def _compute_entropy(rng, transition):
205 | pi = actor_apply_fn(agent_state.actor.params, transition.obs)
206 | _, log_pi = pi.sample_and_log_prob(seed=rng)
207 | return -log_pi.sum()
208 |
209 | log_alpha = alpha_apply_fn(params)
210 | rng = jax.random.split(rng, args.batch_size)
211 | entropy = jax.vmap(_compute_entropy)(rng, batch).mean()
212 | target_entropy = -batch.action.shape[-1]
213 | return log_alpha * (entropy - target_entropy)
214 |
215 | rng, rng_alpha = jax.random.split(rng)
216 | alpha_loss, alpha_grad = _alpha_loss_fn(agent_state.alpha.params, rng_alpha)
217 | updated_alpha = agent_state.alpha.apply_gradients(grads=alpha_grad)
218 | agent_state = agent_state._replace(alpha=updated_alpha)
219 | alpha = jnp.exp(alpha_apply_fn(agent_state.alpha.params))
220 |
221 | # --- Update actor ---
222 | @partial(jax.value_and_grad, has_aux=True)
223 | def _actor_loss_function(params, rng):
224 | def _compute_loss(rng, transition):
225 | pi = actor_apply_fn(params, transition.obs)
226 | sampled_action, log_pi = pi.sample_and_log_prob(seed=rng)
227 | log_pi = log_pi.sum()
228 | q_values = q_apply_fn(
229 | agent_state.vec_q.params, transition.obs, sampled_action
230 | )
231 | q_min = jnp.min(q_values)
232 | return -q_min + alpha * log_pi, -log_pi, q_min, q_values.std()
233 |
234 | rng = jax.random.split(rng, args.batch_size)
235 | loss, entropy, q_min, q_std = jax.vmap(_compute_loss)(rng, batch)
236 | return loss.mean(), (entropy.mean(), q_min.mean(), q_std.mean())
237 |
238 | rng, rng_actor = jax.random.split(rng)
239 | (actor_loss, (entropy, q_min, q_std)), actor_grad = _actor_loss_function(
240 | agent_state.actor.params, rng_actor
241 | )
242 | updated_actor = agent_state.actor.apply_gradients(grads=actor_grad)
243 | agent_state = agent_state._replace(actor=updated_actor)
244 |
245 | # --- Update Q target network ---
246 | updated_q_target_params = optax.incremental_update(
247 | agent_state.vec_q.params,
248 | agent_state.vec_q_target.params,
249 | args.polyak_step_size,
250 | )
251 | updated_q_target = agent_state.vec_q_target.replace(
252 | step=agent_state.vec_q_target.step + 1, params=updated_q_target_params
253 | )
254 | agent_state = agent_state._replace(vec_q_target=updated_q_target)
255 |
256 | # --- Compute targets ---
257 | def _sample_next_v(rng, transition):
258 | next_pi = actor_apply_fn(agent_state.actor.params, transition.next_obs)
259 | # Note: Important to use sample_and_log_prob here for numerical stability
260 | # See https://github.com/deepmind/distrax/issues/7
261 | next_action, log_next_pi = next_pi.sample_and_log_prob(seed=rng)
262 | # Minimum of the target Q-values
263 | next_q = q_apply_fn(
264 | agent_state.vec_q_target.params, transition.next_obs, next_action
265 | )
266 | return next_q.min(-1) - alpha * log_next_pi.sum(-1)
267 |
268 | rng, rng_next_v = jax.random.split(rng)
269 | rng_next_v = jax.random.split(rng_next_v, args.batch_size)
270 | next_v_target = jax.vmap(_sample_next_v)(rng_next_v, batch)
271 | target = batch.reward + args.gamma * (1 - batch.done) * next_v_target
272 |
273 | # --- Update critics ---
274 | @jax.value_and_grad
275 | def _q_loss_fn(params):
276 | q_pred = q_apply_fn(params, batch.obs, batch.action)
277 | return jnp.square((q_pred - jnp.expand_dims(target, -1))).sum(-1).mean()
278 |
279 | critic_loss, critic_grad = _q_loss_fn(agent_state.vec_q.params)
280 | updated_q = agent_state.vec_q.apply_gradients(grads=critic_grad)
281 | agent_state = agent_state._replace(vec_q=updated_q)
282 |
283 | loss = {
284 | "critic_loss": critic_loss,
285 | "actor_loss": actor_loss,
286 | "alpha_loss": alpha_loss,
287 | "entropy": entropy,
288 | "alpha": alpha,
289 | "q_min": q_min,
290 | "q_std": q_std,
291 | }
292 | return (rng, agent_state), loss
293 |
294 | return _train_step
295 |
296 |
297 | if __name__ == "__main__":
298 | # --- Parse arguments ---
299 | args = tyro.cli(Args)
300 | rng = jax.random.PRNGKey(args.seed)
301 |
302 | # --- Initialize logger ---
303 | if args.log:
304 | wandb.init(
305 | config=args,
306 | project=args.wandb_project,
307 | entity=args.wandb_team,
308 | group=args.wandb_group,
309 | job_type="train_agent",
310 | )
311 |
312 | # --- Initialize environment and dataset ---
313 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers)
314 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset))
315 | dataset = Transition(
316 | obs=jnp.array(dataset["observations"]),
317 | action=jnp.array(dataset["actions"]),
318 | reward=jnp.array(dataset["rewards"]),
319 | next_obs=jnp.array(dataset["next_observations"]),
320 | done=jnp.array(dataset["terminals"]),
321 | )
322 |
323 | # --- Initialize agent and value networks ---
324 | num_actions = env.single_action_space.shape[0]
325 | dummy_obs = jnp.zeros(env.single_observation_space.shape)
326 | dummy_action = jnp.zeros(num_actions)
327 | actor_net = TanhGaussianActor(num_actions)
328 | q_net = VectorQ(args.num_critics)
329 | alpha_net = EntropyCoef()
330 |
331 | # Target networks share seeds to match initialization
332 | rng, rng_actor, rng_q, rng_alpha = jax.random.split(rng, 4)
333 | agent_state = AgentTrainState(
334 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]),
335 | vec_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
336 | vec_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
337 | alpha=create_train_state(args, rng_alpha, alpha_net, []),
338 | )
339 |
340 | # --- Make train step ---
341 | _agent_train_step_fn = make_train_step(
342 | args, actor_net.apply, q_net.apply, alpha_net.apply, dataset
343 | )
344 |
345 | num_evals = args.num_updates // args.eval_interval
346 | for eval_idx in range(num_evals):
347 | # --- Execute train loop ---
348 | (rng, agent_state), loss = jax.lax.scan(
349 | _agent_train_step_fn,
350 | (rng, agent_state),
351 | None,
352 | args.eval_interval,
353 | )
354 |
355 | # --- Evaluate agent ---
356 | rng, rng_eval = jax.random.split(rng)
357 | returns = eval_agent(args, rng_eval, env, agent_state)
358 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0
359 |
360 | # --- Log metrics ---
361 | step = (eval_idx + 1) * args.eval_interval
362 | print("Step:", step, f"\t Score: {scores.mean():.2f}")
363 | if args.log:
364 | log_dict = {
365 | "return": returns.mean(),
366 | "score": scores.mean(),
367 | "score_std": scores.std(),
368 | "num_updates": step,
369 | **{k: loss[k][-1] for k in loss},
370 | }
371 | wandb.log(log_dict)
372 |
373 | # --- Evaluate final agent ---
374 | if args.eval_final_episodes > 0:
375 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers))
376 | print(f"Evaluating final agent for {final_iters} iterations...")
377 | _rng = jax.random.split(rng, final_iters)
378 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng])
379 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0
380 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()}
381 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores")
382 |
383 | # --- Write final returns to file ---
384 | os.makedirs("final_returns", exist_ok=True)
385 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
386 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz"
387 | with open(os.path.join("final_returns", filename), "wb") as f:
388 | onp.savez_compressed(f, **info, args=asdict(args))
389 |
390 | if args.log:
391 | wandb.save(os.path.join("final_returns", filename))
392 |
393 | if args.log:
394 | wandb.finish()
395 |
--------------------------------------------------------------------------------
/algorithms/td3_bc.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from dataclasses import dataclass, asdict
3 | from datetime import datetime
4 | import os
5 | import warnings
6 |
7 | import distrax
8 | import d4rl
9 | import flax.linen as nn
10 | from flax.training.train_state import TrainState
11 | import gym
12 | import jax
13 | import jax.numpy as jnp
14 | import numpy as onp
15 | import optax
16 | import tyro
17 | import wandb
18 |
19 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
20 |
21 |
22 | @dataclass
23 | class Args:
24 | # --- Experiment ---
25 | seed: int = 0
26 | dataset: str = "halfcheetah-medium-v2"
27 | algorithm: str = "td3_bc"
28 | num_updates: int = 1_000_000
29 | eval_interval: int = 2500
30 | eval_workers: int = 8
31 | eval_final_episodes: int = 1000
32 | # --- Logging ---
33 | log: bool = False
34 | wandb_project: str = "unifloral"
35 | wandb_team: str = "flair"
36 | wandb_group: str = "debug"
37 | # --- Generic optimization ---
38 | lr: float = 3e-4
39 | batch_size: int = 256
40 | gamma: float = 0.99
41 | polyak_step_size: float = 0.005
42 | # --- TD3+BC ---
43 | td3_alpha: float = 2.5
44 | noise_clip: float = 0.5
45 | policy_noise: float = 0.2
46 | num_critic_updates_per_step: int = 2
47 |
48 |
49 | r"""
50 | |\ __
51 | \| /_/
52 | \|
53 | ___|_____
54 | \ /
55 | \ /
56 | \___/ Preliminaries
57 | """
58 |
59 | AgentTrainState = namedtuple(
60 | "AgentTrainState", "actor actor_target dual_q dual_q_target"
61 | )
62 | Transition = namedtuple("Transition", "obs action reward next_obs done")
63 |
64 |
65 | class SoftQNetwork(nn.Module):
66 | obs_mean: jax.Array
67 | obs_std: jax.Array
68 |
69 | @nn.compact
70 | def __call__(self, obs, action):
71 | obs = (obs - self.obs_mean) / (self.obs_std + 1e-3)
72 | x = jnp.concatenate([obs, action], axis=-1)
73 | for _ in range(2):
74 | x = nn.Dense(256)(x)
75 | x = nn.relu(x)
76 | q = nn.Dense(1)(x)
77 | return q.squeeze(-1)
78 |
79 |
80 | class DualQNetwork(nn.Module):
81 | obs_mean: jax.Array
82 | obs_std: jax.Array
83 |
84 | @nn.compact
85 | def __call__(self, obs, action):
86 | vmap_critic = nn.vmap(
87 | SoftQNetwork,
88 | variable_axes={"params": 0}, # Parameters not shared between critics
89 | split_rngs={"params": True, "dropout": True}, # Different initializations
90 | in_axes=None,
91 | out_axes=-1,
92 | axis_size=2, # Two Q networks
93 | )
94 | q_values = vmap_critic(self.obs_mean, self.obs_std)(obs, action)
95 | return q_values
96 |
97 |
98 | class DeterministicTanhActor(nn.Module):
99 | num_actions: int
100 | obs_mean: jax.Array
101 | obs_std: jax.Array
102 |
103 | @nn.compact
104 | def __call__(self, x):
105 | x = (x - self.obs_mean) / (self.obs_std + 1e-3)
106 | for _ in range(2):
107 | x = nn.Dense(256)(x)
108 | x = nn.relu(x)
109 | action = nn.Dense(self.num_actions)(x)
110 | pi = distrax.Transformed(
111 | distrax.Deterministic(action),
112 | distrax.Tanh(),
113 | )
114 | return pi
115 |
116 |
117 | def create_train_state(args, rng, network, dummy_input):
118 | return TrainState.create(
119 | apply_fn=network.apply,
120 | params=network.init(rng, *dummy_input),
121 | tx=optax.adam(args.lr, eps=1e-5),
122 | )
123 |
124 |
125 | def eval_agent(args, rng, env, agent_state):
126 | # --- Reset environment ---
127 | step = 0
128 | returned = onp.zeros(args.eval_workers).astype(bool)
129 | cum_reward = onp.zeros(args.eval_workers)
130 | rng, rng_reset = jax.random.split(rng)
131 | rng_reset = jax.random.split(rng_reset, args.eval_workers)
132 | obs = env.reset()
133 |
134 | # --- Rollout agent ---
135 | @jax.jit
136 | @jax.vmap
137 | def _policy_step(rng, obs):
138 | pi = agent_state.actor.apply_fn(agent_state.actor.params, obs)
139 | action = pi.sample(seed=rng)
140 | return jnp.nan_to_num(action)
141 |
142 | max_episode_steps = env.env_fns[0]().spec.max_episode_steps
143 | while step < max_episode_steps and not returned.all():
144 | # --- Take step in environment ---
145 | step += 1
146 | rng, rng_step = jax.random.split(rng)
147 | rng_step = jax.random.split(rng_step, args.eval_workers)
148 | action = _policy_step(rng_step, jnp.array(obs))
149 | obs, reward, done, info = env.step(onp.array(action))
150 |
151 | # --- Track cumulative reward ---
152 | cum_reward += reward * ~returned
153 | returned |= done
154 |
155 | if step >= max_episode_steps and not returned.all():
156 | warnings.warn("Maximum steps reached before all episodes terminated")
157 | return cum_reward
158 |
159 |
160 | r"""
161 | __/)
162 | .-(__(=:
163 | |\ | \)
164 | \ ||
165 | \||
166 | \|
167 | ___|_____
168 | \ /
169 | \ /
170 | \___/ Agent
171 | """
172 |
173 |
174 | def make_train_step(args, actor_apply_fn, q_apply_fn, dataset):
175 | """Make JIT-compatible agent train step."""
176 |
177 | def _train_step(runner_state, _):
178 | rng, agent_state = runner_state
179 |
180 | # --- Sample batch ---
181 | rng, rng_batch = jax.random.split(rng)
182 | batch_indices = jax.random.randint(
183 | rng_batch, (args.batch_size,), 0, len(dataset.obs)
184 | )
185 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset)
186 |
187 | # --- Update critics ---
188 | def _update_critics(runner_state, _):
189 | rng, agent_state = runner_state
190 |
191 | def _compute_target(rng, transition):
192 | next_obs = transition.next_obs
193 |
194 | # --- Sample noised action ---
195 | next_pi = actor_apply_fn(agent_state.actor_target.params, next_obs)
196 | rng, rng_action, rng_noise = jax.random.split(rng, 3)
197 | action = next_pi.sample(seed=rng_action)
198 | noise = jax.random.normal(rng_noise, shape=action.shape)
199 | noise *= args.policy_noise
200 | noise = jnp.clip(noise, -args.noise_clip, args.noise_clip)
201 | action = jnp.clip(action + noise, -1, 1)
202 |
203 | # --- Compute targets ---
204 | target_q = q_apply_fn(
205 | agent_state.dual_q_target.params, next_obs, action
206 | )
207 | next_q_value = (1.0 - transition.done) * jnp.min(target_q)
208 | return transition.reward + args.gamma * next_q_value
209 |
210 | rng, rng_targets = jax.random.split(rng)
211 | rng_targets = jax.random.split(rng_targets, args.batch_size)
212 | targets = jax.vmap(_compute_target)(rng_targets, batch)
213 |
214 | # --- Compute critic loss ---
215 | @jax.value_and_grad
216 | def _q_loss_fn(params):
217 | q_pred = q_apply_fn(params, batch.obs, batch.action)
218 | return jnp.square(q_pred - jnp.expand_dims(targets, axis=-1)).mean()
219 |
220 | q_loss, q_grad = _q_loss_fn(agent_state.dual_q.params)
221 | updated_q_state = agent_state.dual_q.apply_gradients(grads=q_grad)
222 | agent_state = agent_state._replace(dual_q=updated_q_state)
223 | return (rng, agent_state), q_loss
224 |
225 | # --- Iterate critic update ---
226 | (rng, agent_state), q_loss = jax.lax.scan(
227 | _update_critics,
228 | (rng, agent_state),
229 | None,
230 | length=args.num_critic_updates_per_step,
231 | )
232 |
233 | # --- Update actor ---
234 | def _actor_loss_function(params):
235 | def _transition_loss(transition):
236 | pi = actor_apply_fn(params, transition.obs)
237 | pi_action = pi.sample(seed=None)
238 | q = q_apply_fn(agent_state.dual_q.params, transition.obs, pi_action)
239 | bc_loss = jnp.square(pi_action - transition.action).mean()
240 | return q[0], bc_loss
241 |
242 | q, bc_loss = jax.vmap(_transition_loss)(batch)
243 | lambda_ = args.td3_alpha / (jnp.abs(q).mean() + 1e-7)
244 | lambda_ = jax.lax.stop_gradient(lambda_)
245 | actor_loss = -lambda_ * q.mean() + bc_loss.mean()
246 | return actor_loss.mean(), (q.mean(), lambda_.mean(), bc_loss.mean())
247 |
248 | loss_fn = jax.value_and_grad(_actor_loss_function, has_aux=True)
249 | (actor_loss, (q_mean, lambda_, bc_loss)), actor_grad = loss_fn(
250 | agent_state.actor.params
251 | )
252 | agent_state = agent_state._replace(
253 | actor=agent_state.actor.apply_gradients(grads=actor_grad)
254 | )
255 |
256 | # --- Update target networks ---
257 | def _update_target(state, target_state):
258 | new_target_params = optax.incremental_update(
259 | state.params, target_state.params, args.polyak_step_size
260 | )
261 | return target_state.replace(
262 | step=target_state.step + 1, params=new_target_params
263 | )
264 |
265 | agent_state = agent_state._replace(
266 | actor_target=_update_target(agent_state.actor, agent_state.actor_target),
267 | dual_q_target=_update_target(agent_state.dual_q, agent_state.dual_q_target),
268 | )
269 |
270 | loss = {
271 | "actor_loss": actor_loss,
272 | "q_loss": q_loss.mean(),
273 | "q_mean": q_mean,
274 | "lambda": lambda_,
275 | "bc_loss": bc_loss,
276 | }
277 | return (rng, agent_state), loss
278 |
279 | return _train_step
280 |
281 |
282 | if __name__ == "__main__":
283 | # --- Parse arguments ---
284 | args = tyro.cli(Args)
285 | rng = jax.random.PRNGKey(args.seed)
286 |
287 | # --- Initialize logger ---
288 | if args.log:
289 | wandb.init(
290 | config=args,
291 | project=args.wandb_project,
292 | entity=args.wandb_team,
293 | group=args.wandb_group,
294 | job_type="train_agent",
295 | )
296 |
297 | # --- Initialize environment and dataset ---
298 | env = gym.vector.make(args.dataset, num_envs=args.eval_workers)
299 | dataset = d4rl.qlearning_dataset(gym.make(args.dataset))
300 | dataset = Transition(
301 | obs=jnp.array(dataset["observations"]),
302 | action=jnp.array(dataset["actions"]),
303 | reward=jnp.array(dataset["rewards"]),
304 | next_obs=jnp.array(dataset["next_observations"]),
305 | done=jnp.array(dataset["terminals"]),
306 | )
307 |
308 | # --- Initialize agent and value networks ---
309 | num_actions = env.single_action_space.shape[0]
310 | obs_mean = dataset.obs.mean(axis=0)
311 | obs_std = jnp.nan_to_num(dataset.obs.std(axis=0), nan=1.0)
312 | dummy_obs = jnp.zeros(env.single_observation_space.shape)
313 | dummy_action = jnp.zeros(num_actions)
314 | actor_net = DeterministicTanhActor(num_actions, obs_mean, obs_std)
315 | q_net = DualQNetwork(obs_mean, obs_std)
316 |
317 | # Target networks share seeds to match initialization
318 | rng, rng_actor, rng_q = jax.random.split(rng, 3)
319 | agent_state = AgentTrainState(
320 | actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]),
321 | actor_target=create_train_state(args, rng_actor, actor_net, [dummy_obs]),
322 | dual_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
323 | dual_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
324 | )
325 |
326 | # --- Make train step ---
327 | _agent_train_step_fn = make_train_step(args, actor_net.apply, q_net.apply, dataset)
328 |
329 | num_evals = args.num_updates // args.eval_interval
330 | for eval_idx in range(num_evals):
331 | # --- Execute train loop ---
332 | (rng, agent_state), loss = jax.lax.scan(
333 | _agent_train_step_fn,
334 | (rng, agent_state),
335 | None,
336 | args.eval_interval,
337 | )
338 |
339 | # --- Evaluate agent ---
340 | rng, rng_eval = jax.random.split(rng)
341 | returns = eval_agent(args, rng_eval, env, agent_state)
342 | scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0
343 |
344 | # --- Log metrics ---
345 | step = (eval_idx + 1) * args.eval_interval
346 | print("Step:", step, f"\t Score: {scores.mean():.2f}")
347 | if args.log:
348 | log_dict = {
349 | "return": returns.mean(),
350 | "score": scores.mean(),
351 | "score_std": scores.std(),
352 | "num_updates": step,
353 | **{k: loss[k][-1] for k in loss},
354 | }
355 | wandb.log(log_dict)
356 |
357 | # --- Evaluate final agent ---
358 | if args.eval_final_episodes > 0:
359 | final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers))
360 | print(f"Evaluating final agent for {final_iters} iterations...")
361 | _rng = jax.random.split(rng, final_iters)
362 | rets = onp.concat([eval_agent(args, _rng, env, agent_state) for _rng in _rng])
363 | scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0
364 | agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()}
365 | info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores")
366 |
367 | # --- Write final returns to file ---
368 | os.makedirs("final_returns", exist_ok=True)
369 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
370 | filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz"
371 | with open(os.path.join("final_returns", filename), "wb") as f:
372 | onp.savez_compressed(f, **info, args=asdict(args))
373 |
374 | if args.log:
375 | wandb.save(os.path.join("final_returns", filename))
376 |
377 | if args.log:
378 | wandb.finish()
379 |
--------------------------------------------------------------------------------
/algorithms/termination_fns.py:
--------------------------------------------------------------------------------
1 | """
2 | The world models we use in model-based RL don't predict termination, so we need to
3 | define termination functions for each task.
4 | This code is adapted from
5 | https://github.com/yihaosun1124/OfflineRL-Kit/blob/6e578d13568fa934096baa2ca96e38e1fa44a233/offlinerlkit/utils/termination_fns.py#L123
6 | Thanks to the authors!
7 | """
8 |
9 | import jax.numpy as jnp
10 |
11 |
12 | def termination_fn_halfcheetah(obs, act, next_obs):
13 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
14 |
15 | not_done = jnp.logical_and(jnp.all(next_obs > -100), jnp.all(next_obs < 100))
16 | done = ~not_done
17 | return done
18 |
19 |
20 | def termination_fn_hopper(obs, act, next_obs):
21 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
22 |
23 | height = next_obs[0]
24 | angle = next_obs[1]
25 | not_done = (
26 | jnp.isfinite(next_obs).all()
27 | * jnp.abs(next_obs[1:] < 100).all()
28 | * (height > 0.7)
29 | * (jnp.abs(angle) < 0.2)
30 | )
31 |
32 | done = ~not_done
33 | return done
34 |
35 |
36 | def termination_fn_halfcheetahveljump(obs, act, next_obs):
37 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
38 |
39 | done = jnp.array(False)
40 | return done
41 |
42 |
43 | def termination_fn_antangle(obs, act, next_obs):
44 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
45 |
46 | x = next_obs[0]
47 | not_done = jnp.isfinite(next_obs).all() * (x >= 0.2) * (x <= 1.0)
48 |
49 | done = ~not_done
50 | return done
51 |
52 |
53 | def termination_fn_ant(obs, act, next_obs):
54 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
55 |
56 | x = next_obs[0]
57 | not_done = jnp.isfinite(next_obs).all() * (x >= 0.2) * (x <= 1.0)
58 |
59 | done = ~not_done
60 | return done
61 |
62 |
63 | def termination_fn_walker2d(obs, act, next_obs):
64 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
65 |
66 | height = next_obs[0]
67 | angle = next_obs[1]
68 | not_done = (
69 | jnp.logical_and(jnp.all(next_obs > -100), jnp.all(next_obs < 100))
70 | * (height > 0.8)
71 | * (height < 2.0)
72 | * (angle > -1.0)
73 | * (angle < 1.0)
74 | )
75 | done = ~not_done
76 | return done
77 |
78 |
79 | def termination_fn_point2denv(obs, act, next_obs):
80 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
81 |
82 | done = jnp.array(False)
83 | return done
84 |
85 |
86 | def termination_fn_point2dwallenv(obs, act, next_obs):
87 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
88 |
89 | done = jnp.array(False)
90 | return done
91 |
92 |
93 | def termination_fn_pendulum(obs, act, next_obs):
94 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
95 |
96 | done = jnp.array(False)
97 | return done
98 |
99 |
100 | def termination_fn_humanoid(obs, act, next_obs):
101 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
102 |
103 | z = next_obs[0]
104 | done = (z < 1.0) + (z > 2.0)
105 |
106 | return done
107 |
108 |
109 | def termination_fn_pen(obs, act, next_obs):
110 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
111 |
112 | obj_pos = next_obs[24:27]
113 | done = obj_pos[2] < 0.075
114 |
115 | return done
116 |
117 |
118 | def terminaltion_fn_door(obs, act, next_obs):
119 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
120 |
121 | done = jnp.array(False)
122 |
123 | return done
124 |
125 |
126 | def termination_fn_relocate(obs, act, next_obs):
127 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
128 |
129 | done = jnp.array(False)
130 |
131 | return done
132 |
133 |
134 | def maze2d_open_termination_fn(obs, act, next_obs):
135 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
136 |
137 | agent_location = jnp.array([obs[0], obs[1]])
138 | goal_location = jnp.array([2, 3])
139 | done = jnp.linalg.norm(agent_location - goal_location) < 0.5
140 | return done
141 |
142 |
143 | def maze2d_umaze_termination_fn(obs, act, next_obs):
144 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
145 |
146 | agent_location = jnp.array([obs[0], obs[1]])
147 | goal_location = jnp.array([1, 1])
148 | done = jnp.linalg.norm(agent_location - goal_location) < 0.5
149 | return done
150 |
151 |
152 | def maze2d_medium_termination_fn(obs, act, next_obs):
153 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
154 |
155 | agent_location = jnp.array([obs[0], obs[1]])
156 | goal_location = jnp.array([6, 6])
157 | done = jnp.linalg.norm(agent_location - goal_location) < 0.5
158 | return done
159 |
160 |
161 | def maze2d_large_termination_fn(obs, act, next_obs):
162 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
163 |
164 | agent_location = jnp.array([obs[0], obs[1]])
165 | goal_location = jnp.array([7, 9])
166 | done = jnp.linalg.norm(agent_location - goal_location) < 0.5
167 | return done
168 |
169 |
170 | def termination_fn_kitchen(obs, act, next_obs):
171 | assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 1
172 |
173 | # Implementing termination function is tricky, since it's unclear how it works in
174 | # the original code (see d4rl/kitchen/kitchen_envs.py)
175 | # Not terminating the episode works as well, it is even an argument defined in gym.
176 | done = jnp.array(False)
177 | return done
178 |
179 |
180 | def get_termination_fn(task):
181 | if "halfcheetahvel" in task:
182 | return termination_fn_halfcheetahveljump
183 | elif "halfcheetah" in task:
184 | return termination_fn_halfcheetah
185 | elif "hopper" in task:
186 | return termination_fn_hopper
187 | elif "antangle" in task:
188 | return termination_fn_antangle
189 | elif "ant" in task:
190 | return termination_fn_ant
191 | elif "walker2d" in task:
192 | return termination_fn_walker2d
193 | elif "point2denv" in task:
194 | return termination_fn_point2denv
195 | elif "point2dwallenv" in task:
196 | return termination_fn_point2dwallenv
197 | elif "pendulum" in task:
198 | return termination_fn_pendulum
199 | elif "humanoid" in task:
200 | return termination_fn_humanoid
201 | elif "maze2d-open" in task:
202 | return maze2d_open_termination_fn
203 | elif "maze2d-umaze" in task:
204 | return maze2d_umaze_termination_fn
205 | elif "maze2d-medium" in task:
206 | return maze2d_medium_termination_fn
207 | elif "maze2d-large" in task:
208 | return maze2d_large_termination_fn
209 | elif "pen" in task:
210 | return termination_fn_pen
211 | elif "door" in task:
212 | return terminaltion_fn_door
213 | elif "relocate" in task:
214 | return termination_fn_relocate
215 | elif "kitchen" in task:
216 | return termination_fn_kitchen
217 | else:
218 | raise ValueError(f"Unknown task: {task}")
219 |
--------------------------------------------------------------------------------
/configs/algorithms/bc.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: BC
8 | program: algorithms/bc.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: bc
20 | num_updates:
21 | value: 1_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 3e-4
42 | batch_size:
43 | value: 256
44 |
--------------------------------------------------------------------------------
/configs/algorithms/combo.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: COMBO
8 | program: algorithms/combo.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | value: 0 # Fixed seed since we're sweeping over models
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: combo
20 | num_updates:
21 | value: 3000000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | values:
42 | - 1e-4
43 | - 3e-4
44 | batch_size:
45 | value: 256
46 | gamma:
47 | value: 0.99
48 | polyak_step_size:
49 | value: 0.005
50 |
51 | # --- Model-based specific ---
52 | model_path:
53 | values: # This will be replaced with actual model paths for each environment
54 | - "PLACEHOLDER_MODEL_PATH"
55 | model_retain_epochs:
56 | value: 5
57 | num_critics:
58 | value: 10
59 | rollout_batch_size:
60 | value: 50000
61 | rollout_interval:
62 | value: 1000
63 | rollout_length:
64 | values: [1, 5, 25]
65 | dataset_sample_ratio:
66 | values: [0.5, 0.8]
67 |
68 | # --- CQL ---
69 | actor_lr:
70 | values: [1e-5, 3e-5, 1e-4]
71 | cql_temperature:
72 | value: 1.0
73 | cql_min_q_weight:
74 | values: [0.5, 1.0, 5.0]
75 |
--------------------------------------------------------------------------------
/configs/algorithms/cql.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: CQL
8 | program: algorithms/cql.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: cql
20 | num_updates:
21 | value: 1_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | values: [1e-4, 3e-4]
42 | batch_size:
43 | value: 256
44 | gamma:
45 | value: 0.99
46 | polyak_step_size:
47 | value: 0.005
48 |
49 | # --- SAC-N ---
50 | num_critics:
51 | value: 10
52 |
53 | # --- CQL ---
54 | actor_lr:
55 | values: [3e-5, 1e-4, 3e-4]
56 | cql_temperature:
57 | value: 1.0
58 | cql_min_q_weight:
59 | value: 10.0
60 |
--------------------------------------------------------------------------------
/configs/algorithms/edac.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: EDAC
8 | program: algorithms/edac.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: edac
20 | num_updates:
21 | value: 3000000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 0.0003
42 | batch_size:
43 | value: 256
44 | gamma:
45 | value: 0.99
46 | polyak_step_size:
47 | value: 0.005
48 |
49 | # --- SAC-N ---
50 | num_critics:
51 | values: [10, 20, 50] # 100 has high GPU memory usage
52 |
53 | # --- EDAC ---
54 | eta:
55 | values: [0.0, 1.0, 5.0, 10.0, 100.0, 200.0, 500.0, 1000.0]
56 |
--------------------------------------------------------------------------------
/configs/algorithms/iql.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: IQL
8 | program: algorithms/iql.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: iql
20 | num_updates:
21 | value: 1_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 0.0003
42 | batch_size:
43 | value: 256
44 | gamma:
45 | value: 0.99
46 | polyak_step_size:
47 | value: 0.005
48 |
49 | # --- IQL specific ---
50 | beta:
51 | values: [0.5, 3.0, 10.0]
52 | iql_tau:
53 | values: [0.5, 0.7, 0.9]
54 | exp_adv_clip:
55 | value: 100.0
56 |
--------------------------------------------------------------------------------
/configs/algorithms/mopo.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: MOPO
8 | program: algorithms/mopo.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | value: 0 # Fixed seed since we're sweeping over models
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: mopo
20 | num_updates:
21 | value: 3000000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 1e-4
42 | batch_size:
43 | value: 256
44 | gamma:
45 | value: 0.99
46 | polyak_step_size:
47 | value: 0.005
48 |
49 | # --- Model-based specific ---
50 | model_path:
51 | values: # This should be replaced with actual model paths for each environment
52 | - "PLACEHOLDER_MODEL_PATH"
53 | model_retain_epochs:
54 | value: 5
55 | num_critics:
56 | value: 10
57 | rollout_batch_size:
58 | value: 50000
59 | rollout_interval:
60 | value: 1000
61 | rollout_length:
62 | values: [1, 5]
63 | dataset_sample_ratio:
64 | value: 0.05
65 |
66 | # --- MOPO specific ---
67 | step_penalty_coef:
68 | values: [1.0, 5.0]
69 |
--------------------------------------------------------------------------------
/configs/algorithms/morel.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: MOReL
8 | program: algorithms/morel.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | value: 0 # Fixed seed since we're sweeping over models
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: morel
20 | num_updates:
21 | value: 3000000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 1e-4
42 | batch_size:
43 | value: 256
44 | gamma:
45 | value: 0.99
46 | polyak_step_size:
47 | value: 0.005
48 |
49 | # --- Model-based specific ---
50 | model_path:
51 | values: # This should be replaced with actual model paths for each environment
52 | - "PLACEHOLDER_MODEL_PATH"
53 | model_retain_epochs:
54 | value: 5
55 | num_critics:
56 | value: 10
57 | rollout_batch_size:
58 | value: 50000
59 | rollout_interval:
60 | value: 1000
61 | rollout_length:
62 | value: 5
63 | dataset_sample_ratio:
64 | value: 0.01
65 |
66 | # --- MOREL specific ---
67 | threshold_coef:
68 | values: [0, 5, 10, 15, 20, 25]
69 | term_penalty_offset:
70 | values: [-30, -50, -100, -200]
71 |
--------------------------------------------------------------------------------
/configs/algorithms/rebrac.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: ReBRAC
8 | program: algorithms/rebrac.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: rebrac
20 | num_updates:
21 | value: 1_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 1e-3
42 | batch_size:
43 | value: 1024
44 | gamma:
45 | value: 0.99
46 | polyak_step_size:
47 | value: 0.005
48 |
49 | # --- TD3+BC ---
50 | noise_clip:
51 | value: 0.5
52 | policy_noise:
53 | value: 0.2
54 | num_critic_updates_per_step:
55 | value: 2
56 |
57 | # --- REBRAC ---
58 | critic_bc_coef:
59 | values: [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1]
60 | actor_bc_coef:
61 | values: [0.0005, 0.001, 0.002, 0.003, 0.03, 0.1, 0.3, 1.0]
62 | actor_ln:
63 | value: false
64 | critic_ln:
65 | value: true
66 | norm_obs:
67 | value: false
68 |
--------------------------------------------------------------------------------
/configs/algorithms/sac_n.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: SAC-N
8 | program: algorithms/sac_n.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: sac_n
20 | num_updates:
21 | value: 3000000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 0.0003
42 | batch_size:
43 | value: 256
44 | gamma:
45 | value: 0.99
46 | polyak_step_size:
47 | value: 0.005
48 |
49 | # --- SAC-N ---
50 | num_critics:
51 | values: [5, 10, 20, 50, 100, 200] # 500 and 1000 have high GPU memory usage
52 |
--------------------------------------------------------------------------------
/configs/algorithms/td3_bc.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: TD3+BC
8 | program: algorithms/td3_bc.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: td3_bc
20 | num_updates:
21 | value: 1_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 3e-4
42 | batch_size:
43 | value: 256
44 | gamma:
45 | value: 0.99
46 | polyak_step_size:
47 | value: 0.005
48 |
49 | # --- TD3+BC ---
50 | td3_alpha:
51 | values: [1.0, 2.0, 2.5, 3.0, 4.0]
52 | noise_clip:
53 | value: 0.5
54 | policy_noise:
55 | value: 0.2
56 | num_critic_updates_per_step:
57 | value: 2
58 |
--------------------------------------------------------------------------------
/configs/dynamics.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: Dynamics-Model
8 | program: algorithms/dynamics.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: dynamics
20 | eval_interval:
21 | value: 10000
22 | log:
23 | value: true
24 | wandb_project:
25 | value: unifloral
26 | wandb_team:
27 | value: flair
28 | wandb_group:
29 | value: debug
30 |
31 | # --- Generic optimization ---
32 | lr:
33 | value: 0.001
34 | batch_size:
35 | value: 256
36 |
37 | # --- Dynamics ---
38 | n_layers:
39 | value: 4
40 | layer_size:
41 | value: 200
42 | num_ensemble:
43 | value: 7
44 | num_elites:
45 | value: 5
46 | num_epochs:
47 | value: 400
48 | logvar_diff_coef:
49 | value: 0.01
50 | weight_decay:
51 | value: 2.5e-5
52 | validation_split:
53 | value: 0.2
54 | precompute_term_stats:
55 | value: true
56 |
--------------------------------------------------------------------------------
/configs/unifloral/bc.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: BC (Unifloral)
8 | program: algorithms/unifloral.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: unified
20 | num_updates:
21 | value: 1_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 3e-4
42 | actor_lr:
43 | value: 3e-4
44 | lr_schedule:
45 | value: constant
46 | batch_size:
47 | value: 256
48 | gamma:
49 | value: 0.99
50 | polyak_step_size:
51 | value: 0.005
52 | norm_obs:
53 | value: true
54 |
55 | # --- Actor architecture ---
56 | actor_num_layers:
57 | value: 2
58 | actor_layer_width:
59 | value: 256
60 | actor_ln:
61 | value: false
62 | deterministic:
63 | value: true
64 | deterministic_eval:
65 | value: false
66 | use_tanh_mean:
67 | value: true
68 | use_log_std_param:
69 | value: false
70 | log_std_min:
71 | value: -5.0
72 | log_std_max:
73 | value: 2.0
74 |
75 | # --- Critic + value function architecture ---
76 | num_critics:
77 | value: 2
78 | critic_num_layers:
79 | value: 2
80 | critic_layer_width:
81 | value: 256
82 | critic_ln:
83 | value: false
84 |
85 | # --- Actor loss components ---
86 | actor_bc_coef:
87 | value: 1.0
88 | actor_q_coef:
89 | value: 0.0
90 | use_q_target_in_actor:
91 | value: false
92 | normalize_q_loss:
93 | value: false
94 | aggregate_q:
95 | value: first
96 |
97 | # --- AWR (Advantage Weighted Regression) actor ---
98 | use_awr:
99 | value: false
100 | awr_temperature:
101 | value: 1.0
102 | awr_exp_adv_clip:
103 | value: 100.0
104 |
105 | # --- Critic loss components ---
106 | num_critic_updates_per_step:
107 | value: 1
108 | critic_bc_coef:
109 | value: 0.0
110 | diversity_coef:
111 | value: 0.0
112 | policy_noise:
113 | value: 0.0
114 | noise_clip:
115 | value: 0.0
116 | use_target_actor:
117 | value: false
118 |
119 | # --- Value function ---
120 | use_value_target:
121 | value: false
122 | value_expectile:
123 | value: 0.8
124 |
125 | # --- Entropy loss ---
126 | use_entropy_loss:
127 | value: false
128 | ent_coef_init:
129 | value: 1.0
130 | actor_entropy_coef:
131 | value: 0.0
132 | critic_entropy_coef:
133 | value: 0.0
134 |
--------------------------------------------------------------------------------
/configs/unifloral/edac.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: EDAC (Unifloral)
8 | program: algorithms/unifloral.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: unified
20 | num_updates:
21 | value: 3_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 3e-4
42 | actor_lr:
43 | value: 3e-4
44 | lr_schedule:
45 | value: constant
46 | batch_size:
47 | value: 256
48 | gamma:
49 | value: 0.99
50 | polyak_step_size:
51 | value: 0.005
52 | norm_obs:
53 | value: false
54 |
55 | # --- Actor architecture ---
56 | actor_num_layers:
57 | value: 3
58 | actor_layer_width:
59 | value: 256
60 | actor_ln:
61 | value: false
62 | deterministic:
63 | value: false
64 | deterministic_eval:
65 | value: false
66 | use_tanh_mean:
67 | value: false
68 | use_log_std_param:
69 | value: false
70 | log_std_min:
71 | value: -5.0
72 | log_std_max:
73 | value: 2.0
74 |
75 | # --- Critic + value function architecture ---
76 | num_critics:
77 | values: [10, 20, 50]
78 | critic_num_layers:
79 | value: 3
80 | critic_layer_width:
81 | value: 256
82 | critic_ln:
83 | value: false
84 |
85 | # --- Actor loss components ---
86 | actor_bc_coef:
87 | value: 0.0
88 | actor_q_coef:
89 | value: 1.0
90 | use_q_target_in_actor:
91 | value: false
92 | normalize_q_loss:
93 | value: false
94 | aggregate_q:
95 | value: min
96 |
97 | # --- AWR (Advantage Weighted Regression) actor ---
98 | use_awr:
99 | value: false
100 | awr_temperature:
101 | value: 1.0
102 | awr_exp_adv_clip:
103 | value: 100.0
104 |
105 | # --- Critic loss components ---
106 | num_critic_updates_per_step:
107 | value: 1
108 | critic_bc_coef:
109 | value: 0.0
110 | diversity_coef:
111 | values: [0.0, 1.0, 5.0, 10.0, 100.0, 200.0, 500.0, 1000.0]
112 | policy_noise:
113 | value: 0.0
114 | noise_clip:
115 | value: 0.0
116 | use_target_actor:
117 | value: false
118 |
119 | # --- Value function ---
120 | use_value_target:
121 | value: false
122 | value_expectile:
123 | value: 0.8
124 |
125 | # --- Entropy loss ---
126 | use_entropy_loss:
127 | value: true
128 | ent_coef_init:
129 | value: 1.0
130 | actor_entropy_coef:
131 | value: 1.0
132 | critic_entropy_coef:
133 | value: 1.0
134 |
--------------------------------------------------------------------------------
/configs/unifloral/iql.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: IQL (Unifloral)
8 | program: algorithms/unifloral.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: unified
20 | num_updates:
21 | value: 1_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 3e-4
42 | actor_lr:
43 | value: 3e-4
44 | lr_schedule:
45 | value: cosine
46 | batch_size:
47 | value: 256
48 | gamma:
49 | value: 0.99
50 | polyak_step_size:
51 | value: 0.005
52 | norm_obs:
53 | value: true
54 |
55 | # --- Actor architecture ---
56 | actor_num_layers:
57 | value: 2
58 | actor_layer_width:
59 | value: 256
60 | actor_ln:
61 | value: false
62 | deterministic:
63 | value: false
64 | deterministic_eval:
65 | value: true
66 | use_tanh_mean:
67 | value: true
68 | use_log_std_param:
69 | value: true
70 | log_std_min:
71 | value: -20.0
72 | log_std_max:
73 | value: 2.0
74 |
75 | # --- Critic + value function architecture ---
76 | num_critics:
77 | value: 2
78 | critic_num_layers:
79 | value: 2
80 | critic_layer_width:
81 | value: 256
82 | critic_ln:
83 | value: false
84 |
85 | # --- Actor loss components ---
86 | actor_bc_coef:
87 | value: 1.0 # Weight of AWR loss
88 | actor_q_coef:
89 | value: 0.0
90 | use_q_target_in_actor:
91 | value: false
92 | normalize_q_loss:
93 | value: false
94 | aggregate_q:
95 | value: min
96 |
97 | # --- AWR (Advantage Weighted Regression) actor ---
98 | use_awr:
99 | value: true
100 | awr_temperature:
101 | value: [0.5, 3.0, 10.0] # IQL beta
102 | awr_exp_adv_clip:
103 | value: 100.0
104 |
105 | # --- Critic loss components ---
106 | num_critic_updates_per_step:
107 | value: 1
108 | critic_bc_coef:
109 | value: 0.0
110 | diversity_coef:
111 | value: 0.0
112 | policy_noise:
113 | value: 0.0
114 | noise_clip:
115 | value: 0.0
116 | use_target_actor:
117 | value: false
118 |
119 | # --- Value function ---
120 | use_value_target:
121 | value: false
122 | value_expectile:
123 | value: [0.5, 0.7, 0.9] # IQL tau
124 |
125 | # --- Entropy loss ---
126 | use_entropy_loss:
127 | value: false
128 | ent_coef_init:
129 | value: 1.0
130 | actor_entropy_coef:
131 | value: 0.0
132 | critic_entropy_coef:
133 | value: 0.0
134 |
--------------------------------------------------------------------------------
/configs/unifloral/mobrac.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: ReBRAC (Unifloral)
8 | program: algorithms/unifloral.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: unified
20 | num_updates:
21 | value: 1_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 1e-3
42 | actor_lr:
43 | value: 1e-3
44 | lr_schedule:
45 | value: constant
46 | batch_size:
47 | value: 1024
48 | gamma:
49 | value: 0.99
50 | polyak_step_size:
51 | value: 0.005
52 | norm_obs:
53 | value: false
54 |
55 | # --- Actor architecture ---
56 | actor_num_layers:
57 | value: 3
58 | actor_layer_width:
59 | value: 256
60 | actor_ln:
61 | value: true
62 | deterministic:
63 | value: true
64 | deterministic_eval:
65 | value: false
66 | use_tanh_mean:
67 | value: true
68 | use_log_std_param:
69 | value: false
70 | log_std_min:
71 | value: -5.0
72 | log_std_max:
73 | value: 2.0
74 |
75 | # --- Critic + value function architecture ---
76 | num_critics:
77 | value: 2
78 | critic_num_layers:
79 | value: 3
80 | critic_layer_width:
81 | value: 256
82 | critic_ln:
83 | value: true
84 |
85 | # --- Actor loss components ---
86 | actor_bc_coef:
87 | values: [0.0005, 0.001, 0.002, 0.003, 0.03, 0.1, 0.3, 1.0]
88 | actor_q_coef:
89 | value: 1.0
90 | use_q_target_in_actor:
91 | value: false
92 | normalize_q_loss:
93 | value: true
94 | aggregate_q:
95 | value: min
96 |
97 | # --- AWR (Advantage Weighted Regression) actor ---
98 | use_awr:
99 | value: false
100 | awr_temperature:
101 | value: 1.0
102 | awr_exp_adv_clip:
103 | value: 100.0
104 |
105 | # --- Critic loss components ---
106 | num_critic_updates_per_step:
107 | value: 2
108 | critic_bc_coef:
109 | values: [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1]
110 | diversity_coef:
111 | value: 0.0
112 | policy_noise:
113 | value: 0.2
114 | noise_clip:
115 | value: 0.5
116 | use_target_actor:
117 | value: true
118 |
119 | # --- Value function ---
120 | use_value_target:
121 | value: false
122 | value_expectile:
123 | value: 0.8
124 |
125 | # --- Entropy loss ---
126 | use_entropy_loss:
127 | value: false
128 | ent_coef_init:
129 | value: 1.0
130 | actor_entropy_coef:
131 | value: 0.0
132 | critic_entropy_coef:
133 | value: 0.0
134 |
135 | # --- World model ---
136 | model_path:
137 | value: "PLACEHOLDER MODEL PATH"
138 | dataset_sample_ratio:
139 | value: 0.05
140 | rollout_interval:
141 | value: 1000
142 | rollout_length:
143 | values: [1, 5]
144 | rollout_batch_size:
145 | value: 50000
146 | model_retain_epochs:
147 | value: 5
148 | step_penalty_coef:
149 | value: [0.0, 0.25, 0.5, 1.0, 5.0]
150 |
--------------------------------------------------------------------------------
/configs/unifloral/rebrac.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: ReBRAC (Unifloral)
8 | program: algorithms/unifloral.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: unified
20 | num_updates:
21 | value: 1_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 1e-3
42 | actor_lr:
43 | value: 1e-3
44 | lr_schedule:
45 | value: constant
46 | batch_size:
47 | value: 1024
48 | gamma:
49 | value: 0.99
50 | polyak_step_size:
51 | value: 0.005
52 | norm_obs:
53 | value: false
54 |
55 | # --- Actor architecture ---
56 | actor_num_layers:
57 | value: 3
58 | actor_layer_width:
59 | value: 256
60 | actor_ln:
61 | value: true
62 | deterministic:
63 | value: true
64 | deterministic_eval:
65 | value: false
66 | use_tanh_mean:
67 | value: true
68 | use_log_std_param:
69 | value: false
70 | log_std_min:
71 | value: -5.0
72 | log_std_max:
73 | value: 2.0
74 |
75 | # --- Critic + value function architecture ---
76 | num_critics:
77 | value: 2
78 | critic_num_layers:
79 | value: 3
80 | critic_layer_width:
81 | value: 256
82 | critic_ln:
83 | value: true
84 |
85 | # --- Actor loss components ---
86 | actor_bc_coef:
87 | values: [0.0005, 0.001, 0.002, 0.003, 0.03, 0.1, 0.3, 1.0]
88 | actor_q_coef:
89 | value: 1.0
90 | use_q_target_in_actor:
91 | value: false
92 | normalize_q_loss:
93 | value: true
94 | aggregate_q:
95 | value: min
96 |
97 | # --- AWR (Advantage Weighted Regression) actor ---
98 | use_awr:
99 | value: false
100 | awr_temperature:
101 | value: 1.0
102 | awr_exp_adv_clip:
103 | value: 100.0
104 |
105 | # --- Critic loss components ---
106 | num_critic_updates_per_step:
107 | value: 2
108 | critic_bc_coef:
109 | values: [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1]
110 | diversity_coef:
111 | value: 0.0
112 | policy_noise:
113 | value: 0.2
114 | noise_clip:
115 | value: 0.5
116 | use_target_actor:
117 | value: true
118 |
119 | # --- Value function ---
120 | use_value_target:
121 | value: false
122 | value_expectile:
123 | value: 0.8
124 |
125 | # --- Entropy loss ---
126 | use_entropy_loss:
127 | value: false
128 | ent_coef_init:
129 | value: 1.0
130 | actor_entropy_coef:
131 | value: 0.0
132 | critic_entropy_coef:
133 | value: 0.0
134 |
--------------------------------------------------------------------------------
/configs/unifloral/sac_n.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: SAC-N (Unifloral)
8 | program: algorithms/unifloral.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: unified
20 | num_updates:
21 | value: 3_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 3e-4
42 | actor_lr:
43 | value: 3e-4
44 | lr_schedule:
45 | value: constant
46 | batch_size:
47 | value: 256
48 | gamma:
49 | value: 0.99
50 | polyak_step_size:
51 | value: 0.005
52 | norm_obs:
53 | value: false
54 |
55 | # --- Actor architecture ---
56 | actor_num_layers:
57 | value: 3
58 | actor_layer_width:
59 | value: 256
60 | actor_ln:
61 | value: false
62 | deterministic:
63 | value: false
64 | deterministic_eval:
65 | value: false
66 | use_tanh_mean:
67 | value: false
68 | use_log_std_param:
69 | value: false
70 | log_std_min:
71 | value: -5.0
72 | log_std_max:
73 | value: 2.0
74 |
75 | # --- Critic + value function architecture ---
76 | num_critics:
77 | values: [5, 10, 20, 50, 100, 200]
78 | critic_num_layers:
79 | value: 3
80 | critic_layer_width:
81 | value: 256
82 | critic_ln:
83 | value: false
84 |
85 | # --- Actor loss components ---
86 | actor_bc_coef:
87 | value: 0.0
88 | actor_q_coef:
89 | value: 1.0
90 | use_q_target_in_actor:
91 | value: false
92 | normalize_q_loss:
93 | value: false
94 | aggregate_q:
95 | value: min
96 |
97 | # --- AWR (Advantage Weighted Regression) actor ---
98 | use_awr:
99 | value: false
100 | awr_temperature:
101 | value: 1.0
102 | awr_exp_adv_clip:
103 | value: 100.0
104 |
105 | # --- Critic loss components ---
106 | num_critic_updates_per_step:
107 | value: 1
108 | critic_bc_coef:
109 | value: 0.0
110 | diversity_coef:
111 | value: 0.0
112 | policy_noise:
113 | value: 0.0
114 | noise_clip:
115 | value: 0.0
116 | use_target_actor:
117 | value: false
118 |
119 | # --- Value function ---
120 | use_value_target:
121 | value: false
122 | value_expectile:
123 | value: 0.8
124 |
125 | # --- Entropy loss ---
126 | use_entropy_loss:
127 | value: true
128 | ent_coef_init:
129 | value: 1.0
130 | actor_entropy_coef:
131 | value: 1.0
132 | critic_entropy_coef:
133 | value: 1.0
134 |
--------------------------------------------------------------------------------
/configs/unifloral/td3_awr.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: TD3-AWR
8 | program: algorithms/unifloral.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: unified
20 | num_updates:
21 | value: 1_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 1e-3
42 | actor_lr:
43 | value: 1e-3
44 | lr_schedule:
45 | value: constant
46 | batch_size:
47 | value: 1024
48 | gamma:
49 | value: 0.99
50 | polyak_step_size:
51 | value: 0.005
52 | norm_obs:
53 | value: false
54 |
55 | # --- Actor architecture ---
56 | actor_num_layers:
57 | value: 3
58 | actor_layer_width:
59 | value: 256
60 | actor_ln:
61 | value: true
62 | deterministic:
63 | value: true
64 | deterministic_eval:
65 | value: false
66 | use_tanh_mean:
67 | value: true
68 | use_log_std_param:
69 | value: false
70 | log_std_min:
71 | value: -5.0
72 | log_std_max:
73 | value: 2.0
74 |
75 | # --- Critic + value function architecture ---
76 | num_critics:
77 | value: 2
78 | critic_num_layers:
79 | value: 3
80 | critic_layer_width:
81 | value: 256
82 | critic_ln:
83 | value: true
84 |
85 | # --- Actor loss components ---
86 | actor_bc_coef:
87 | values: [0.0005, 0.001, 0.002, 0.003, 0.03, 0.1, 0.3, 1.0]
88 | actor_q_coef:
89 | value: 1.0
90 | use_q_target_in_actor:
91 | value: false
92 | normalize_q_loss:
93 | value: true
94 | aggregate_q:
95 | value: min
96 |
97 | # --- AWR (Advantage Weighted Regression) actor ---
98 | use_awr:
99 | value: true
100 | awr_temperature:
101 | values: [0.5, 3.0, 10.0]
102 | awr_exp_adv_clip:
103 | value: 100.0
104 |
105 | # --- Critic loss components ---
106 | num_critic_updates_per_step:
107 | value: 2
108 | critic_bc_coef:
109 | values: [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1]
110 | diversity_coef:
111 | value: 0.0
112 | policy_noise:
113 | value: 0.2
114 | noise_clip:
115 | value: 0.5
116 | use_target_actor:
117 | value: true
118 |
119 | # --- Value function ---
120 | use_value_target:
121 | value: false
122 | value_expectile:
123 | values: [0.5, 0.7, 0.9]
124 |
125 | # --- Entropy loss ---
126 | use_entropy_loss:
127 | value: false
128 | ent_coef_init:
129 | value: 1.0
130 | actor_entropy_coef:
131 | value: 0.0
132 | critic_entropy_coef:
133 | value: 0.0
134 |
--------------------------------------------------------------------------------
/configs/unifloral/td3_bc.yaml:
--------------------------------------------------------------------------------
1 | command:
2 | - python3.9
3 | - ${program}
4 | - ${args_no_boolean_flags}
5 | entity: flair
6 | method: random
7 | name: TD3+BC (Unifloral)
8 | program: algorithms/unifloral.py
9 | project: unifloral
10 |
11 | parameters:
12 | # --- Experiment ---
13 | seed:
14 | values: [0, 1, 2, 3, 4]
15 | dataset:
16 | values:
17 | - halfcheetah-medium-v2
18 | algorithm:
19 | value: unified
20 | num_updates:
21 | value: 1_000_000
22 | eval_interval:
23 | value: 2500
24 | eval_workers:
25 | value: 8
26 | eval_final_episodes:
27 | value: 1000
28 |
29 | # --- Logging ---
30 | log:
31 | value: true
32 | wandb_project:
33 | value: unifloral
34 | wandb_team:
35 | value: flair
36 | wandb_group:
37 | value: debug
38 |
39 | # --- Generic optimization ---
40 | lr:
41 | value: 3e-4
42 | actor_lr:
43 | value: 3e-4
44 | lr_schedule:
45 | value: constant
46 | batch_size:
47 | value: 256
48 | gamma:
49 | value: 0.99
50 | polyak_step_size:
51 | value: 0.005
52 | norm_obs:
53 | value: true
54 |
55 | # --- Actor architecture ---
56 | actor_num_layers:
57 | value: 2
58 | actor_layer_width:
59 | value: 256
60 | actor_ln:
61 | value: false
62 | deterministic:
63 | value: true
64 | deterministic_eval:
65 | value: false
66 | use_tanh_mean:
67 | value: true
68 | use_log_std_param:
69 | value: false
70 | log_std_min:
71 | value: -5.0
72 | log_std_max:
73 | value: 2.0
74 |
75 | # --- Critic + value function architecture ---
76 | num_critics:
77 | value: 2
78 | critic_num_layers:
79 | value: 2
80 | critic_layer_width:
81 | value: 256
82 | critic_ln:
83 | value: false
84 |
85 | # --- Actor loss components ---
86 | actor_bc_coef:
87 | value: 1.0
88 | actor_q_coef:
89 | values: [1.0, 2.0, 2.5, 3.0, 4.0]
90 | use_q_target_in_actor:
91 | value: false
92 | normalize_q_loss:
93 | value: true
94 | aggregate_q:
95 | value: first
96 |
97 | # --- AWR (Advantage Weighted Regression) actor ---
98 | use_awr:
99 | value: false
100 | awr_temperature:
101 | value: 1.0
102 | awr_exp_adv_clip:
103 | value: 100.0
104 |
105 | # --- Critic loss components ---
106 | num_critic_updates_per_step:
107 | value: 2
108 | critic_bc_coef:
109 | value: 0.0
110 | diversity_coef:
111 | value: 0.0
112 | policy_noise:
113 | value: 0.2
114 | noise_clip:
115 | value: 0.5
116 | use_target_actor:
117 | value: true
118 |
119 | # --- Value function ---
120 | use_value_target:
121 | value: false
122 | value_expectile:
123 | value: 0.8
124 |
125 | # --- Entropy loss ---
126 | use_entropy_loss:
127 | value: false
128 | ent_coef_init:
129 | value: 1.0
130 | actor_entropy_coef:
131 | value: 0.0
132 | critic_entropy_coef:
133 | value: 0.0
134 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | """Evaluation utilities for the Unifloral project.
2 |
3 | This module provides tools for:
4 | 1. Loading and parsing experiment results
5 | 2. Running bandit-based policy selection
6 | 3. Computing confidence intervals via bootstrapping
7 | """
8 |
9 | from collections import namedtuple
10 | from datetime import datetime
11 | import os
12 | import re
13 | from typing import Dict, Tuple
14 | import warnings
15 |
16 | from functools import partial
17 | import glob
18 | import jax
19 | from jax import numpy as jnp
20 | import numpy as np
21 | import pandas as pd
22 |
23 |
24 | r"""
25 | |\ __
26 | \| /_/
27 | \|
28 | ___|_____
29 | \ /
30 | \ /
31 | \___/ Data loading
32 | """
33 |
34 |
35 | def parse_and_load_npz(filename: str) -> Dict:
36 | """Load data from a result file and parse metadata from filename.
37 |
38 | Args:
39 | filename: Path to the .npz result file
40 |
41 | Returns:
42 | Dictionary containing loaded arrays and metadata
43 | """
44 | # Parse filename to extract algorithm, dataset, and timestamp
45 | pattern = r"(.+)_(.+)_(\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2})"
46 | match = re.match(pattern, os.path.basename(filename))
47 | if not match:
48 | raise ValueError(f"Could not parse filename: {filename}")
49 |
50 | algorithm, dataset, dt_str = match.groups()
51 | dt = datetime.strptime(dt_str, "%Y-%m-%d_%H-%M-%S")
52 |
53 | data = np.load(filename, allow_pickle=True)
54 | data = {k: v for k, v in data.items()}
55 | data["algorithm"] = algorithm
56 | data["dataset"] = dataset
57 | data["datetime"] = dt
58 | data.update(data.pop("args", np.array({})).item()) # Flatten args
59 | return data
60 |
61 |
62 | def load_results_dataframe(results_dir: str = "final_returns") -> pd.DataFrame:
63 | """Load all result files from a directory into a pandas DataFrame.
64 |
65 | Args:
66 | results_dir: Directory containing .npz result files
67 |
68 | Returns:
69 | DataFrame containing results from all successfully loaded files
70 | """
71 | npz_files = glob.glob(os.path.join(results_dir, "*.npz"))
72 | data_list = []
73 |
74 | for f in npz_files:
75 | try:
76 | data = parse_and_load_npz(f)
77 | data_list.append(data)
78 | except Exception as e:
79 | print(f"Error loading {f}: {e}")
80 | continue
81 |
82 | df = pd.DataFrame(data_list).drop(columns=["Index"], errors="ignore")
83 | if "final_scores" in df.columns:
84 | df["final_scores"] = df["final_scores"].apply(lambda x: x.reshape(-1))
85 | if "final_returns" in df.columns:
86 | df["final_returns"] = df["final_returns"].apply(lambda x: x.reshape(-1))
87 |
88 | df = df.sort_values(by=["algorithm", "dataset", "datetime"])
89 | return df.reset_index(drop=True)
90 |
91 |
92 | r"""
93 | __/)
94 | .-(__(=:
95 | |\ | \)
96 | \ ||
97 | \||
98 | \|
99 | ___|_____
100 | \ /
101 | \ /
102 | \___/ Bandit Evaluation and Bootstrapping
103 | """
104 |
105 | BanditState = namedtuple("BanditState", "rng counts rewards total_pulls")
106 |
107 |
108 | def ucb(
109 | means: jnp.ndarray, counts: jnp.ndarray, total_counts: int, alpha: float
110 | ) -> jnp.ndarray:
111 | """Compute UCB exploration bonus.
112 |
113 | Args:
114 | means: Array of empirical means for each arm
115 | counts: Array of pull counts for each arm
116 | total_counts: Total number of pulls across all arms
117 | alpha: Exploration coefficient
118 |
119 | Returns:
120 | Array of UCB values for each arm
121 | """
122 | exploration = jnp.sqrt(alpha * jnp.log(total_counts) / (counts + 1e-9))
123 | return means + exploration
124 |
125 |
126 | def argmax_with_random_tiebreaking(rng: jnp.ndarray, values: jnp.ndarray) -> int:
127 | """Select maximum value with random tiebreaking.
128 |
129 | Args:
130 | rng: JAX PRNGKey
131 | values: Array of values to select from
132 |
133 | Returns:
134 | Index of selected maximum value
135 | """
136 | mask = values == jnp.max(values)
137 | p = mask / (mask.sum() + 1e-9)
138 | return jax.random.choice(rng, jnp.arange(len(values)), p=p)
139 |
140 |
141 | @partial(jax.jit, static_argnums=(2,))
142 | def run_bandit(
143 | returns_array: jnp.ndarray,
144 | rng: jnp.ndarray,
145 | max_pulls: int,
146 | alpha: float,
147 | policy_idx: jnp.ndarray,
148 | ) -> Tuple[jnp.ndarray, jnp.ndarray]:
149 | """Run a single bandit algorithm and report results after each pull.
150 |
151 | Args:
152 | returns_array: Array of returns for each policy and rollout
153 | rng: JAX PRNGKey
154 | max_pulls: Maximum number of pulls to execute
155 | alpha: UCB exploration coefficient
156 | policy_idx: Indices of policies to consider
157 |
158 | Returns:
159 | Tuple of (pulls, estimated_bests)
160 | """
161 | returns_array = returns_array[policy_idx]
162 | num_policies, num_rollouts = returns_array.shape
163 |
164 | init_state = BanditState(
165 | rng=rng,
166 | counts=jnp.zeros(num_policies, dtype=jnp.int32),
167 | rewards=jnp.zeros(num_policies),
168 | total_pulls=1,
169 | )
170 |
171 | def bandit_step(state: BanditState, _):
172 | """Run one bandit step and track performance."""
173 | rng, rng_lever, rng_reward = jax.random.split(state.rng, 3)
174 |
175 | # Select arm using UCB
176 | means = state.rewards / jnp.maximum(state.counts, 1)
177 | ucb_values = ucb(means, state.counts, state.total_pulls, alpha)
178 | arm = argmax_with_random_tiebreaking(rng_lever, ucb_values)
179 |
180 | # Sample a reward for the chosen arm
181 | idx = jax.random.randint(rng_reward, shape=(), minval=0, maxval=num_rollouts)
182 | reward = returns_array[arm, idx]
183 | new_state = BanditState(
184 | rng=rng,
185 | counts=state.counts.at[arm].add(1),
186 | rewards=state.rewards.at[arm].add(reward),
187 | total_pulls=state.total_pulls + 1,
188 | )
189 |
190 | # Calculate best arm based on current state
191 | updated_means = new_state.rewards / jnp.maximum(new_state.counts, 1)
192 | best_arm = jnp.argmax(updated_means)
193 | estimated_best = returns_array[best_arm].mean()
194 |
195 | return new_state, (state.total_pulls, estimated_best)
196 |
197 | _, (pulls, estimated_bests) = jax.lax.scan(
198 | bandit_step, init_state, length=max_pulls
199 | )
200 | return pulls, estimated_bests
201 |
202 |
203 | def run_bandit_trials(
204 | returns_array: jnp.ndarray,
205 | seed: int = 17,
206 | num_subsample: int = 20,
207 | num_repeats: int = 1000,
208 | max_pulls: int = 200,
209 | ucb_alpha: float = 2.0,
210 | ) -> Tuple[jnp.ndarray, jnp.ndarray]:
211 | """Run multiple bandit trials and collect results at each step.
212 |
213 | Args:
214 | returns_array: Array of returns for each policy and rollout
215 | seed: Random seed
216 | num_subsample: Number of policies to subsample on each trial
217 | num_repeats: Number of trials to run
218 | max_pulls: Maximum number of pulls per trial
219 | ucb_alpha: UCB exploration coefficient
220 |
221 | Returns:
222 | Tuple of (pulls, estimated_bests)
223 | """
224 | rng = jax.random.PRNGKey(seed)
225 | num_policies = returns_array.shape[0]
226 |
227 | num_subsample = min(num_subsample, num_policies)
228 | if num_subsample > num_policies:
229 | warnings.warn("Not enough policies to subsample, using all policies")
230 |
231 | rng, rng_trials, rng_sample = jax.random.split(rng, 3)
232 | rng_trials = jax.random.split(rng_trials, num_repeats)
233 |
234 | def sample_policies(rng: jnp.ndarray) -> jnp.ndarray:
235 | """Sample a subset of policy indices."""
236 | if num_subsample > num_policies:
237 | return jnp.arange(num_policies)
238 | return jax.random.choice(
239 | rng, jnp.arange(num_policies), shape=(num_subsample,), replace=False
240 | )
241 |
242 | # Create a batch of policy index arrays for all trials
243 | rng_sample_keys = jax.random.split(rng_sample, num_repeats)
244 | policy_indices = jax.vmap(sample_policies)(rng_sample_keys)
245 |
246 | # Run bandit trials with policy subsampling
247 | # Pulls are the same for all trials, so we can just return the first one
248 | vmap_run_bandit = jax.vmap(run_bandit, in_axes=(None, 0, None, None, 0))
249 | pulls, estimated_bests = vmap_run_bandit(
250 | returns_array, rng_trials, max_pulls, ucb_alpha, policy_indices
251 | )
252 | return pulls[0], estimated_bests
253 |
254 |
255 | def bootstrap_confidence_interval(
256 | rng: jnp.ndarray,
257 | data: jnp.ndarray,
258 | n_bootstraps: int = 1000,
259 | confidence: float = 0.95,
260 | ) -> Tuple[float, float]:
261 | """Compute bootstrap confidence interval for mean of data.
262 |
263 | Args:
264 | rng: JAX PRNGKey
265 | data: Array of values to bootstrap
266 | n_bootstraps: Number of bootstrap samples
267 | confidence: Confidence level (between 0 and 1)
268 |
269 | Returns:
270 | Tuple of (lower_bound, upper_bound)
271 | """
272 |
273 | @jax.vmap
274 | def bootstrap_mean(rng):
275 | samples = jax.random.choice(rng, data, shape=(data.shape[0],), replace=True)
276 | return samples.mean()
277 |
278 | bootstrap_means = bootstrap_mean(jax.random.split(rng, n_bootstraps))
279 | lower_bound = jnp.percentile(bootstrap_means, 100 * (1 - confidence) / 2)
280 | upper_bound = jnp.percentile(bootstrap_means, 100 * (1 + confidence) / 2)
281 | return lower_bound, upper_bound
282 |
283 |
284 | def bootstrap_bandit_trials(
285 | returns_array: jnp.ndarray,
286 | seed: int = 17,
287 | num_subsample: int = 20,
288 | num_repeats: int = 1000,
289 | max_pulls: int = 200,
290 | ucb_alpha: float = 2.0,
291 | n_bootstraps: int = 1000,
292 | confidence: float = 0.95,
293 | ) -> Dict[str, np.ndarray]:
294 | """Run bandit trials and compute bootstrap confidence intervals.
295 |
296 | Args:
297 | returns_array: Array of returns for each policy and rollout has shape (num_policies, num_rollouts)
298 | seed: Random seed
299 | num_subsample: Number of policies to subsample
300 | num_repeats: Number of bandit trials to run
301 | max_pulls: Maximum number of pulls per trial
302 | ucb_alpha: UCB exploration coefficient
303 | n_bootstraps: Number of bootstrap samples
304 | confidence: Confidence level for intervals
305 |
306 | Returns:
307 | Dictionary with the following keys:
308 | - pulls: Number of pulls at each step
309 | - estimated_bests_mean: Mean of the currently estimated best returns across trials
310 | - estimated_bests_ci_low: Lower confidence bound for estimated best returns
311 | - estimated_bests_ci_high: Upper confidence bound for estimated best returns
312 | """
313 | rng = jax.random.PRNGKey(seed)
314 | rng = jax.random.split(rng, max_pulls)
315 |
316 | pulls, estimated_bests = run_bandit_trials(
317 | returns_array, seed, num_subsample, num_repeats, max_pulls, ucb_alpha
318 | )
319 | vmap_bootstrap = jax.vmap(bootstrap_confidence_interval, in_axes=(0, 1, None, None))
320 | ci_low, ci_high = vmap_bootstrap(rng, estimated_bests, n_bootstraps, confidence)
321 | estimated_bests_mean = estimated_bests.mean(axis=0)
322 |
323 | return {
324 | "pulls": pulls,
325 | "estimated_bests_mean": estimated_bests_mean,
326 | "estimated_bests_ci_low": ci_low,
327 | "estimated_bests_ci_high": ci_high,
328 | }
329 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | argparse
2 | black==24.4.2
3 | cython<3
4 | d4rl==1.1
5 | distrax==0.1.3
6 | dm-control==1.0.5
7 | flax==0.6.11
8 | gym==0.23.1
9 | jax[cuda12_pip]==0.4.16
10 | mujoco==2.2.1
11 | mujoco-py==2.1.2.14
12 | numpy==1.22.4
13 | optax==0.1.5
14 | orbax-checkpoint==0.4.4
15 | pre-commit==3.7.1
16 | scipy==1.12.0
17 | tensorflow==2.13.0
18 | tensorstore==0.1.51
19 | tyro==0.7.3
20 | wandb==0.17.3
21 |
--------------------------------------------------------------------------------