├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── images
└── OPEN_Animation.gif
├── rl_optimizer
├── base.py
├── configs.py
├── eval.py
├── network.py
├── pretrained
│ ├── ant_OPEN.npy
│ ├── asterix_OPEN.npy
│ ├── breakout_OPEN.npy
│ ├── freeway_OPEN.npy
│ ├── gridworld_OPEN.npy
│ ├── multi_OPEN.npy
│ └── spaceinvaders_OPEN.npy
├── train.py
└── utils.py
├── run_docker.sh
└── setup
├── Dockerfile
├── build_docker.sh
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | training
2 | wandb
3 | **/__pycache__/
4 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "rl_optimizer/groove"]
2 | path = rl_optimizer/groove
3 | url = https://github.com/EmptyJackson/groove
4 | [submodule "rl_optimizer/learned_optimization"]
5 | path = rl_optimizer/learned_optimization
6 | url = https://github.com/google/learned_optimization
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
OPEN: Learned Optimization for RL in JAX
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | This is the official implementation of OPEN from *Can Learned Optimization Make Reinforcement Less Difficult*, NeurIPS 2024 (**Spotlight**) and the AutoRL Workshop @ ICML 2024 (**Spotlight**).
15 |
16 |
17 | OPEN is a framework for learning to optimize (L2O) in reinforcement learning. Here, we provide full JAX code to replicate the experiments in our paper and foster future work in this direction. Our current codebase can be used with environments from gymnax or Brax.
18 |
19 |
20 | # 🖥️ Usage
21 |
22 | All files for running OPEN are stored in ``.
23 |
24 | ## 🏋️♀️ Training
25 | Alongside training code in `rl_optimizer/train.py`, we include configs for [`freeway`, `asterix`, `breakout`, `spaceinvaders`, `ant`, `gridworld`]. We automate parallelisation over multiple GPUs using JAX sharding. The flag `<--larger>` can be used to increase the size of the network in OPEN. To learn an optimizer in one or a combination of these environments run:
26 | ```bash
27 | python3 train.py --envs --num-rollouts --popsize --noise-level --sigma-decay --lr --lr-decay --num-generations --save-every-k --wandb-name "" --wandb-entity "" [--larger]
28 | ```
29 |
30 | This will save a checkpoint, and evaluate the performance of the optimizer, every $k$ steps. Please note that `gridworld` can not be run in tandem with other environments as it is the only environment which we apply antithetic task sampling to.
31 |
32 | We include our hyperparameters in the paper. An example usage is:
33 | ```bash
34 | python3 train.py --envs breakout --num-rollouts 1 --popsize 64 --noise-level 0.03 --sigma-decay 0.999 --lr 0.03 --lr-decay 0.999 --num-generations 500 --save-every-k 24 --wandb-name "OPEN Breakout"
35 | ```
36 |
37 | ## 🔬 Evaluation
38 |
39 | To evaluate the performance of learned optimizers, run the following command by providing the relevant wandb run IDs to `<--exp-name>` and the generation number to `--exp-num`. This code is run intermittently during training too.
40 |
41 | For experimental purposes, we provide learned weights for the trained optimizers from our paper for the aforementioned environments in `rl_optimizer/pretrained`. These can be used with the argument `<--pretrained>` in place of wandb IDs. Use the <--larger> flag if this was used in training, and to experiment with our pretrained `` optimizers pass the `<--multi>` flag.
42 | ```bash
43 | python3 rl_optimizer.eval --envs --exp-name --exp-num --num-runs 16 --title [--pretrained --multi --larger]
44 | ```
45 |
46 |
47 | # ⬇️ Installation
48 |
49 | We include submodules for [Learned Optimization](https://github.com/google/learned_optimization) and [GROOVE](https://github.com/EmptyJackson/groove). Therefore, when cloning this repo, ensure to use `--recurse-submodules`:
50 | ```bash
51 | git clone --recurse-submodules git@github.com:AlexGoldie/rl-learned-optimization.git
52 | ```
53 |
54 | ## 📝 Requirements
55 |
56 | We include requirements in `setup/requirements.txt`. Dependencies can be install locally using:
57 | ```bash
58 | pip install -r setup/requirements.txt
59 | ```
60 |
61 | ## 🐋 Docker
62 | We also provide files to help build a Docker image. Since we use wandb for logging checkpoints, you should supply this as an argument to `build_docker.sh`.
63 |
64 | ```bash
65 | cd setup
66 | chmod +x build_docker.sh
67 | ./build_docker.sh {WANDB_API_KEY}
68 | cd ..
69 | chmod +x run_docker.sh
70 | ./run_docker.sh {GPU_NAMES}
71 | ```
72 |
73 | For example, starting the docker container with access to GPUs `0` and `1` can be done as `./run_docker.sh 0,1`
74 |
75 |
76 | # 📚 Related Work
77 |
78 | The following projects were used extensively in the making of OPEN:
79 | - 🎓 [Learned Optimization](https://github.com/google/learned_optimization)
80 | - 🦎 [Evosax](https://github.com/RobertTLange/evosax)
81 | - ⚡ [PureJaxRL](https://github.com/luchris429/purejaxrl)
82 | - 🕺 [GROOVE](https://github.com/EmptyJackson/groove)
83 | - 🐜 [Brax](https://github.com/google/brax)
84 | - 💪 [Gymnax](https://github.com/RobertTLange/gymnax)
85 |
86 |
87 | # 🔖 Citation
88 |
89 | If you use OPEN in your work, please cite the following:
90 | ```
91 | @inproceedings{goldie2024can,
92 | author={Alexander D. Goldie and Chris Lu and Matthew Thomas Jackson and Shimon Whiteson and Jakob Nicolaus Foerster},
93 | booktitle={Advances in Neural Information Processing Systems},
94 | title={Can Learned Optimization Make Reinforcement Learning Less Difficult?},
95 | year={2024},
96 | }
97 | ```
98 |
--------------------------------------------------------------------------------
/images/OPEN_Animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/images/OPEN_Animation.gif
--------------------------------------------------------------------------------
/rl_optimizer/base.py:
--------------------------------------------------------------------------------
1 | # A MISHMASH OF BASE OBJECTS FROM LEARNED_OPTIMIZERS
2 | import abc
3 | import collections
4 | from typing import Any, Callable, Sequence, Optional, Tuple
5 |
6 | import chex
7 | import flax
8 |
9 | import jax.numpy as jnp
10 |
11 | MetaParamOpt = collections.namedtuple("MetaParamOpt", ["init", "opt_fn"])
12 |
13 | PRNGKey = jnp.ndarray
14 | Params = Any
15 | MetaParams = Any
16 |
17 |
18 | class LearnedOptimizer(abc.ABC):
19 | """Base class for learned optimizers."""
20 |
21 | @abc.abstractmethod
22 | def init(self, key: PRNGKey) -> MetaParams:
23 | raise NotImplementedError()
24 |
25 | @abc.abstractmethod
26 | def opt_fn(self, theta: MetaParams, is_training: bool = False):
27 | raise NotImplementedError()
28 |
29 | @property
30 | def name(self):
31 | return None
32 |
33 |
34 | ModelState = Any
35 | Params = Any
36 | Gradient = Params
37 | OptState = Any
38 |
39 |
40 | @flax.struct.dataclass
41 | class StatelessState:
42 | params: chex.ArrayTree
43 | state: chex.ArrayTree
44 |
45 |
46 | class Optimizer(abc.ABC):
47 | """Baseclass for the Optimizer interface."""
48 |
49 | def get_params(self, state: OptState) -> Params:
50 | return state.params
51 |
52 | def get_state(self, state: OptState) -> ModelState:
53 | return state.state
54 |
55 | def get_params_state(self, state: OptState) -> Tuple[Params, ModelState]:
56 | return self.get_params(state), self.get_state(state)
57 |
58 | def init(
59 | self,
60 | params: Params,
61 | state: Optional[ModelState] = None,
62 | num_steps: Optional[int] = None,
63 | key: Optional[chex.PRNGKey] = None,
64 | **kwargs,
65 | ) -> OptState:
66 | raise NotImplementedError
67 |
68 | def set_params(self, state: OptState, params: Params) -> OptState:
69 | return state.replace(params=params)
70 |
71 | def update(
72 | self,
73 | opt_state: OptState,
74 | grad: Gradient,
75 | model_state: Optional[ModelState] = None,
76 | key: Optional[chex.PRNGKey] = None,
77 | **kwargs,
78 | ) -> OptState:
79 | raise NotImplementedError()
80 |
81 | @property
82 | def name(self) -> str:
83 | """Name of optimizer.
84 | This property is used when serializing results / baselines. This should
85 | lead with the class name, and follow with all parameters used to create
86 | the object. For example: "__"
87 | """
88 | return "UnnamedOptimizer"
89 |
--------------------------------------------------------------------------------
/rl_optimizer/configs.py:
--------------------------------------------------------------------------------
1 | # ****HAVE TO INPUT WHETHER CONTINUOUS OR JUST AUTOMATE
2 |
3 | all_configs = {
4 | "asterix": {
5 | "ANNEAL_LR": True,
6 | "NUM_ENVS": 64,
7 | "NUM_STEPS": 128,
8 | "TOTAL_TIMESTEPS": 1e7,
9 | "UPDATE_EPOCHS": 4,
10 | "NUM_MINIBATCHES": 8,
11 | "GAMMA": 0.99,
12 | "GAE_LAMBDA": 0.95,
13 | "CLIP_EPS": 0.2,
14 | "ENT_COEF": 0.01,
15 | "VF_COEF": 0.5,
16 | "MAX_GRAD_NORM": 0.5,
17 | "ENV_NAME": "Asterix-MinAtar",
18 | "HSIZE": 64,
19 | "ACTIVATION": "relu",
20 | "DEBUG": False,
21 | "PPO_TEMP": 15.0353,
22 | "CONTINUOUS": False,
23 | },
24 | "freeway": {
25 | "NUM_ENVS": 64,
26 | "NUM_STEPS": 128,
27 | "TOTAL_TIMESTEPS": 1e7,
28 | "UPDATE_EPOCHS": 4,
29 | "NUM_MINIBATCHES": 8,
30 | "GAMMA": 0.99,
31 | "GAE_LAMBDA": 0.95,
32 | "CLIP_EPS": 0.2,
33 | "ENT_COEF": 0.01,
34 | "VF_COEF": 0.5,
35 | "MAX_GRAD_NORM": 0.5,
36 | "ENV_NAME": "Freeway-MinAtar",
37 | "HSIZE": 64,
38 | "ACTIVATION": "relu",
39 | "DEBUG": False,
40 | "PPO_TEMP": 61.8369,
41 | "CONTINUOUS": False,
42 | },
43 | "breakout": {
44 | "NUM_ENVS": 64,
45 | "NUM_STEPS": 128,
46 | "TOTAL_TIMESTEPS": 5e5,
47 | "UPDATE_EPOCHS": 4,
48 | "NUM_MINIBATCHES": 8,
49 | "GAMMA": 0.99,
50 | "GAE_LAMBDA": 0.95,
51 | "CLIP_EPS": 0.2,
52 | "ENT_COEF": 0.01,
53 | "VF_COEF": 0.5,
54 | "MAX_GRAD_NORM": 0.5,
55 | "ENV_NAME": "Breakout-MinAtar",
56 | "HSIZE": 64,
57 | "ACTIVATION": "relu",
58 | "DEBUG": False,
59 | "PPO_TEMP": 50.0822,
60 | "CONTINUOUS": False,
61 | },
62 | "spaceinvaders": {
63 | "NUM_ENVS": 64,
64 | "NUM_STEPS": 128,
65 | "TOTAL_TIMESTEPS": 1e7,
66 | "UPDATE_EPOCHS": 4,
67 | "NUM_MINIBATCHES": 8,
68 | "GAMMA": 0.99,
69 | "GAE_LAMBDA": 0.95,
70 | "CLIP_EPS": 0.2,
71 | "ENT_COEF": 0.01,
72 | "VF_COEF": 0.5,
73 | "MAX_GRAD_NORM": 0.5,
74 | "ENV_NAME": "SpaceInvaders-MinAtar",
75 | "HSIZE": 64,
76 | "ACTIVATION": "relu",
77 | "DEBUG": False,
78 | "PPO_TEMP": 167.0913,
79 | "CONTINUOUS": False,
80 | },
81 | "ant": {
82 | "NUM_ENVS": 2048,
83 | "NUM_STEPS": 10,
84 | "TOTAL_TIMESTEPS": 5e7,
85 | "UPDATE_EPOCHS": 4,
86 | "NUM_MINIBATCHES": 32,
87 | "GAMMA": 0.99,
88 | "GAE_LAMBDA": 0.95,
89 | "CLIP_EPS": 0.2,
90 | "ENT_COEF": 0.0,
91 | "VF_COEF": 0.5,
92 | "MAX_GRAD_NORM": 0.5,
93 | "ACTIVATION": "tanh",
94 | "HSIZE": 64,
95 | "ENV_NAME": "Brax-ant",
96 | "BACKEND": "positional",
97 | "SYMLOG_OBS": False,
98 | "CLIP_ACTION": True,
99 | "DEBUG": False,
100 | "NORMALIZE": True,
101 | "CONTINUOUS": True,
102 | "PPO_TEMP": 6517.31,
103 | },
104 | "gridworld": {
105 | "LR": 1e-4,
106 | "ANNEAL_LR": True,
107 | "NUM_ENVS": 1024,
108 | "NUM_STEPS": 20,
109 | "TOTAL_TIMESTEPS": 3e7,
110 | "UPDATE_EPOCHS": 2,
111 | "NUM_MINIBATCHES": 16,
112 | "GAMMA": 0.99,
113 | "GAE_LAMBDA": 0.95,
114 | "CLIP_EPS": 0.2,
115 | "ENT_COEF": 0.01,
116 | "VF_COEF": 0.5,
117 | "MAX_GRAD_NORM": 0.5,
118 | "ACTIVATION": "tanh",
119 | "HSIZE": 16,
120 | "ENV_NAME": "gridworld",
121 | "DEBUG": False,
122 | "PPO_TEMP": 1.0,
123 | "CONTINUOUS": False,
124 | },
125 | }
126 |
--------------------------------------------------------------------------------
/rl_optimizer/eval.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import numpy as np
4 | import matplotlib
5 | import matplotlib.pyplot as plt
6 | import os
7 | import argparse
8 | import seaborn as sns
9 | import pandas as pd
10 | import wandb
11 | from scipy.interpolate import interp1d
12 | import time
13 |
14 | import jax
15 | import jax.numpy as jnp
16 | from jax.sharding import Mesh
17 | from jax.sharding import PartitionSpec as P, NamedSharding
18 |
19 | from configs import all_configs as all_configs
20 | from network import GRU_Opt as optim
21 | from utils import ParameterReshaper
22 |
23 | api = wandb.Api()
24 |
25 |
26 | labs = [""]
27 |
28 |
29 | def set_size(width=487.8225, fraction=1, subplots=(1, 1)):
30 | """Set figure dimensions to avoid scaling in LaTeX.
31 |
32 | Parameters
33 | ----------
34 | width: float
35 | Document textwidth or columnwidth in pts
36 | fraction: float, optional
37 | Fraction of the width which you wish the figure to occupy
38 |
39 | Returns
40 | -------
41 | fig_dim: tuple
42 | Dimensions of figure in inches
43 | """
44 | # Width of figure (in pts)
45 | fig_width_pt = width * fraction
46 |
47 | # Convert from pt to inches
48 | inches_per_pt = 1 / 72.27
49 |
50 | golden_ratio = 1.2
51 |
52 | # Figure width in inches
53 | fig_width_in = fig_width_pt * inches_per_pt
54 | subplots = (1, 1)
55 | fig_height_in = fig_width_in / golden_ratio * (subplots[0] / subplots[1])
56 | # Figure height in inches
57 | fig_dim = (fig_width_in, fig_height_in)
58 |
59 | return fig_dim
60 |
61 |
62 | episode_lengths = {
63 | "asterix": 1000,
64 | "pendulum": 200,
65 | "acrobot": 500,
66 | "cartpole": 500,
67 | "spaceinvaders": 1000,
68 | "freeway": 2500,
69 | "breakout": 1000,
70 | "ant": 1000,
71 | "gridworld": 0,
72 | }
73 |
74 |
75 | def make_plot(
76 | exp_names=None,
77 | exp_nums=None,
78 | envs=None,
79 | num_runs=5,
80 | pmap=False,
81 | title=None,
82 | larger=False,
83 | pretrained=False,
84 | multi=False,
85 | training=False,
86 | grids=False,
87 | params=None,
88 | mesh=None,
89 | ):
90 |
91 | devices = jax.devices()
92 | sharding_p = NamedSharding(
93 | mesh,
94 | P(
95 | None,
96 | ),
97 | )
98 | sharding_rng = NamedSharding(
99 | mesh,
100 | P(
101 | "dim",
102 | ),
103 | )
104 |
105 | if training:
106 | p = "training"
107 | if not os.path.exists(p):
108 | os.mkdir(p)
109 | p = f"{p}/{exp_names}"
110 | if not os.path.exists(p):
111 | os.mkdir(p)
112 | p = f"{p}/{exp_nums}"
113 | elif pretrained:
114 | p = "pretrained"
115 | else:
116 | p = "visualization"
117 | if not os.path.exists(p):
118 | os.mkdir(p)
119 | p = f"{p}/{title}"
120 |
121 | if not os.path.exists(p):
122 | os.mkdir(p)
123 |
124 | returns_list = dict()
125 | param_means = dict()
126 | tau_dormancies = dict()
127 | abs_param_means = dict()
128 | runtimes = dict()
129 |
130 | from train import (
131 | make_train as meta_make_train,
132 | )
133 |
134 | for i, env in enumerate(envs):
135 |
136 | if grids:
137 | env = "gridworld"
138 |
139 | returns_list.update({envs[i]: []})
140 | param_means.update({envs[i]: []})
141 | tau_dormancies.update({envs[i]: []})
142 | abs_param_means.update({envs[i]: []})
143 | runtimes.update({envs[i]: []})
144 |
145 | if grids:
146 | if not os.path.exists(f"{p}/{envs[i]}"):
147 | os.mkdir(f"{p}/{envs[i]}")
148 | else:
149 | if not os.path.exists(f"{p}/{env}"):
150 | os.mkdir(f"{p}/{env}")
151 |
152 | if pretrained:
153 | if multi:
154 | exp_name = f"pretrained/multi_OPEN.npy"
155 | larger = True
156 | else:
157 | exp_name = f"pretrained/{env}_OPEN.npy"
158 | params = jnp.array(jnp.load(exp_name, allow_pickle=True))
159 |
160 | else:
161 | if training:
162 | params = params
163 | exp_name = exp_names
164 | exp_num = exp_nums
165 | else:
166 | if grids:
167 | exp_name = exp_names[0]
168 | exp_num = exp_nums[0]
169 | else:
170 | exp_name = exp_names[i]
171 | exp_num = exp_nums[i]
172 |
173 | run_path = f"OPEN/{exp_name}"
174 | run = api.run(run_path)
175 | restored = wandb.restore(
176 | f"curr_param_{exp_num}.npy",
177 | run_path=run_path,
178 | root=p,
179 | replace=True,
180 | )
181 | print(f"name: {restored.name}, env = {envs[i]}")
182 | params = jnp.array(jnp.load(restored.name, allow_pickle=True))
183 |
184 | # Need to reshape saved params as they are saved as np arrays
185 | if not training:
186 | if larger:
187 | hidden_size = 32
188 | gru_features = 16
189 | else:
190 | hidden_size = 16
191 | gru_features = 8
192 | pholder = optim(hidden_size=hidden_size, gru_features=gru_features).init(
193 | jax.random.PRNGKey(0)
194 | )
195 | param_reshaper = ParameterReshaper(pholder)
196 | params = param_reshaper.reshape_single(params)
197 |
198 | all_configs[f"{env}"]["larger"] = larger
199 |
200 | make_train = meta_make_train
201 |
202 | rng = jax.random.PRNGKey(42)
203 | all_configs[f"{env}"]["VISUALISE"] = True
204 |
205 | start = time.time()
206 | rngs = jax.random.split(rng, num_runs)
207 | if env == "gridworld":
208 | asdf = jax.jit(
209 | jax.vmap(
210 | make_train(all_configs[f"{env}"]),
211 | in_axes=(None, 0, None),
212 | ),
213 | static_argnames=["grid_type"],
214 | )
215 |
216 | if pmap:
217 | params = jax.device_put(params, sharding_p)
218 | rngs = jax.device_put(rngs, sharding_rng)
219 |
220 | out, metrics = asdf(params, rngs, i)
221 |
222 | else:
223 | asdf = jax.jit(
224 | jax.vmap(
225 | make_train(all_configs[f"{env}"]),
226 | in_axes=(None, 0),
227 | )
228 | )
229 |
230 | if pmap:
231 | params = jax.device_put(params, sharding_p)
232 | rngs = jax.device_put(rngs, sharding_rng)
233 |
234 | out, metrics = asdf(params, rngs)
235 | out, metrics = jax.device_get(out), jax.device_get(metrics)
236 |
237 | fitness = metrics["returned_episode_returns"][..., -1, -1, :].mean()
238 | end = time.time()
239 | print(f"runtime = {end - start}")
240 | runtimes[envs[i]].append(end - start)
241 | print(f"{envs[i]} learned fitness: {fitness}")
242 |
243 | returns = (
244 | metrics["returned_episode_returns"].mean(-1).mean(-1).reshape(num_runs, -1)
245 | )
246 |
247 | if training:
248 | wandb.log({f"eval/{env}/fitness_at_mean": fitness}, step=exp_num)
249 |
250 | config_for_index = all_configs[f"{env}"]
251 | index_from = episode_lengths[env]
252 | index = int(np.ceil(index_from / config_for_index["NUM_STEPS"]))
253 |
254 | if grids:
255 |
256 | returns_list[envs[i]].append(returns[:, index:])
257 |
258 | return_df = pd.DataFrame(returns[:, index:])
259 | return_df.to_csv(f"{p}/{envs[i]}/returns.csv")
260 |
261 | tau_dormancies[envs[i]].append(
262 | metrics["dormancy"].mean(-1).reshape(num_runs, -1)
263 | )
264 | else:
265 | returns_list[env].append(returns[:, index:])
266 |
267 | return_df = pd.DataFrame(returns[:, index:])
268 | return_df.to_csv(f"{p}/{env}/returns.csv")
269 |
270 | tau_dormancies[env].append(
271 | metrics["dormancy"].mean(-1).reshape(num_runs, -1)
272 | )
273 |
274 | def plot_all(values, conf, labels, xlabel, ylabel, title, training=False):
275 |
276 | for j, env in enumerate(values.keys()):
277 | fig, ax = plt.subplots(1, 1, figsize=set_size())
278 | if grids:
279 | x_mult = (
280 | conf[f"gridworld"]["NUM_STEPS"] * conf[f"gridworld"]["NUM_ENVS"]
281 | )
282 | else:
283 | x_mult = conf[f"{env}"]["NUM_STEPS"] * conf[f"{env}"]["NUM_ENVS"]
284 |
285 | legend = []
286 | for i, value in enumerate(values[env]):
287 | val = values[env][i]
288 | val_df = pd.DataFrame({f"vals_{i}": val[i] for i in range(len(val))})
289 |
290 | val_ewm = val_df.ewm(span=200, axis=0).mean().to_numpy().T
291 |
292 | mean = val_ewm.mean(0)
293 |
294 | xs = jnp.arange(len(mean)) * x_mult
295 |
296 | std = jnp.std(val_ewm, axis=0) / jnp.sqrt(val_ewm.shape[0])
297 |
298 | results_max = mean + std
299 | results_min = mean - std
300 |
301 | (leg,) = ax.plot(xs, mean, label=labels[i], linewidth=0.4)
302 | legend.append(leg)
303 | ax.fill_between(x=xs, y1=results_min, y2=results_max, alpha=0.5)
304 |
305 | if j == 0:
306 |
307 | ax.legend(
308 | legend,
309 | labels,
310 | loc="lower right",
311 | ncols=1,
312 | )
313 |
314 | if len(values.keys()) > 1:
315 | ax.set_title(env, fontsize=8)
316 | ax.set_xlabel(xlabel)
317 | ax.tick_params(axis="x", which="major", pad=-3)
318 | ax.tick_params(axis="y", which="major", pad=-3)
319 |
320 | else:
321 | ax.set_title(env, fontsize=8)
322 | ax.set_xlabel(xlabel)
323 |
324 | ax.set_ylabel(ylabel)
325 |
326 | fig.savefig(
327 | f"{p}/{env}/{title}_{ylabel}_{env}.pdf",
328 | format="pdf",
329 | bbox_inches="tight",
330 | )
331 |
332 | if training:
333 | fig.savefig(
334 | f"{p}/{env}/{title}_{ylabel}_{env}.png",
335 | format="png",
336 | bbox_inches="tight",
337 | )
338 | wandb.log(
339 | {
340 | f"eval_figs/{env}/{ylabel}": wandb.Image(
341 | f"{p}/{env}/{title}_{ylabel}_{env}.png"
342 | )
343 | },
344 | step=exp_num,
345 | )
346 |
347 | plot_all(
348 | returns_list,
349 | all_configs,
350 | ["OPEN"],
351 | xlabel="Frames",
352 | ylabel=f"Return",
353 | title=title,
354 | training=training,
355 | )
356 |
357 | plot_all(
358 | tau_dormancies,
359 | all_configs,
360 | ["OPEN"],
361 | xlabel="Updates",
362 | ylabel=f"Dormancy",
363 | title=title,
364 | training=training,
365 | )
366 |
367 | for env in envs:
368 | print(f"env: {env}, run_time : {runtimes[env]}")
369 |
370 |
371 | if __name__ == "__main__":
372 | parser = argparse.ArgumentParser()
373 | parser.add_argument("--exp-name", nargs="+", type=str, default=None)
374 | parser.add_argument("--exp-num", nargs="+", type=str, default=None)
375 | parser.add_argument("--envs", nargs="+", required=False, default=None)
376 | parser.add_argument("--num-runs", type=int, default=3)
377 | parser.add_argument(
378 | "--pmap", default=jax.local_device_count() > 1, action="store_true"
379 | )
380 | parser.add_argument("--title", type=str)
381 | parser.add_argument("--larger", default=False, action="store_true")
382 | parser.add_argument("--pretrained", default=False, action="store_true")
383 | parser.add_argument("--multi", default=False, action="store_true")
384 |
385 | args = parser.parse_args()
386 |
387 | sns.set()
388 | plt.style.use("seaborn-v0_8-colorblind")
389 |
390 | tex_fonts = {
391 | "axes.labelsize": 6,
392 | "font.size": 8,
393 | "legend.fontsize": 5,
394 | "xtick.labelsize": 6,
395 | "ytick.labelsize": 6,
396 | }
397 |
398 | matplotlib.rcParams.update(tex_fonts)
399 | matplotlib.rcParams["axes.formatter.limits"] = [-3, 3]
400 | color_palette = sns.color_palette("colorblind", n_colors=5)
401 | plt.rcParams["axes.prop_cycle"] = plt.cycler(color=color_palette)
402 | devices = jax.local_devices()
403 | mesh = Mesh(devices, axis_names=("dim",))
404 |
405 | if args.envs == ["gridworld"]:
406 | args.envs = [
407 | "sixteen_rooms",
408 | "labyrinth",
409 | "rand_sparse",
410 | "rand_dense",
411 | "rand_long",
412 | "standard_maze",
413 | "rand_all",
414 | ]
415 | grids = True
416 | else:
417 | grids = False
418 |
419 | make_plot(
420 | exp_names=args.exp_name,
421 | exp_nums=args.exp_num,
422 | envs=args.envs,
423 | num_runs=args.num_runs,
424 | pmap=args.pmap,
425 | title=args.title,
426 | larger=args.larger,
427 | pretrained=args.pretrained,
428 | multi=args.multi,
429 | grids=grids,
430 | mesh=mesh,
431 | )
432 |
--------------------------------------------------------------------------------
/rl_optimizer/network.py:
--------------------------------------------------------------------------------
1 | """Full OPEN optimizer, incorporating all input features and learnable stochasticity"""
2 |
3 | from typing import Any, Optional
4 |
5 | import flax
6 | import flax.linen as nn
7 | import gin
8 | import jax
9 | from jax import lax
10 | import jax.numpy as jnp
11 | from optax import adam
12 | import optax
13 |
14 | import sys
15 |
16 | import base as opt_base
17 |
18 | from learned_optimization.learned_optimization import tree_utils
19 | from learned_optimization.learned_optimization.learned_optimizers import (
20 | common,
21 | )
22 |
23 |
24 | PRNGKey = jnp.ndarray
25 |
26 |
27 | def _second_moment_normalizer(x, axis, eps=1e-5):
28 | return x * lax.rsqrt(eps + jnp.mean(jnp.square(x), axis=axis, keepdims=True))
29 |
30 |
31 | def iter_proportion(iterations, total_its=100000):
32 | f32 = jnp.float32
33 |
34 | return iterations / f32(total_its)
35 |
36 |
37 | @flax.struct.dataclass
38 | class GRUOptState:
39 | params: Any
40 | rolling_features: common.MomAccumulator
41 | iteration: jnp.ndarray
42 | state: Any
43 | carry: Any
44 |
45 |
46 | @gin.configurable
47 | class GRU_Opt(opt_base.LearnedOptimizer):
48 | def __init__(
49 | self,
50 | exp_mult=0.001,
51 | step_mult=0.001,
52 | hidden_size=16,
53 | gru_features=8,
54 | ):
55 |
56 | super().__init__()
57 | self._step_mult = step_mult
58 | self._exp_mult = exp_mult
59 |
60 | self.gru_features = gru_features
61 |
62 | self._gru = nn.GRUCell(features=self.gru_features)
63 |
64 | self._mod = nn.Sequential(
65 | [
66 | nn.Dense(hidden_size),
67 | nn.LayerNorm(),
68 | nn.relu,
69 | nn.Dense(hidden_size),
70 | nn.LayerNorm(),
71 | nn.relu,
72 | nn.Dense(3),
73 | ]
74 | )
75 |
76 | def init(self, key: PRNGKey) -> opt_base.MetaParams:
77 | # There are 19 features used as input. For now, hard code this.
78 | key = jax.random.split(key, 5)
79 |
80 | proxy_carry = self._gru.initialize_carry(key[4], (1,))
81 |
82 | return {
83 | "params": self._mod.init(key[0], jnp.zeros([self.gru_features])),
84 | "gru_params": self._gru.init(key[2], proxy_carry, jnp.zeros([19])),
85 | }
86 |
87 | def opt_fn(
88 | self, theta: opt_base.MetaParams, is_training: bool = False
89 | ) -> opt_base.Optimizer:
90 | # ALL MOMENTUM TIMESCALES
91 | decays = jnp.asarray([0.1, 0.5, 0.9, 0.99, 0.999, 0.9999])
92 |
93 | mod = self._mod
94 | gru = self._gru
95 | exp_mult = self._exp_mult
96 | step_mult = self._step_mult
97 |
98 | theta_mlp = theta["params"]
99 | theta_gru = theta["gru_params"]
100 |
101 | class _Opt(opt_base.Optimizer):
102 | def init(
103 | self,
104 | params: opt_base.Params,
105 | model_state: Any = None,
106 | num_steps: Optional[int] = None,
107 | key: Optional[PRNGKey] = None,
108 | ) -> GRUOptState:
109 | """Initialize opt state."""
110 |
111 | param_tree = jax.tree_util.tree_structure(params)
112 |
113 | keys = jax.random.split(key, param_tree.num_leaves)
114 | keys = jax.tree_util.tree_unflatten(param_tree, keys)
115 |
116 | carry = jax.tree_util.tree_map(
117 | lambda p, k: gru.initialize_carry(k, jnp.expand_dims(p, -1).shape),
118 | params,
119 | keys,
120 | )
121 |
122 | return GRUOptState(
123 | params=params,
124 | state=model_state,
125 | rolling_features=common.vec_rolling_mom(decays).init(params),
126 | iteration=jnp.asarray(0, dtype=jnp.int32),
127 | carry=carry,
128 | )
129 |
130 | def update(
131 | self,
132 | opt_state_actor: GRUOptState,
133 | crit_opt_state: GRUOptState,
134 | grad: Any,
135 | activations: float,
136 | key: Optional[PRNGKey] = None,
137 | training_prop=0,
138 | batch_prop=0,
139 | config=None,
140 | layer_props=None,
141 | model_state: Any = None,
142 | mask=None,
143 | ) -> GRUOptState:
144 |
145 | next_rolling_features_actor = common.vec_rolling_mom(decays).update(
146 | opt_state_actor.rolling_features, grad["actor"]
147 | )
148 |
149 | next_rolling_features_critic = common.vec_rolling_mom(decays).update(
150 | crit_opt_state.rolling_features, grad["critic"]
151 | )
152 |
153 | rolling_features = {
154 | "actor": next_rolling_features_actor.m,
155 | "critic": next_rolling_features_critic.m,
156 | }
157 |
158 | training_step_feature = training_prop
159 | batch_feature = batch_prop
160 | eps1 = 1e-13
161 | eps2 = 1e-8
162 | t = opt_state_actor.iteration + 1
163 |
164 | def _update_tensor(p, g, mom, k, dorm, carry, layer_prop, mask):
165 |
166 | # this doesn't work with scalar parameters, so let's reshape.
167 | if not p.shape:
168 | p = jnp.expand_dims(p, 0)
169 |
170 | # use gradient conditioning (Optim4RL)
171 | gsign = jnp.expand_dims(jnp.sign(g), 0)
172 | glog = jnp.expand_dims(jnp.log(jnp.abs(g) + eps1), 0)
173 |
174 | mom = jnp.expand_dims(mom, 0)
175 | did_reshape = True
176 | else:
177 | gsign = jnp.sign(g)
178 | glog = jnp.log(jnp.abs(g) + eps1)
179 | did_reshape = False
180 |
181 | inps = []
182 | inp_g = []
183 |
184 | batch_gsign = jnp.expand_dims(gsign, axis=-1)
185 | batch_glog = jnp.expand_dims(glog, axis=-1)
186 |
187 | # feature consisting of raw parameter values
188 | batch_p = jnp.expand_dims(p, axis=-1)
189 | inps.append(batch_p)
190 |
191 | # feature consisting of all momentum values
192 | momsign = jnp.sign(mom)
193 | momlog = jnp.log(jnp.abs(mom) + eps1)
194 | inps.append(momsign)
195 | inps.append(momlog)
196 |
197 | inp_stack = jnp.concatenate(inps, axis=-1)
198 |
199 | axis = list(range(len(p.shape)))
200 |
201 | inp_stack_g = jnp.concatenate(
202 | [inp_stack, batch_gsign, batch_glog], axis=-1
203 | )
204 |
205 | inp_stack_g = _second_moment_normalizer(inp_stack_g, axis=axis)
206 |
207 | # once normalized, add features that are constant across tensor.
208 | # namly the training proportion, batch proportion, parameter value and dormancy
209 | def stack_tensors(feature, input):
210 |
211 | stacked = jnp.reshape(
212 | feature, [1] * len(axis) + list(feature.shape[-1:])
213 | )
214 | stacked = jnp.tile(stacked, list(p.shape) + [1])
215 | return jnp.concatenate([input, stacked], axis=-1)
216 |
217 | inp = jnp.tile(
218 | jnp.reshape(
219 | training_step_feature,
220 | [1] * len(axis) + list(training_step_feature.shape[-1:]),
221 | ),
222 | list(p.shape) + [1],
223 | )
224 |
225 | stacked_batch_prop = jnp.tile(
226 | jnp.reshape(
227 | batch_feature,
228 | [1] * len(axis) + list(batch_feature.shape[-1:]),
229 | ),
230 | list(p.shape) + [1],
231 | )
232 |
233 | layer_prop = jnp.expand_dims(layer_prop, 0)
234 |
235 | stacked_layer_prop = jnp.tile(
236 | jnp.reshape(
237 | layer_prop, [1] * len(axis) + list(layer_prop.shape[-1:])
238 | ),
239 | list(p.shape) + [1],
240 | )
241 |
242 | inp = jnp.concatenate([inp, stacked_layer_prop], axis=-1)
243 |
244 | inp = jnp.concatenate([inp, stacked_batch_prop], axis=-1)
245 |
246 | batch_dorm = jnp.expand_dims(dorm, axis=-1)
247 |
248 | if p.shape != dorm.shape:
249 | batch_dorm = jnp.tile(
250 | batch_dorm, [p.shape[0]] + len(axis) * [1]
251 | )
252 |
253 | inp = jnp.concatenate([inp, batch_dorm], axis=-1)
254 |
255 | inp_g = jnp.concatenate([inp_stack_g, inp], axis=-1)
256 |
257 | gru_new_carry, gru_out = gru.apply(theta_gru, carry, inp_g)
258 |
259 | # apply the per parameter MLP.
260 | output = mod.apply(theta_mlp, gru_out)
261 |
262 | update_ = (
263 | output[..., 0] * step_mult * jnp.exp(output[..., 1] * exp_mult)
264 | )
265 |
266 | # Add the stochasticity *only* to the actor (using the mask)
267 | update = (
268 | update_
269 | + output[..., 2]
270 | * mask
271 | * jax.random.normal(k, shape=update_.shape)
272 | * step_mult
273 | )
274 |
275 | update = update.reshape(p.shape)
276 |
277 | return (update, gru_new_carry)
278 |
279 | full_params = {
280 | "actor": opt_state_actor.params,
281 | "critic": crit_opt_state.params,
282 | }
283 | param_tree = jax.tree_util.tree_structure(full_params)
284 |
285 | keys = jax.random.split(key, param_tree.num_leaves)
286 | keys = jax.tree_util.tree_unflatten(param_tree, keys)
287 |
288 | activations = jax.tree_util.tree_flatten(activations)[0]
289 |
290 | activations = jax.tree_util.tree_unflatten(param_tree, activations)
291 |
292 | def calc_dormancy(tensor_activations):
293 | tensor_activations = tensor_activations + 1e-11
294 | total_activations = jnp.abs(tensor_activations).sum(axis=-1)
295 | total_activations = jnp.tile(
296 | jnp.expand_dims(total_activations, -1),
297 | tensor_activations.shape[-1],
298 | )
299 | dormancy = (
300 | tensor_activations
301 | / total_activations
302 | * tensor_activations.shape[-1]
303 | )
304 | return dormancy
305 |
306 | dormancies = jax.tree_util.tree_map(calc_dormancy, activations)
307 |
308 | full_carry = {
309 | "actor": opt_state_actor.carry,
310 | "critic": crit_opt_state.carry,
311 | }
312 |
313 | updates_carry = jax.tree_util.tree_map(
314 | _update_tensor,
315 | full_params,
316 | grad,
317 | rolling_features,
318 | keys,
319 | dormancies,
320 | full_carry,
321 | layer_props,
322 | mask,
323 | )
324 |
325 | updates_carry_leaves = jax.tree_util.tree_leaves(updates_carry)
326 | updates = [
327 | updates_carry_leaves[i]
328 | for i in range(0, len(updates_carry_leaves), 2)
329 | ]
330 | new_carry = [
331 | updates_carry_leaves[i + 1]
332 | for i in range(0, len(updates_carry_leaves), 2)
333 | ]
334 |
335 | updates = jax.tree_util.tree_unflatten(param_tree, updates)
336 | new_carry = jax.tree_util.tree_unflatten(param_tree, new_carry)
337 |
338 | # Make update globally 0
339 | updates_flat = jax.flatten_util.ravel_pytree(updates)[0]
340 | update_mean = updates_flat.mean()
341 | update_mean = jax.tree_util.tree_unflatten(
342 | param_tree, jnp.tile(update_mean, param_tree.num_leaves)
343 | )
344 |
345 | updates = jax.tree_util.tree_map(
346 | lambda x, mu: x - mu, updates, update_mean
347 | )
348 |
349 | def param_update(p, update):
350 |
351 | new_param = p - update
352 |
353 | return new_param
354 |
355 | next_params = jax.tree_util.tree_map(param_update, full_params, updates)
356 |
357 | # For simplicity, maitain different opt states between the actor and the critic
358 | next_opt_state_actor = GRUOptState(
359 | params=tree_utils.match_type(
360 | next_params["actor"], opt_state_actor.params
361 | ),
362 | rolling_features=tree_utils.match_type(
363 | next_rolling_features_actor, opt_state_actor.rolling_features
364 | ),
365 | iteration=opt_state_actor.iteration + 1,
366 | state=model_state,
367 | carry=new_carry["actor"],
368 | )
369 |
370 | next_opt_state_critic = GRUOptState(
371 | params=tree_utils.match_type(
372 | next_params["critic"], crit_opt_state.params
373 | ),
374 | rolling_features=tree_utils.match_type(
375 | next_rolling_features_critic, crit_opt_state.rolling_features
376 | ),
377 | iteration=opt_state_actor.iteration + 1,
378 | state=model_state,
379 | carry=new_carry["critic"],
380 | )
381 |
382 | param_flat_actor = jax.flatten_util.ravel_pytree(next_params["actor"])[
383 | 0
384 | ]
385 | if config["VISUALISE"]:
386 | param_mean_actor = jnp.mean(jnp.array(param_flat_actor))
387 | param_abs_mean_actor = jnp.mean(jnp.abs(jnp.array(param_flat_actor)))
388 |
389 | param_flat_critic = jax.flatten_util.ravel_pytree(
390 | next_params["critic"]
391 | )[0]
392 | if config["VISUALISE"]:
393 | param_mean_critic = jnp.mean(jnp.array(param_flat_critic))
394 | param_abs_mean_critic = jnp.mean(jnp.abs(jnp.array(param_flat_critic)))
395 |
396 | if config["VISUALISE"]:
397 | return (
398 | next_opt_state_actor,
399 | next_opt_state_critic,
400 | (param_mean_actor, param_abs_mean_actor),
401 | (param_mean_critic, param_abs_mean_critic),
402 | )
403 | else:
404 | return (
405 | next_opt_state_actor,
406 | next_opt_state_critic,
407 | (param_abs_mean_actor),
408 | (param_abs_mean_critic),
409 | )
410 |
411 | return _Opt()
412 |
--------------------------------------------------------------------------------
/rl_optimizer/pretrained/ant_OPEN.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/ant_OPEN.npy
--------------------------------------------------------------------------------
/rl_optimizer/pretrained/asterix_OPEN.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/asterix_OPEN.npy
--------------------------------------------------------------------------------
/rl_optimizer/pretrained/breakout_OPEN.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/breakout_OPEN.npy
--------------------------------------------------------------------------------
/rl_optimizer/pretrained/freeway_OPEN.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/freeway_OPEN.npy
--------------------------------------------------------------------------------
/rl_optimizer/pretrained/gridworld_OPEN.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/gridworld_OPEN.npy
--------------------------------------------------------------------------------
/rl_optimizer/pretrained/multi_OPEN.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/multi_OPEN.npy
--------------------------------------------------------------------------------
/rl_optimizer/pretrained/spaceinvaders_OPEN.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/spaceinvaders_OPEN.npy
--------------------------------------------------------------------------------
/rl_optimizer/train.py:
--------------------------------------------------------------------------------
1 | """(inner and outer) training loop for OPEN"""
2 |
3 | import numpy as np
4 | from typing import Sequence, NamedTuple, Any
5 | import os
6 | import os.path as osp
7 | from datetime import datetime
8 | from tqdm import tqdm
9 | import wandb
10 | import time
11 | from functools import partial
12 | import argparse
13 |
14 | from network import GRU_Opt
15 | from configs import all_configs
16 | from eval import make_plot
17 | from utils import (
18 | GymnaxGymWrapper,
19 | GymnaxLogWrapper,
20 | FlatWrapper,
21 | BraxGymnaxWrapper,
22 | ClipAction,
23 | TransformObservation,
24 | NormalizeObservation,
25 | NormalizeReward,
26 | VecEnv,
27 | )
28 |
29 | import jax
30 | import jax.numpy as jnp
31 | import flax.linen as nn
32 | import optax
33 | from flax.linen.initializers import constant, orthogonal
34 | from flax.training.train_state import TrainState
35 | import flax
36 | import distrax
37 | import gymnax
38 | from brax.envs.wrappers.gym import GymWrapper
39 | from brax import envs
40 | import evosax
41 | from evosax.algorithms.distribution_based import Open_ES
42 | from evosax.core.fitness_shaping import (
43 | centered_rank_fitness_shaping_fn,
44 | identity_fitness_shaping_fn,
45 | )
46 | from gymnax.environments import spaces
47 | from optax import adam
48 | from jax.sharding import Mesh
49 | from jax.sharding import PartitionSpec as P, NamedSharding
50 |
51 | import sys
52 |
53 | # GROOVE imports cause an issue
54 | groove_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "groove"))
55 | sys.path.insert(0, groove_path)
56 |
57 | # Import the modules you need
58 | # The 'groove' module will be found in the original path
59 | # The 'environments' module will be found via the path we just added
60 | import groove.environments.gridworld.configs as grid_conf
61 | from groove.environments.gridworld import gridworld as grid
62 | from groove.environments.gridworld.configs import ENV_MODE_KWARGS
63 |
64 | sys.path.remove(groove_path)
65 |
66 |
67 | class Actor(nn.Module):
68 | action_dim: Sequence[int]
69 | config: dict
70 |
71 | @nn.compact
72 | def __call__(self, x):
73 | hsize = self.config["HSIZE"]
74 | if self.config["ACTIVATION"] == "relu":
75 | activation = nn.relu
76 | else:
77 | activation = nn.tanh
78 |
79 | actor_mean = nn.Dense(
80 | hsize, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
81 | )(x)
82 | actor_mean = activation(actor_mean)
83 | actor_mean_activation_1 = {
84 | "kernel": jnp.mean(jnp.abs(actor_mean), axis=0),
85 | "bias": jnp.mean(jnp.abs(actor_mean), axis=0),
86 | }
87 |
88 | actor_mean = nn.Dense(
89 | hsize, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
90 | )(actor_mean)
91 | actor_mean = activation(actor_mean)
92 | actor_mean_activation_2 = {
93 | "kernel": jnp.mean(jnp.abs(actor_mean), axis=0),
94 | "bias": jnp.mean(jnp.abs(actor_mean), axis=0),
95 | }
96 | actor_mean = nn.Dense(
97 | self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
98 | )(actor_mean)
99 | actor_mean_activation_3 = {
100 | "kernel": jnp.mean(jnp.abs(actor_mean), axis=0),
101 | "bias": jnp.mean(jnp.abs(actor_mean), axis=0),
102 | }
103 |
104 | if self.config["CONTINUOUS"]:
105 | actor_logtstd = self.param(
106 | "log_std", nn.initializers.zeros, (self.action_dim,)
107 | )
108 | actor_mean_activation_4 = jnp.expand_dims(
109 | jnp.mean(jnp.exp(actor_logtstd), axis=0), axis=0
110 | )
111 | pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))
112 | else:
113 | pi = distrax.Categorical(logits=actor_mean)
114 |
115 | if self.config["CONTINUOUS"]:
116 | activations = (
117 | actor_mean_activation_1,
118 | actor_mean_activation_2,
119 | actor_mean_activation_3,
120 | actor_mean_activation_4,
121 | )
122 | else:
123 | activations = (
124 | actor_mean_activation_1,
125 | actor_mean_activation_2,
126 | actor_mean_activation_3,
127 | )
128 |
129 | return pi, activations
130 |
131 |
132 | class Critic(nn.Module):
133 | config: dict
134 |
135 | @nn.compact
136 | def __call__(self, x):
137 | hsize = self.config["HSIZE"]
138 | if self.config["ACTIVATION"] == "relu":
139 | activation = nn.relu
140 | else:
141 | activation = nn.tanh
142 |
143 | critic = nn.Dense(
144 | hsize, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
145 | )(x)
146 | critic = activation(critic)
147 | critic_mean_activation_1 = {
148 | "kernel": jnp.mean(jnp.abs(critic), axis=0),
149 | "bias": jnp.mean(jnp.abs(critic), axis=0),
150 | }
151 | critic = nn.Dense(
152 | hsize, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
153 | )(critic)
154 | critic = activation(critic)
155 | critic_mean_activation_2 = {
156 | "kernel": jnp.mean(jnp.abs(critic), axis=0),
157 | "bias": jnp.mean(jnp.abs(critic), axis=0),
158 | }
159 | critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
160 | critic
161 | )
162 | critic_mean_activation_3 = {
163 | "kernel": jnp.mean(jnp.abs(critic), axis=0),
164 | "bias": jnp.mean(jnp.abs(critic), axis=0),
165 | }
166 |
167 | activations = (
168 | critic_mean_activation_1,
169 | critic_mean_activation_2,
170 | critic_mean_activation_3,
171 | )
172 | return jnp.squeeze(critic, axis=-1), activations
173 |
174 |
175 | class Transition(NamedTuple):
176 | done: jnp.ndarray
177 | action: jnp.ndarray
178 | value: jnp.ndarray
179 | reward: jnp.ndarray
180 | log_prob: jnp.ndarray
181 | obs: jnp.ndarray
182 | info: jnp.ndarray
183 |
184 |
185 | def symlog(x):
186 | return jnp.sign(x) * jnp.log(jnp.abs(x) + 1)
187 |
188 |
189 | gridtypes = [
190 | "sixteen_rooms",
191 | "labyrinth",
192 | "rand_sparse",
193 | "rand_dense",
194 | "rand_long",
195 | "standard_maze",
196 | "rand_all",
197 | ]
198 |
199 |
200 | def make_train(config):
201 |
202 | config["NUM_UPDATES"] = (
203 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
204 | )
205 | config["MINIBATCH_SIZE"] = (
206 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
207 | )
208 | config["TOTAL_UPDATES"] = (
209 | config["NUM_UPDATES"] * config["UPDATE_EPOCHS"] * config["NUM_MINIBATCHES"]
210 | )
211 |
212 | @partial(jax.jit, static_argnames=["grid_type"])
213 | def train(meta_params, rng, grid_type=6):
214 |
215 | if "Brax-" in config["ENV_NAME"]:
216 | name = config["ENV_NAME"].split("Brax-")[1]
217 | env, env_params = BraxGymnaxWrapper(env_name=name), None
218 | if config.get("CLIP_ACTION"):
219 | env = ClipAction(env)
220 | env = GymnaxLogWrapper(env)
221 | if config.get("SYMLOG_OBS"):
222 | env = TransformObservation(env, transform_obs=symlog)
223 |
224 | env = VecEnv(env)
225 | if config.get("NORMALIZE"):
226 | env = NormalizeObservation(env)
227 | env = NormalizeReward(env, config["GAMMA"])
228 | actor = Actor(env.action_space(env_params).shape[0], config=config)
229 | critic = Critic(config=config)
230 | init_x = jnp.zeros(env.observation_space(env_params).shape)
231 | else:
232 | # INIT ENV
233 | if config["ENV_NAME"] == "gridworld":
234 | env = grid.GridWorld(**ENV_MODE_KWARGS[gridtypes[grid_type]])
235 | rng, rng_ = jax.random.split(rng)
236 | env_params = grid_conf.reset_env_params(rng_, gridtypes[grid_type])
237 | else:
238 | env, env_params = gymnax.make(config["ENV_NAME"])
239 | env = GymnaxGymWrapper(env, env_params, config)
240 | env = FlatWrapper(env)
241 |
242 | env = GymnaxLogWrapper(env)
243 | env = VecEnv(env)
244 | actor = Actor(env.action_space, config=config)
245 | critic = Critic(config=config)
246 | init_x = jnp.zeros(env.observation_space)
247 |
248 | # INIT NETWORK
249 |
250 | rng, _rng = jax.random.split(rng)
251 | actor_params = actor.init(_rng, init_x)
252 | critic_params = critic.init(_rng, init_x)
253 |
254 | if config["larger"]:
255 | meta_opt = GRU_Opt(hidden_size=32, gru_features=16)
256 | else:
257 | meta_opt = GRU_Opt(hidden_size=16, gru_features=8)
258 | opt = meta_opt.opt_fn(meta_params)
259 | clip_opt = optax.clip_by_global_norm(config["MAX_GRAD_NORM"])
260 |
261 | rng, rng_act, rng_crit = jax.random.split(rng, 3)
262 | train_state_actor = opt.init(actor_params, key=rng_act)
263 | train_state_critic = opt.init(critic_params, key=rng_crit)
264 |
265 | act_param_tree = jax.tree_util.tree_structure(train_state_actor.params)
266 | act_layer_props = []
267 | num_act_layers = len(train_state_actor.params["params"])
268 | for i, layer in enumerate(train_state_actor.params["params"]):
269 | layer_prop = i / (num_act_layers - 1)
270 | if type(train_state_actor.params["params"][layer]) == dict:
271 | act_layer_props.extend(
272 | [layer_prop] * len(train_state_actor.params["params"][layer])
273 | )
274 | else:
275 | act_layer_props.extend([layer_prop])
276 |
277 | act_layer_props = jax.tree_util.tree_unflatten(act_param_tree, act_layer_props)
278 |
279 | crit_param_tree = jax.tree_util.tree_structure(train_state_critic.params)
280 | crit_layer_props = []
281 | num_crit_layers = len(train_state_critic.params["params"])
282 | for i, layer in enumerate(train_state_critic.params["params"]):
283 | layer_prop = i / (num_crit_layers - 1)
284 | if type(train_state_critic.params["params"][layer]) == dict:
285 | crit_layer_props.extend(
286 | [layer_prop] * len(train_state_critic.params["params"][layer])
287 | )
288 | else:
289 | crit_layer_props.extend([layer_prop])
290 |
291 | crit_layer_props = jax.tree_util.tree_unflatten(
292 | crit_param_tree, crit_layer_props
293 | )
294 |
295 | # INIT ENV
296 | all_rng = jax.random.split(_rng, config["NUM_ENVS"] + 1)
297 | rng, _rng = all_rng[0], all_rng[1:]
298 | obsv, env_state = env.reset(_rng, env_params)
299 |
300 | # TRAIN LOOP
301 | def _update_step(runner_state, unused):
302 |
303 | # COLLECT TRAJECTORIES
304 | def _env_step(runner_state, unused):
305 | (
306 | train_state_actor,
307 | train_state_critic,
308 | env_state,
309 | last_obs,
310 | last_done,
311 | rng,
312 | last_param_abs,
313 | ) = runner_state
314 | rng, _rng = jax.random.split(rng)
315 | # SELECT ACTION
316 | pi, _ = actor.apply(train_state_actor.params, last_obs)
317 | value, _ = critic.apply(train_state_critic.params, last_obs)
318 | action = pi.sample(seed=_rng)
319 |
320 | log_prob = pi.log_prob(action)
321 | # STEP ENV
322 | rng, _rng = jax.random.split(rng)
323 | rng_step = jax.random.split(_rng, config["NUM_ENVS"])
324 | obsv, env_state, reward, done, info = env.step(
325 | rng_step, env_state, action, env_params
326 | )
327 | transition = Transition(
328 | done, action, value, reward, log_prob, last_obs, info
329 | )
330 | runner_state = (
331 | train_state_actor,
332 | train_state_critic,
333 | env_state,
334 | obsv,
335 | done,
336 | rng,
337 | 0.0,
338 | )
339 |
340 | return runner_state, transition
341 |
342 | runner_state, traj_batch = jax.lax.scan(
343 | _env_step, runner_state, None, config["NUM_STEPS"]
344 | )
345 |
346 | # CALCULATE ADVANTAGE
347 | (
348 | train_state_actor,
349 | train_state_critic,
350 | env_state,
351 | last_obs,
352 | last_done,
353 | rng,
354 | last_param_abs,
355 | ) = runner_state
356 | last_val, _ = critic.apply(train_state_critic.params, last_obs)
357 | last_val = jnp.where(last_done, jnp.zeros_like(last_val), last_val)
358 |
359 | def _calculate_gae(traj_batch, last_val):
360 | def _get_advantages(gae_and_next_value, transition):
361 | gae, next_value = gae_and_next_value
362 | done, value, reward = (
363 | transition.done,
364 | transition.value,
365 | transition.reward,
366 | )
367 | delta = reward + config["GAMMA"] * next_value * (1 - done) - value
368 | gae = (
369 | delta
370 | + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
371 | )
372 | return (gae, value), gae
373 |
374 | _, advantages = jax.lax.scan(
375 | _get_advantages,
376 | (jnp.zeros_like(last_val), last_val),
377 | traj_batch,
378 | reverse=True,
379 | unroll=16,
380 | )
381 | return advantages, advantages + traj_batch.value
382 |
383 | advantages, targets = _calculate_gae(traj_batch, last_val)
384 |
385 | # UPDATE NETWORK
386 | def _update_epoch(update_state, unused):
387 | def _update_minbatch(train_state_key, batch_info):
388 | train_state_actor, train_state_critic, key = train_state_key
389 | key, key_ = jax.random.split(key)
390 |
391 | traj_batch, advantages, targets = batch_info
392 |
393 | def _loss_fn(actor_params, critic_params, traj_batch, gae, targets):
394 | # RERUN NETWORK
395 | pi, actor_activations = actor.apply(
396 | actor_params, traj_batch.obs
397 | )
398 | value, critic_activations = critic.apply(
399 | critic_params, traj_batch.obs
400 | )
401 |
402 | log_prob = pi.log_prob(traj_batch.action)
403 |
404 | # CALCULATE VALUE LOSS
405 | value_pred_clipped = traj_batch.value + (
406 | value - traj_batch.value
407 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
408 | value_losses = jnp.square(value - targets)
409 | value_losses_clipped = jnp.square(value_pred_clipped - targets)
410 | value_loss = (
411 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
412 | )
413 |
414 | # CALCULATE ACTOR LOSS
415 | ratio = jnp.exp(log_prob - traj_batch.log_prob)
416 | gae = (gae - gae.mean()) / (gae.std() + 1e-8)
417 | loss_actor1 = ratio * gae
418 | loss_actor2 = (
419 | jnp.clip(
420 | ratio,
421 | 1.0 - config["CLIP_EPS"],
422 | 1.0 + config["CLIP_EPS"],
423 | )
424 | * gae
425 | )
426 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
427 | loss_actor = loss_actor.mean()
428 | entropy = pi.entropy().mean()
429 |
430 | total_loss = (
431 | loss_actor
432 | + config["VF_COEF"] * value_loss
433 | - config["ENT_COEF"] * entropy
434 | )
435 |
436 | return total_loss, (actor_activations, critic_activations)
437 |
438 | training_prop = (
439 | train_state_actor.iteration
440 | // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])
441 | ) / (config["NUM_UPDATES"] - 1)
442 | batch_prop = (
443 | (train_state_actor.iteration // config["NUM_MINIBATCHES"])
444 | % config["UPDATE_EPOCHS"]
445 | ) / (config["UPDATE_EPOCHS"] - 1)
446 |
447 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True, argnums=[0, 1])
448 | (total_loss, (actor_activations, critic_activations)), (
449 | actor_grads,
450 | critic_grads,
451 | ) = grad_fn(
452 | train_state_actor.params,
453 | train_state_critic.params,
454 | traj_batch,
455 | advantages,
456 | targets,
457 | )
458 |
459 | key_actor, key_critic = jax.random.split(key_)
460 | actor_grads, _ = clip_opt.update(actor_grads, None)
461 | critic_grads, _ = clip_opt.update(critic_grads, None)
462 | actor_mask = {"kernel": 1, "bias": 1}
463 | critic_mask = {"kernel": 0, "bias": 0}
464 |
465 | # FOR NOW, HARD CODE MASK
466 | if config["CONTINUOUS"]:
467 | mask = {
468 | "actor": {
469 | "params": {
470 | "Dense_0": actor_mask,
471 | "Dense_1": actor_mask,
472 | "Dense_2": actor_mask,
473 | "log_std": 1,
474 | }
475 | },
476 | "critic": {
477 | "params": {
478 | "Dense_0": critic_mask,
479 | "Dense_1": critic_mask,
480 | "Dense_2": critic_mask,
481 | }
482 | },
483 | }
484 |
485 | else:
486 |
487 | mask = {
488 | "actor": {
489 | "params": {
490 | "Dense_0": actor_mask,
491 | "Dense_1": actor_mask,
492 | "Dense_2": actor_mask,
493 | }
494 | },
495 | "critic": {
496 | "params": {
497 | "Dense_0": critic_mask,
498 | "Dense_1": critic_mask,
499 | "Dense_2": critic_mask,
500 | }
501 | },
502 | }
503 |
504 | activations = {
505 | "actor": actor_activations,
506 | "critic": critic_activations,
507 | }
508 | grads = {"actor": actor_grads, "critic": critic_grads}
509 | layer_props = {"actor": act_layer_props, "critic": crit_layer_props}
510 |
511 | # APPLY OPTIMIZER
512 | (
513 | train_state_actor,
514 | train_state_critic,
515 | actor_updates,
516 | critic_updates,
517 | ) = opt.update(
518 | train_state_actor,
519 | train_state_critic,
520 | grads,
521 | activations,
522 | key=key_actor,
523 | training_prop=training_prop,
524 | config=config,
525 | batch_prop=batch_prop,
526 | layer_props=layer_props,
527 | mask=mask,
528 | )
529 |
530 | train_state_key_ = (train_state_actor, train_state_critic, key)
531 | return train_state_key_, (
532 | total_loss,
533 | actor_updates,
534 | critic_updates,
535 | actor_activations,
536 | critic_activations,
537 | )
538 |
539 | (
540 | train_state_actor,
541 | train_state_critic,
542 | traj_batch,
543 | advantages,
544 | targets,
545 | rng,
546 | ) = update_state
547 | rng, _rng = jax.random.split(rng)
548 | batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
549 | assert (
550 | batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
551 | ), "batch size must be equal to number of steps * number of envs"
552 | permutation = jax.random.permutation(_rng, batch_size)
553 | batch = (traj_batch, advantages, targets)
554 | batch = jax.tree_util.tree_map(
555 | lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
556 | )
557 | shuffled_batch = jax.tree_util.tree_map(
558 | lambda x: jnp.take(x, permutation, axis=0), batch
559 | )
560 | minibatches = jax.tree_util.tree_map(
561 | lambda x: jnp.reshape(
562 | x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
563 | ),
564 | shuffled_batch,
565 | )
566 |
567 | train_state_key = (train_state_actor, train_state_critic, rng)
568 | train_state_key, total_loss_updates = jax.lax.scan(
569 | _update_minbatch, train_state_key, minibatches
570 | )
571 | train_state_actor, train_state_critic, rng = train_state_key
572 | update_state = (
573 | train_state_actor,
574 | train_state_critic,
575 | traj_batch,
576 | advantages,
577 | targets,
578 | rng,
579 | )
580 |
581 | return update_state, total_loss_updates
582 |
583 | update_state = (
584 | train_state_actor,
585 | train_state_critic,
586 | traj_batch,
587 | advantages,
588 | targets,
589 | rng,
590 | )
591 | update_state, loss_update = jax.lax.scan(
592 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
593 | )
594 |
595 | train_state_actor = update_state[0]
596 | train_state_critic = update_state[1]
597 |
598 | (
599 | loss_info,
600 | actor_updates,
601 | critic_updates,
602 | actor_activations,
603 | critic_activations,
604 | ) = loss_update
605 | if config["VISUALISE"]:
606 |
607 | metric = traj_batch.info
608 |
609 | else:
610 | metric = dict()
611 |
612 | metric.update(
613 | {
614 | "returned_episode_returns": traj_batch.info[
615 | "returned_episode_returns"
616 | ][-1].mean()
617 | }
618 | )
619 |
620 | rng = update_state[-1]
621 |
622 | actor_size = len(jax.flatten_util.ravel_pytree(runner_state[0].params)[0])
623 | critic_size = len(jax.flatten_util.ravel_pytree(runner_state[1].params)[0])
624 | if config["VISUALISE"]:
625 | actor_abs = jax.flatten_util.ravel_pytree(actor_updates[1])[0]
626 | critic_abs = jax.flatten_util.ravel_pytree(actor_updates[1])[0]
627 | else:
628 | actor_abs = jax.flatten_util.ravel_pytree(actor_updates)[0][-1]
629 | critic_abs = jax.flatten_util.ravel_pytree(actor_updates)[0][-1]
630 | abs_param_mean = (actor_size * actor_abs + critic_size * critic_abs) / (
631 | actor_size + critic_size
632 | )
633 | runner_state = (
634 | train_state_actor,
635 | train_state_critic,
636 | env_state,
637 | last_obs,
638 | last_done,
639 | rng,
640 | abs_param_mean.mean(),
641 | )
642 | if config["VISUALISE"]:
643 |
644 | # CALCULATE DORMANCY FOR TRACKING
645 | def calc_dormancy(tensor_activations, tau=1e-7):
646 | tensor_activations = tensor_activations + 1e-11
647 | total_activations = jnp.abs(tensor_activations).sum(axis=-1)
648 | total_activations = jnp.tile(
649 | jnp.expand_dims(total_activations, -1),
650 | tensor_activations.shape[-1],
651 | )
652 | dormancy = (
653 | tensor_activations
654 | / total_activations
655 | * tensor_activations.shape[-1]
656 | )
657 | tau_dormancy = dormancy < tau
658 | tau_dormancy = tau_dormancy.sum(axis=(-1))
659 | return tau_dormancy
660 |
661 | full_activations = (actor_activations, critic_activations)
662 | dormancies = jax.tree_util.tree_map(calc_dormancy, full_activations)
663 | dormancies = jax.tree_util.tree_flatten(dormancies)[0]
664 |
665 | prop_dormancies = jnp.stack(dormancies).sum(axis=0)
666 | prop_dormancies = (
667 | prop_dormancies
668 | / len(jax.flatten_util.ravel_pytree(full_activations)[0])
669 | * prop_dormancies.shape[0]
670 | * prop_dormancies.shape[1]
671 | )
672 | param_mean = (
673 | actor_size * jax.flatten_util.ravel_pytree(actor_updates[0])[0]
674 | + critic_size * jax.flatten_util.ravel_pytree(critic_updates[0])[0]
675 | ) / (actor_size + critic_size)
676 | log_updates = {
677 | "param_mean": param_mean,
678 | "param_abs_mean": abs_param_mean,
679 | }
680 | metric.update({"dormancy": prop_dormancies})
681 | metric.update(log_updates)
682 |
683 | return runner_state, metric
684 |
685 | rng, _rng = jax.random.split(rng)
686 | runner_state = (
687 | train_state_actor,
688 | train_state_critic,
689 | env_state,
690 | obsv,
691 | jnp.zeros((config["NUM_ENVS"]), dtype=bool),
692 | _rng,
693 | 0.0,
694 | )
695 |
696 | runner_state, metric = jax.lax.scan(
697 | _update_step, runner_state, None, config["NUM_UPDATES"]
698 | )
699 |
700 | return runner_state, metric
701 |
702 | return train
703 |
704 |
705 | if __name__ == "__main__":
706 |
707 | parser = argparse.ArgumentParser()
708 | parser.add_argument("--num-generations", type=int, default=512)
709 | parser.add_argument("--envs", nargs="+", required=True)
710 | parser.add_argument("--lr", type=float, default=0.03)
711 | parser.add_argument("--popsize", type=int, default=64)
712 | parser.add_argument("--num-rollouts", type=int, default=1)
713 | parser.add_argument("--save-every-k", type=int, default=24)
714 | parser.add_argument("--noise-level", type=float, default=0.03)
715 | parser.add_argument(
716 | "--pmap", default=jax.local_device_count() > 1, action="store_true"
717 | )
718 | parser.add_argument("--wandb-name", type=str, default="OPEN")
719 | parser.add_argument("--wandb-entity", type=str, default=None)
720 | parser.add_argument("--sigma-decay", type=float, default=0.999)
721 | parser.add_argument("--lr-decay", type=float, default=0.999)
722 | parser.add_argument("--larger", default=False, action="store_true")
723 |
724 | args = parser.parse_args()
725 |
726 | evo_config = {
727 | "ENV_NAME": args.envs,
728 | "POPULATION_SIZE": args.popsize,
729 | "NUM_GENERATIONS": args.num_generations,
730 | "NUM_ROLLOUTS": args.num_rollouts,
731 | "SAVE_EVERY_K": args.save_every_k,
732 | "NOISE_LEVEL": args.noise_level,
733 | "PMAP": args.pmap,
734 | "LR": args.lr,
735 | "num_GPUs": jax.local_device_count(),
736 | }
737 |
738 | all_configs = {k: all_configs[k] for k in evo_config["ENV_NAME"]}
739 |
740 | save_loc = "training"
741 | if not os.path.exists(save_loc):
742 | os.mkdir(save_loc)
743 | save_dir = f"{save_loc}/{str(datetime.now()).replace(' ', '_')}_optimizer"
744 | os.mkdir(f"{save_dir}")
745 |
746 | popsize = args.popsize
747 | num_generations = args.num_generations
748 | num_rollouts = args.num_rollouts
749 | save_every_k_gens = args.save_every_k
750 |
751 | wandb.init(
752 | project="OPEN",
753 | config=evo_config,
754 | name=args.wandb_name,
755 | entity=args.wandb_entity,
756 | )
757 |
758 | if args.larger:
759 | meta_opt = GRU_Opt(hidden_size=32, gru_features=16)
760 | else:
761 | meta_opt = GRU_Opt(hidden_size=16, gru_features=8)
762 | params = meta_opt.init(jax.random.PRNGKey(0))
763 |
764 | params = jax.tree.map(lambda x: jnp.zeros_like(x), params)
765 |
766 | devices = jax.devices()
767 | mesh = Mesh(devices, axis_names=("dim",))
768 | sharding_p = NamedSharding(
769 | mesh,
770 | P(
771 | "dim",
772 | ),
773 | )
774 | sharding_rng = NamedSharding(
775 | mesh,
776 | P(
777 | "dim",
778 | ),
779 | )
780 |
781 | def make_rollout(train_fn):
782 | def single_rollout(rng_input, meta_params):
783 | params, metrics = train_fn(meta_params, rng_input)
784 |
785 | fitness = metrics["returned_episode_returns"][-1]
786 | return fitness
787 |
788 | vmap_rollout = jax.vmap(single_rollout, in_axes=(0, None))
789 | rollout = jax.jit(jax.vmap(vmap_rollout, in_axes=(0, 0)))
790 |
791 | # if evo_config["PMAP"]:
792 | # rollout = jax.pmap(rollout)
793 |
794 | return rollout
795 |
796 | for k in all_configs.keys():
797 | all_configs[k]["NUM_UPDATES"] = (
798 | all_configs[k]["TOTAL_TIMESTEPS"]
799 | // all_configs[k]["NUM_STEPS"]
800 | // all_configs[k]["NUM_ENVS"]
801 | )
802 | all_configs[k]["larger"] = args.larger
803 |
804 | rollouts = {k: make_rollout(make_train(v)) for k, v in all_configs.items()}
805 |
806 | rng = jax.random.PRNGKey(42)
807 |
808 | if args.envs == ["gridworld"]:
809 | fitness_shaping_fn = identity_fitness_shaping_fn
810 | else:
811 | fitness_shaping_fn = centered_rank_fitness_shaping_fn
812 |
813 | # In practice, you should set decay rate and transition steps to be how much/long you want to decay for.
814 | # We set transition steps = 1 and all other hparams to match the behaviour of the original evosax here.
815 | strategy = Open_ES(
816 | population_size=args.popsize,
817 | optimizer=optax.adam(
818 | optax.schedules.exponential_decay(
819 | evo_config["LR"],
820 | decay_rate=args.lr_decay,
821 | end_value=0,
822 | transition_steps=1,
823 | ),
824 | b1=0.99,
825 | ),
826 | std_schedule=optax.schedules.exponential_decay(
827 | evo_config["NOISE_LEVEL"],
828 | decay_rate=args.sigma_decay,
829 | end_value=0,
830 | transition_steps=1,
831 | ),
832 | solution=params,
833 | fitness_shaping_fn=fitness_shaping_fn,
834 | )
835 |
836 | es_params = strategy.default_params
837 |
838 | state = strategy.init(key=rng, params=es_params, mean=params)
839 |
840 | most_neg = {env: 0 for env in evo_config["ENV_NAME"]}
841 |
842 | fit_history = []
843 | for gen in tqdm(range(num_generations)):
844 |
845 | rng, rng_ask, rng_eval = jax.random.split(rng, 3)
846 | x, state = jax.jit(strategy.ask)(rng_ask, state, es_params)
847 | x = jax.device_put(x, sharding_p)
848 |
849 | # Set up antithetic task sampling by repeating each optimizer.
850 | if args.envs == ["gridworld"]:
851 | x, state = jax.jit(strategy.ask)(rng_ask, state, es_params)
852 | new_orders = [
853 | [i, int(args.popsize / 2) + i] for i in range(int(args.popsize / 2))
854 | ]
855 | new_orders = jnp.array([x for y in new_orders for x in y])
856 | x = x[new_orders]
857 |
858 | fit_info = {}
859 |
860 | wandb.log(
861 | {
862 | "evo/std":state.std,
863 | "evo/lr":evo_config["LR"]*(args.lr_decay**gen),
864 | },
865 | step = gen
866 | )
867 |
868 | all_fitness = []
869 |
870 | for env in args.envs:
871 |
872 | rng, rng_eval = jax.random.split(rng)
873 |
874 | all_configs[env]["VISUALISE"] = False
875 | rollout = rollouts[env]
876 |
877 | # Antithetic task sampling for gridworld - antithetic perturbatiosn are evaluated on the same rng
878 | if args.envs == ["gridworld"]:
879 | batch_rng = jax.random.split(rng_eval, args.popsize / 2)
880 | batch_rng = jnp.repeat(batch_rng, 2, axis=0)
881 |
882 | else:
883 | batch_rng = jax.random.split(rng_eval, num_rollouts)
884 | batch_rng = jnp.tile(batch_rng, (args.popsize, 1, 1))
885 |
886 | if args.pmap:
887 | batch_rng_pmap = jax.device_put(batch_rng, sharding_rng)
888 | fitness = rollout(batch_rng_pmap, x)
889 | fitness = jax.device_get(fitness)
890 | fitness = fitness.reshape(-1, evo_config["NUM_ROLLOUTS"]).mean(axis=1)
891 |
892 | else:
893 | batch_rng = jnp.reshape(batch_rng, (-1, num_rollouts, 2))
894 | fitness = rollout(batch_rng, x)
895 | fitness = fitness.mean(axis=1)
896 |
897 | fitness = jnp.nan_to_num(fitness, nan=-100000)
898 |
899 | print(f"fitness: {fitness}")
900 | print(f"mean fitness_{env} = {jnp.mean(fitness):.3f}")
901 | print(f"fitness_spread at gen {gen} is {fitness.max()-fitness.min()}")
902 |
903 | fit_history.append(fitness.mean())
904 |
905 | fitness_var = jnp.var(fitness)
906 | fitness_spread = fitness.max() - fitness.min()
907 |
908 | # PPO_TEMP is set to the return obtained by PPO with Adam, and is used to normalise across environments
909 | fit_norm = fitness / all_configs[env]["PPO_TEMP"]
910 |
911 | mean_norm = fit_norm.mean()
912 |
913 | wandb.log(
914 | {
915 | f"training/{env}/avg_fitness": fitness.mean(),
916 | f"training/{env}/fitness_histo": wandb.Histogram(
917 | fitness, num_bins=16
918 | ),
919 | f"training/{env}/fitness_spread": fitness_spread,
920 | f"training/{env}/fitness_variance": fitness_var,
921 | f"training/{env}/normalised_fitness": mean_norm,
922 | f"training/{env}/normalised_histo": wandb.Histogram(
923 | fit_norm, num_bins=16
924 | ),
925 | f"training/{env}/best_fitness": jnp.max(fitness),
926 | f"training/{env}/worst_fitness": jnp.min(fitness),
927 | },
928 | step=gen,
929 | )
930 |
931 | all_fitness.append(fit_norm)
932 |
933 | fitnesses = jnp.stack(all_fitness, axis=0)
934 | fitnesses_mean = jnp.mean(fitnesses, axis=0)
935 |
936 | wandb.log(
937 | {
938 | "training/average_over_env/normalised_fitness": fitnesses_mean.mean(),
939 | "training/average_over_env/normalised_histo": wandb.Histogram(
940 | fitnesses_mean, num_bins=16
941 | ),
942 | },
943 | step=gen,
944 | )
945 |
946 | param_sum = state.mean.sum()
947 | param_abs_sum = jnp.abs(state.mean).sum()
948 | param_abs_mean = jnp.abs(state.mean).mean()
949 |
950 | wandb.log(
951 | {
952 | "param_stats/param_sum": state.mean.sum(),
953 | "param_stats/param_abs_sum": jnp.abs(state.mean).sum(),
954 | "param_stats/param_abs_mean": jnp.abs(state.mean).mean(),
955 | "param_stats/param_mean": state.mean.mean(),
956 | },
957 | step=gen,
958 | )
959 |
960 | # normalization for antithetic task sampling
961 | if args.envs == ["gridworld"]:
962 | first_greater = jnp.greater(fitness[::2], fitness[1::2])
963 | rank_fitness = jnp.zeros_like(fitness)
964 | rank_fitness = rank_fitness.at[::2].set(-1 * first_greater.astype(float))
965 | fitnesses_mean = rank_fitness.at[1::2].set(
966 | first_greater.astype(float) - 1.0
967 | )
968 |
969 | # use standard key as update is deterministic
970 | # multiply fitness by -1 since the new evosax is only set to minimise!
971 | x = jax.device_get(x)
972 | state, metrics = jax.jit(strategy.tell)(
973 | jax.random.PRNGKey(0), x, -1 * fitnesses_mean, state, es_params
974 | )
975 | wandb.log(
976 | {"training/average_over_env/best_fitness": -1 * metrics["best_fitness"]}, step=gen
977 | )
978 |
979 | if gen % save_every_k_gens == 0:
980 | print("SAVING & EVALUATING!")
981 | jnp.save(osp.join(save_dir, f"curr_param_{gen}.npy"), state.mean)
982 | np.save(osp.join(save_dir, f"fit_history.npy"), np.array(fit_history))
983 |
984 | wandb.save(
985 | osp.join(save_dir, f"curr_param_{gen}.npy"),
986 | base_path=save_dir,
987 | )
988 | if args.envs == ["gridworld"]:
989 | plot_envs = gridtypes
990 | grids = True
991 | else:
992 | plot_envs = args.envs
993 | grids = False
994 |
995 | time.sleep(1)
996 | make_plot(
997 | exp_names=wandb.run.id,
998 | exp_nums=gen,
999 | envs=plot_envs,
1000 | num_runs=8,
1001 | pmap=args.pmap,
1002 | larger=args.larger,
1003 | training=True,
1004 | grids=grids,
1005 | params=strategy.get_mean(state),
1006 | mesh=mesh,
1007 | )
1008 |
--------------------------------------------------------------------------------
/rl_optimizer/utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | from gymnax.environments import spaces, environment
5 | from typing import NamedTuple, Optional, Tuple, Union
6 | from brax import envs
7 | from flax import struct
8 | from functools import partial
9 | from gymnax.wrappers.purerl import GymnaxWrapper
10 | import chex
11 | from jax import vjp, flatten_util
12 | from jax.tree_util import tree_flatten
13 |
14 |
15 | class GymnaxGymWrapper:
16 | def __init__(self, env, env_params, config):
17 | self.env = env
18 | self.env_params = env_params
19 | if config["CONTINUOUS"]:
20 | self.action_space = env.action_space(self.env_params).shape[0]
21 | else:
22 | self.action_space = env.action_space(self.env_params).n
23 |
24 | self.observation_space = env.observation_space(self.env_params).shape
25 |
26 | def reset(self, rng, params=None):
27 | rng, _rng = jax.random.split(rng)
28 | obs, env_state = self.env.reset(_rng, self.env_params)
29 | state = (env_state, rng)
30 | return obs, state
31 |
32 | def step(self, key, state, action, params=None):
33 | env_state, rng = state
34 | rng, _rng = jax.random.split(rng)
35 | obs, env_state, reward, done, info = self.env.step(
36 | _rng, env_state, action, self.env_params
37 | )
38 | state = (env_state, rng)
39 | return obs, state, reward, done, info
40 |
41 |
42 | class MetaGymnaxGymWrapper:
43 | def __init__(self, env, env_param_generator):
44 | self.env = env
45 | self.env_param_generator = env_param_generator
46 | self.observation_space = env.observation_space(self.env_params).shape
47 | self.action_space = env.action_space(self.env_params).n
48 |
49 | def reset(self, rng):
50 | rng, _rng = jax.random.split(rng)
51 | env_params = self.env_param_generator(_rng)
52 | rng, _rng = jax.random.split(rng)
53 | obs, env_state = self.env.reset(_rng, env_params)
54 | state = (env_state, env_params, rng)
55 | return obs, state
56 |
57 | def step(self, state, action):
58 | env_state, env_params, rng = state
59 | rng, _rng = jax.random.split(rng)
60 | obs, env_state, reward, done, info = self.env.step(
61 | _rng, env_state, action, env_params
62 | )
63 | state = (env_state, env_params, rng)
64 | return obs, state, reward, done, info
65 |
66 |
67 | class FlatWrapper:
68 | def __init__(self, env):
69 | self.env = env
70 | self.observation_space = np.prod(env.observation_space)
71 | self.action_space = env.action_space
72 |
73 | def reset(self, rng, params=None):
74 | obs, env_state = self.env.reset(rng, params)
75 | obs = jnp.reshape(obs, (self.observation_space,))
76 | return obs, env_state
77 |
78 | def step(self, key, state, action, params=None):
79 | obs, state, reward, done, info = self.env.step(key, state, action, params)
80 | obs = jnp.reshape(obs, (self.observation_space,))
81 | return obs, state, reward, done, info
82 |
83 |
84 | class EpisodeStats(NamedTuple):
85 | episode_returns: float
86 | episode_lengths: int
87 | returned_episode_returns: float
88 | returned_episode_lengths: int
89 |
90 |
91 | class GymnaxLogWrapper(GymnaxWrapper):
92 | def __init__(self, env):
93 | self.env = env
94 | self.observation_space = env.observation_space
95 | self.action_space = env.action_space
96 |
97 | def reset(self, rng, params=None):
98 | obs, env_state = self.env.reset(rng, params)
99 | state = (env_state, EpisodeStats(0, 0, 0, 0))
100 | return obs, state
101 |
102 | def step(self, key, state, action, params=None):
103 | # def step(self, state, action):
104 | env_state, episode_stats = state
105 | obs, env_state, reward, done, info = self.env.step(
106 | key, env_state, action, params
107 | )
108 | new_episode_return = episode_stats.episode_returns + reward
109 | new_episode_length = episode_stats.episode_lengths + 1
110 | new_episode_stats = EpisodeStats(
111 | episode_returns=new_episode_return * (1 - done),
112 | episode_lengths=new_episode_length * (1 - done),
113 | returned_episode_returns=episode_stats.returned_episode_returns * (1 - done)
114 | + new_episode_return * done,
115 | returned_episode_lengths=episode_stats.returned_episode_lengths * (1 - done)
116 | + new_episode_length * done,
117 | )
118 | state = (env_state, new_episode_stats)
119 | info = {}
120 | info["returned_episode_returns"] = new_episode_stats.returned_episode_returns
121 | info["returned_episode_lengths"] = new_episode_stats.returned_episode_lengths
122 | return obs, state, reward, done, info
123 |
124 |
125 | class EvalStats(NamedTuple):
126 | first_returned_episode_returns: float
127 | ever_done: bool
128 |
129 |
130 | class GymnaxLogEvalWrapper:
131 | def __init__(self, env):
132 | self.env = env
133 | self.observation_space = env.observation_space
134 | self.action_space = env.action_space
135 |
136 | def reset(self, rng, params=None):
137 | obs, env_state = self.env.reset(rng, params)
138 | state = (env_state, EvalStats(0, False))
139 | return obs, state
140 |
141 | def step(self, key, state, action, params=None):
142 | env_state, episode_stats = state
143 | obs, env_state, reward, done, info = self.env.step(
144 | key, env_state, action, params
145 | )
146 | ever_done = jnp.logical_or(episode_stats.ever_done, done)
147 | episode_return = episode_stats.first_returned_episode_returns + reward * (
148 | 1 - ever_done
149 | )
150 | new_episode_stats = EvalStats(
151 | first_returned_episode_returns=episode_return,
152 | ever_done=ever_done,
153 | )
154 | state = (env_state, new_episode_stats)
155 | info = {}
156 | info["first_returned_episode_returns"] = (
157 | new_episode_stats.first_returned_episode_returns
158 | )
159 | info["ever_done"] = new_episode_stats.ever_done
160 | return obs, state, reward, done, info
161 |
162 |
163 | class AutoResetEnvWrapper(GymnaxWrapper):
164 | """Provides standard auto-reset functionality, providing the same behaviour as Gymnax-default."""
165 |
166 | def __init__(self, env: environment.Environment):
167 | super().__init__(env)
168 |
169 | @partial(jax.jit, static_argnums=(0, 2))
170 | def reset(
171 | self, key, params: Optional[environment.EnvParams] = None
172 | ) -> Tuple[chex.Array, environment.EnvState]:
173 | return self._env.reset(key, params)
174 |
175 | @partial(jax.jit, static_argnums=(0, 4))
176 | def step(self, rng, state, action, params=None):
177 |
178 | rng, _rng = jax.random.split(rng)
179 | obs_st, state_st, reward, done, info = self._env.step(
180 | _rng, state, action, params
181 | )
182 |
183 | rng, _rng = jax.random.split(rng)
184 | obs_re, state_re = self._env.reset(_rng, params)
185 |
186 | # Auto-reset environment based on termination
187 | def auto_reset(done, state_re, state_st, obs_re, obs_st):
188 | state = jax.tree_util.tree_map(
189 | lambda x, y: jax.lax.select(done, x, y), state_re, state_st
190 | )
191 | obs = jax.lax.select(done, obs_re, obs_st)
192 |
193 | return obs, state
194 |
195 | obs, state = auto_reset(done, state_re, state_st, obs_re, obs_st)
196 |
197 | return obs, state, reward, done, info
198 |
199 |
200 | class BraxGymnaxWrapper:
201 | def __init__(self, env_name, backend="positional"):
202 | # def __init__(self, env_name, backend="generalized"):
203 |
204 | # ****** BACKEND CURRENTLY NOT IMPLEMENTED
205 |
206 | env = envs.get_environment(env_name=env_name, backend=backend)
207 | env = envs.wrappers.training.EpisodeWrapper(
208 | env, episode_length=1000, action_repeat=1
209 | )
210 | env = envs.wrappers.training.AutoResetWrapper(env)
211 | self._env = env
212 | self.action_size = env.action_size
213 | self.observation_size = (env.observation_size,)
214 |
215 | def reset(self, key, params=None):
216 | # print(f'bye: {key}')
217 | state = self._env.reset(key)
218 | return state.obs, state
219 |
220 | def step(self, key, state, action, params):
221 | next_state = self._env.step(state, action)
222 | return next_state.obs, next_state, next_state.reward, next_state.done > 0.5, {}
223 |
224 | def observation_space(self, params):
225 | return spaces.Box(
226 | low=-jnp.inf,
227 | high=jnp.inf,
228 | shape=(self._env.observation_size,),
229 | )
230 |
231 | def action_space(self, params):
232 | return spaces.Box(
233 | low=-1.0,
234 | high=1.0,
235 | shape=(self._env.action_size,),
236 | )
237 |
238 |
239 | class GymnaxWrapper(object):
240 | """Base class for Gymnax wrappers."""
241 |
242 | def __init__(self, env):
243 | self._env = env
244 |
245 | # provide proxy access to regular attributes of wrapped object
246 | def __getattr__(self, name):
247 | return getattr(self._env, name)
248 |
249 |
250 | class ClipAction(GymnaxWrapper):
251 | def __init__(self, env):
252 | super().__init__(env)
253 | # self.action_lim = config["ACTION_LIM"]
254 |
255 | def step(self, key, state, action, params=None):
256 | """TODO: FIX"""
257 | # old_action = action
258 | action = jnp.clip(action, -1, 1)
259 | # jax.debug.print('{old} -> {new}', old=old_action, new=action)
260 | # action = jnp.clip(action, -1.0, 1.0)
261 | # jax.debug.print('{action}', action=action)
262 | return self._env.step(key, state, action, params)
263 |
264 |
265 | class TransformObservation(GymnaxWrapper):
266 | def __init__(self, env, transform_obs):
267 | super().__init__(env)
268 | self.transform_obs = transform_obs
269 |
270 | def reset(self, key, params=None):
271 | obs, state = self._env.reset(key, params)
272 | return self.transform_obs(obs), state
273 |
274 | def step(self, key, state, action, params=None):
275 | obs, state, reward, done, info = self._env.step(key, state, action, params)
276 | return self.transform_obs(obs), state, reward, done, info
277 |
278 |
279 | class TransformReward(GymnaxWrapper):
280 | def __init__(self, env, transform_reward):
281 | super().__init__(env)
282 | self.transform_reward = transform_reward
283 |
284 | def step(self, key, state, action, params=None):
285 | obs, state, reward, done, info = self._env.step(key, state, action, params)
286 | return obs, state, self.transform_reward(reward), done, info
287 |
288 |
289 | class VecEnv(GymnaxWrapper):
290 | def __init__(self, env):
291 | super().__init__(env)
292 | self.reset = jax.vmap(self._env.reset, in_axes=(0, None))
293 | self.step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
294 |
295 |
296 | @struct.dataclass
297 | class NormalizeObsEnvState:
298 | mean: jnp.ndarray
299 | var: jnp.ndarray
300 | count: float
301 | env_state: environment.EnvState
302 |
303 |
304 | class NormalizeObservation(GymnaxWrapper):
305 | def __init__(self, env):
306 | super().__init__(env)
307 |
308 | def reset(self, key, params=None):
309 | obs, state = self._env.reset(key, params)
310 | state = NormalizeObsEnvState(
311 | mean=jnp.zeros_like(obs),
312 | var=jnp.ones_like(obs),
313 | count=1e-4,
314 | env_state=state,
315 | )
316 | batch_mean = jnp.mean(obs, axis=0)
317 | batch_var = jnp.var(obs, axis=0)
318 | batch_count = obs.shape[0]
319 |
320 | delta = batch_mean - state.mean
321 | tot_count = state.count + batch_count
322 |
323 | new_mean = state.mean + delta * batch_count / tot_count
324 | m_a = state.var * state.count
325 | m_b = batch_var * batch_count
326 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
327 | new_var = M2 / tot_count
328 | new_count = tot_count
329 |
330 | state = NormalizeObsEnvState(
331 | mean=new_mean,
332 | var=new_var,
333 | count=new_count,
334 | env_state=state.env_state,
335 | )
336 |
337 | return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state
338 |
339 | def step(self, key, state, action, params=None):
340 | obs, env_state, reward, done, info = self._env.step(
341 | key, state.env_state, action, params
342 | )
343 |
344 | batch_mean = jnp.mean(obs, axis=0)
345 | batch_var = jnp.var(obs, axis=0)
346 | batch_count = obs.shape[0]
347 |
348 | delta = batch_mean - state.mean
349 | tot_count = state.count + batch_count
350 |
351 | new_mean = state.mean + delta * batch_count / tot_count
352 | m_a = state.var * state.count
353 | m_b = batch_var * batch_count
354 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
355 | new_var = M2 / tot_count
356 | new_count = tot_count
357 |
358 | state = NormalizeObsEnvState(
359 | mean=new_mean,
360 | var=new_var,
361 | count=new_count,
362 | env_state=env_state,
363 | )
364 | return (
365 | (obs - state.mean) / jnp.sqrt(state.var + 1e-8),
366 | state,
367 | reward,
368 | done,
369 | info,
370 | )
371 |
372 |
373 | @struct.dataclass
374 | class NormalizeRewEnvState:
375 | mean: jnp.ndarray
376 | var: jnp.ndarray
377 | count: float
378 | return_val: float
379 | env_state: environment.EnvState
380 |
381 |
382 | class NormalizeReward(GymnaxWrapper):
383 | def __init__(self, env, gamma):
384 | super().__init__(env)
385 | self.gamma = gamma
386 |
387 | def reset(self, key, params=None):
388 | obs, state = self._env.reset(key, params)
389 | batch_count = obs.shape[0]
390 | state = NormalizeRewEnvState(
391 | mean=0.0,
392 | var=1.0,
393 | count=1e-4,
394 | return_val=jnp.zeros((batch_count,)),
395 | env_state=state,
396 | )
397 | return obs, state
398 |
399 | def step(self, key, state, action, params=None):
400 | obs, env_state, reward, done, info = self._env.step(
401 | key, state.env_state, action, params
402 | )
403 | return_val = state.return_val * self.gamma * (1 - done) + reward
404 |
405 | batch_mean = jnp.mean(return_val, axis=0)
406 | batch_var = jnp.var(return_val, axis=0)
407 | batch_count = obs.shape[0]
408 |
409 | delta = batch_mean - state.mean
410 | tot_count = state.count + batch_count
411 |
412 | new_mean = state.mean + delta * batch_count / tot_count
413 | m_a = state.var * state.count
414 | m_b = batch_var * batch_count
415 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
416 | new_var = M2 / tot_count
417 | new_count = tot_count
418 |
419 | state = NormalizeRewEnvState(
420 | mean=new_mean,
421 | var=new_var,
422 | count=new_count,
423 | return_val=return_val,
424 | env_state=env_state,
425 | )
426 | return obs, state, reward / jnp.sqrt(state.var + 1e-8), done, info
427 |
428 |
429 | # Use parameter Reshaper from old gymnax
430 | def ravel_pytree(pytree):
431 | leaves, _ = tree_flatten(pytree)
432 | flat, _ = vjp(ravel_list, *leaves)
433 | return flat
434 |
435 |
436 | def ravel_list(*lst):
437 | return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([])
438 |
439 |
440 | class ParameterReshaper(object):
441 | def __init__(
442 | self,
443 | placeholder_params: Union[chex.ArrayTree, chex.Array],
444 | n_devices: Optional[int] = None,
445 | verbose: bool = True,
446 | ):
447 | """Reshape flat parameters vectors into generation eval shape."""
448 | # Get network shape to reshape
449 | self.placeholder_params = placeholder_params
450 |
451 | # Set total parameters depending on type of placeholder params
452 | flat, self.unravel_pytree = flatten_util.ravel_pytree(placeholder_params)
453 | self.total_params = flat.shape[0]
454 | self.reshape_single = jax.jit(self.unravel_pytree)
455 |
456 | if n_devices is None:
457 | self.n_devices = jax.local_device_count()
458 | else:
459 | self.n_devices = n_devices
460 | if self.n_devices > 1 and verbose:
461 | print(
462 | f"ParameterReshaper: {self.n_devices} devices detected. Please"
463 | " make sure that the ES population size divides evenly across"
464 | " the number of devices to pmap/parallelize over."
465 | )
466 |
467 | if verbose:
468 | print(
469 | f"ParameterReshaper: {self.total_params} parameters detected"
470 | " for optimization."
471 | )
472 |
473 | def reshape(self, x: chex.Array) -> chex.ArrayTree:
474 | """Perform reshaping for a 2D matrix (pop_members, params)."""
475 | vmap_shape = jax.vmap(self.reshape_single)
476 | if self.n_devices > 1:
477 | x = self.split_params_for_pmap(x)
478 | map_shape = jax.pmap(vmap_shape)
479 | else:
480 | map_shape = vmap_shape
481 | return map_shape(x)
482 |
483 | def multi_reshape(self, x: chex.Array) -> chex.ArrayTree:
484 | """Reshape parameters lying already on different devices."""
485 | # No reshaping required!
486 | vmap_shape = jax.vmap(self.reshape_single)
487 | return jax.pmap(vmap_shape)(x)
488 |
489 | def flatten(self, x: chex.ArrayTree) -> chex.Array:
490 | """Reshaping pytree parameters into flat array."""
491 | vmap_flat = jax.vmap(ravel_pytree)
492 | if self.n_devices > 1:
493 | # Flattening of pmap paramater trees to apply vmap flattening
494 | def map_flat(x):
495 | x_re = jax.tree_util.tree_map(lambda x: x.reshape(-1, *x.shape[2:]), x)
496 | return vmap_flat(x_re)
497 |
498 | else:
499 | map_flat = vmap_flat
500 | flat = map_flat(x)
501 | return flat
502 |
503 | def multi_flatten(self, x: chex.Array) -> chex.ArrayTree:
504 | """Flatten parameters lying remaining on different devices."""
505 | # No reshaping required!
506 | vmap_flat = jax.vmap(ravel_pytree)
507 | return jax.pmap(vmap_flat)(x)
508 |
509 | def split_params_for_pmap(self, param: chex.Array) -> chex.Array:
510 | """Helper reshapes param (bs, #params) into (#dev, bs/#dev, #params)."""
511 | return jnp.stack(jnp.split(param, self.n_devices))
512 |
513 | @property
514 | def vmap_dict(self) -> chex.ArrayTree:
515 | """Get a dictionary specifying axes to vmap over."""
516 | vmap_dict = jax.tree_util.tree_map(lambda x: 0, self.placeholder_params)
517 | return vmap_dict
518 |
--------------------------------------------------------------------------------
/run_docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Check if GPU names are provided
4 | if [ -z "$1" ]; then
5 | echo "Please provide GPU names (e.g., 0,1,2)"
6 | exit 1
7 | fi
8 |
9 | # GPU names passed as the first argument
10 | GPU_NAMES=$1
11 |
12 | # Run the Docker command with the specified GPUs. shm-size is needed for sharding, as Docker defaults to tiny RAM. Feel free to change this if it causes issues.
13 | docker run -it --rm --gpus '"device='$GPU_NAMES'"' -v $(pwd):/rl_optimizer -w /rl_optimizer/rl_optimizer --shm-size=5g open
--------------------------------------------------------------------------------
/setup/Dockerfile:
--------------------------------------------------------------------------------
1 |
2 | FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04
3 |
4 | ARG DEBIAN_FRONTEND=noninteractive
5 |
6 | RUN apt update && apt install python3-pip -y
7 | RUN pip3 install --upgrade pip
8 | RUN apt-get install tmux -y
9 | RUN apt-get install vim -y
10 | RUN apt install libglfw3-dev -y
11 | RUN apt install libglfw3 -y
12 | RUN apt-get update && apt-get install -y git
13 | RUN pip3 install --upgrade pip setuptools wheel
14 |
15 | COPY requirements.txt /tmp/requirements.txt
16 | # Need to use specific cuda versions for jax
17 | ARG USE_CUDA=true
18 | RUN if [ "$USE_CUDA" = true ] ; \
19 | then pip install "jax[cuda12]>=0.4.25, <0.6.0" -f "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" ; \
20 | fi
21 | RUN pip3 install -r /tmp/requirements.txt
22 |
23 | ARG WANDB_API
24 | ARG USERNAME
25 | ARG USER_UID
26 | ARG USER_GID
27 |
28 | RUN groupadd --gid $USER_GID $USERNAME \
29 | && useradd --uid $USER_UID --gid $USER_GID --create-home $USERNAME \
30 | && chown -R $USER_UID:$USER_GID /home/$USERNAME
31 |
32 | USER $USERNAME
33 |
34 | ENV WANDB_API_KEY=$WANDB_API
35 |
36 | WORKDIR rl_optimizer/
37 |
38 | CMD ["/bin/bash"]
39 |
--------------------------------------------------------------------------------
/setup/build_docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Set default image name
4 | IMAGE_NAME="open"
5 |
6 | # Get current user info
7 | USER_NAME=$(id -un)
8 | USER_ID=$(id -u)
9 | GROUP_ID=$(id -g)
10 |
11 | # Check if WANDB_API key is provided
12 | if [ -z "$1" ]; then
13 | echo "Please provide WANDB_API_KEY key as argument"
14 | exit 1
15 | fi
16 | WANDB_API=$1
17 |
18 | # Build the Docker image with current user info
19 | docker build \
20 | --build-arg USERNAME="${USER_NAME}" \
21 | --build-arg USER_UID="${USER_ID}" \
22 | --build-arg USER_GID="${GROUP_ID}" \
23 | --build-arg WANDB_API="${WANDB_API}" \
24 | -t ${IMAGE_NAME} \
25 | .
--------------------------------------------------------------------------------
/setup/requirements.txt:
--------------------------------------------------------------------------------
1 | brax>=0.9.0
2 | distrax @ git+https://github.com/google-deepmind/distrax # distrax release doesn't support jax > 0.4.13
3 | evosax @ git+https://github.com/RobertTLange/evosax.git
4 | flax
5 | gymnasium
6 | gymnax==0.0.6 # Later gymnax versions have bugs that prevent running certain envs
7 | jaxlib
8 | numpy
9 | optax @ git+https://github.com/google-deepmind/optax.git
10 | tqdm
11 | wandb
12 | gin-config
13 | seaborn
--------------------------------------------------------------------------------