├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── images └── OPEN_Animation.gif ├── rl_optimizer ├── base.py ├── configs.py ├── eval.py ├── network.py ├── pretrained │ ├── ant_OPEN.npy │ ├── asterix_OPEN.npy │ ├── breakout_OPEN.npy │ ├── freeway_OPEN.npy │ ├── gridworld_OPEN.npy │ ├── multi_OPEN.npy │ └── spaceinvaders_OPEN.npy ├── train.py └── utils.py ├── run_docker.sh └── setup ├── Dockerfile ├── build_docker.sh └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | training 2 | wandb 3 | **/__pycache__/ 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "rl_optimizer/groove"] 2 | path = rl_optimizer/groove 3 | url = https://github.com/EmptyJackson/groove 4 | [submodule "rl_optimizer/learned_optimization"] 5 | path = rl_optimizer/learned_optimization 6 | url = https://github.com/google/learned_optimization 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

OPEN: Learned Optimization for RL in JAX

2 | 3 |

4 | 5 | 6 | 7 | 8 |

9 | 10 |

11 | animated 12 |

13 | 14 | This is the official implementation of OPEN from *Can Learned Optimization Make Reinforcement Less Difficult*, NeurIPS 2024 (**Spotlight**) and the AutoRL Workshop @ ICML 2024 (**Spotlight**). 15 | 16 | 17 | OPEN is a framework for learning to optimize (L2O) in reinforcement learning. Here, we provide full JAX code to replicate the experiments in our paper and foster future work in this direction. Our current codebase can be used with environments from gymnax or Brax. 18 | 19 | 20 | # 🖥️ Usage 21 | 22 | All files for running OPEN are stored in ``. 23 | 24 | ## 🏋️‍♀️ Training 25 | Alongside training code in `rl_optimizer/train.py`, we include configs for [`freeway`, `asterix`, `breakout`, `spaceinvaders`, `ant`, `gridworld`]. We automate parallelisation over multiple GPUs using JAX sharding. The flag `<--larger>` can be used to increase the size of the network in OPEN. To learn an optimizer in one or a combination of these environments run: 26 | ```bash 27 | python3 train.py --envs --num-rollouts --popsize --noise-level --sigma-decay --lr --lr-decay --num-generations --save-every-k --wandb-name "" --wandb-entity "" [--larger] 28 | ``` 29 | 30 | This will save a checkpoint, and evaluate the performance of the optimizer, every $k$ steps. Please note that `gridworld` can not be run in tandem with other environments as it is the only environment which we apply antithetic task sampling to. 31 | 32 | We include our hyperparameters in the paper. An example usage is: 33 | ```bash 34 | python3 train.py --envs breakout --num-rollouts 1 --popsize 64 --noise-level 0.03 --sigma-decay 0.999 --lr 0.03 --lr-decay 0.999 --num-generations 500 --save-every-k 24 --wandb-name "OPEN Breakout" 35 | ``` 36 | 37 | ## 🔬 Evaluation 38 | 39 | To evaluate the performance of learned optimizers, run the following command by providing the relevant wandb run IDs to `<--exp-name>` and the generation number to `--exp-num`. This code is run intermittently during training too. 40 | 41 | For experimental purposes, we provide learned weights for the trained optimizers from our paper for the aforementioned environments in `rl_optimizer/pretrained`. These can be used with the argument `<--pretrained>` in place of wandb IDs. Use the <--larger> flag if this was used in training, and to experiment with our pretrained `` optimizers pass the `<--multi>` flag. 42 | ```bash 43 | python3 rl_optimizer.eval --envs --exp-name --exp-num --num-runs 16 --title [--pretrained --multi --larger] 44 | ``` 45 | 46 | 47 | # ⬇️ Installation 48 | 49 | We include submodules for [Learned Optimization](https://github.com/google/learned_optimization) and [GROOVE](https://github.com/EmptyJackson/groove). Therefore, when cloning this repo, ensure to use `--recurse-submodules`: 50 | ```bash 51 | git clone --recurse-submodules git@github.com:AlexGoldie/rl-learned-optimization.git 52 | ``` 53 | 54 | ## 📝 Requirements 55 | 56 | We include requirements in `setup/requirements.txt`. Dependencies can be install locally using: 57 | ```bash 58 | pip install -r setup/requirements.txt 59 | ``` 60 | 61 | ## 🐋 Docker 62 | We also provide files to help build a Docker image. Since we use wandb for logging checkpoints, you should supply this as an argument to `build_docker.sh`. 63 | 64 | ```bash 65 | cd setup 66 | chmod +x build_docker.sh 67 | ./build_docker.sh {WANDB_API_KEY} 68 | cd .. 69 | chmod +x run_docker.sh 70 | ./run_docker.sh {GPU_NAMES} 71 | ``` 72 | 73 | For example, starting the docker container with access to GPUs `0` and `1` can be done as `./run_docker.sh 0,1` 74 | 75 | 76 | # 📚 Related Work 77 | 78 | The following projects were used extensively in the making of OPEN: 79 | - 🎓 [Learned Optimization](https://github.com/google/learned_optimization) 80 | - 🦎 [Evosax](https://github.com/RobertTLange/evosax) 81 | - ⚡ [PureJaxRL](https://github.com/luchris429/purejaxrl) 82 | - 🕺 [GROOVE](https://github.com/EmptyJackson/groove) 83 | - 🐜 [Brax](https://github.com/google/brax) 84 | - 💪 [Gymnax](https://github.com/RobertTLange/gymnax) 85 | 86 | 87 | # 🔖 Citation 88 | 89 | If you use OPEN in your work, please cite the following: 90 | ``` 91 | @inproceedings{goldie2024can, 92 | author={Alexander D. Goldie and Chris Lu and Matthew Thomas Jackson and Shimon Whiteson and Jakob Nicolaus Foerster}, 93 | booktitle={Advances in Neural Information Processing Systems}, 94 | title={Can Learned Optimization Make Reinforcement Learning Less Difficult?}, 95 | year={2024}, 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /images/OPEN_Animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/images/OPEN_Animation.gif -------------------------------------------------------------------------------- /rl_optimizer/base.py: -------------------------------------------------------------------------------- 1 | # A MISHMASH OF BASE OBJECTS FROM LEARNED_OPTIMIZERS 2 | import abc 3 | import collections 4 | from typing import Any, Callable, Sequence, Optional, Tuple 5 | 6 | import chex 7 | import flax 8 | 9 | import jax.numpy as jnp 10 | 11 | MetaParamOpt = collections.namedtuple("MetaParamOpt", ["init", "opt_fn"]) 12 | 13 | PRNGKey = jnp.ndarray 14 | Params = Any 15 | MetaParams = Any 16 | 17 | 18 | class LearnedOptimizer(abc.ABC): 19 | """Base class for learned optimizers.""" 20 | 21 | @abc.abstractmethod 22 | def init(self, key: PRNGKey) -> MetaParams: 23 | raise NotImplementedError() 24 | 25 | @abc.abstractmethod 26 | def opt_fn(self, theta: MetaParams, is_training: bool = False): 27 | raise NotImplementedError() 28 | 29 | @property 30 | def name(self): 31 | return None 32 | 33 | 34 | ModelState = Any 35 | Params = Any 36 | Gradient = Params 37 | OptState = Any 38 | 39 | 40 | @flax.struct.dataclass 41 | class StatelessState: 42 | params: chex.ArrayTree 43 | state: chex.ArrayTree 44 | 45 | 46 | class Optimizer(abc.ABC): 47 | """Baseclass for the Optimizer interface.""" 48 | 49 | def get_params(self, state: OptState) -> Params: 50 | return state.params 51 | 52 | def get_state(self, state: OptState) -> ModelState: 53 | return state.state 54 | 55 | def get_params_state(self, state: OptState) -> Tuple[Params, ModelState]: 56 | return self.get_params(state), self.get_state(state) 57 | 58 | def init( 59 | self, 60 | params: Params, 61 | state: Optional[ModelState] = None, 62 | num_steps: Optional[int] = None, 63 | key: Optional[chex.PRNGKey] = None, 64 | **kwargs, 65 | ) -> OptState: 66 | raise NotImplementedError 67 | 68 | def set_params(self, state: OptState, params: Params) -> OptState: 69 | return state.replace(params=params) 70 | 71 | def update( 72 | self, 73 | opt_state: OptState, 74 | grad: Gradient, 75 | model_state: Optional[ModelState] = None, 76 | key: Optional[chex.PRNGKey] = None, 77 | **kwargs, 78 | ) -> OptState: 79 | raise NotImplementedError() 80 | 81 | @property 82 | def name(self) -> str: 83 | """Name of optimizer. 84 | This property is used when serializing results / baselines. This should 85 | lead with the class name, and follow with all parameters used to create 86 | the object. For example: "__" 87 | """ 88 | return "UnnamedOptimizer" 89 | -------------------------------------------------------------------------------- /rl_optimizer/configs.py: -------------------------------------------------------------------------------- 1 | # ****HAVE TO INPUT WHETHER CONTINUOUS OR JUST AUTOMATE 2 | 3 | all_configs = { 4 | "asterix": { 5 | "ANNEAL_LR": True, 6 | "NUM_ENVS": 64, 7 | "NUM_STEPS": 128, 8 | "TOTAL_TIMESTEPS": 1e7, 9 | "UPDATE_EPOCHS": 4, 10 | "NUM_MINIBATCHES": 8, 11 | "GAMMA": 0.99, 12 | "GAE_LAMBDA": 0.95, 13 | "CLIP_EPS": 0.2, 14 | "ENT_COEF": 0.01, 15 | "VF_COEF": 0.5, 16 | "MAX_GRAD_NORM": 0.5, 17 | "ENV_NAME": "Asterix-MinAtar", 18 | "HSIZE": 64, 19 | "ACTIVATION": "relu", 20 | "DEBUG": False, 21 | "PPO_TEMP": 15.0353, 22 | "CONTINUOUS": False, 23 | }, 24 | "freeway": { 25 | "NUM_ENVS": 64, 26 | "NUM_STEPS": 128, 27 | "TOTAL_TIMESTEPS": 1e7, 28 | "UPDATE_EPOCHS": 4, 29 | "NUM_MINIBATCHES": 8, 30 | "GAMMA": 0.99, 31 | "GAE_LAMBDA": 0.95, 32 | "CLIP_EPS": 0.2, 33 | "ENT_COEF": 0.01, 34 | "VF_COEF": 0.5, 35 | "MAX_GRAD_NORM": 0.5, 36 | "ENV_NAME": "Freeway-MinAtar", 37 | "HSIZE": 64, 38 | "ACTIVATION": "relu", 39 | "DEBUG": False, 40 | "PPO_TEMP": 61.8369, 41 | "CONTINUOUS": False, 42 | }, 43 | "breakout": { 44 | "NUM_ENVS": 64, 45 | "NUM_STEPS": 128, 46 | "TOTAL_TIMESTEPS": 5e5, 47 | "UPDATE_EPOCHS": 4, 48 | "NUM_MINIBATCHES": 8, 49 | "GAMMA": 0.99, 50 | "GAE_LAMBDA": 0.95, 51 | "CLIP_EPS": 0.2, 52 | "ENT_COEF": 0.01, 53 | "VF_COEF": 0.5, 54 | "MAX_GRAD_NORM": 0.5, 55 | "ENV_NAME": "Breakout-MinAtar", 56 | "HSIZE": 64, 57 | "ACTIVATION": "relu", 58 | "DEBUG": False, 59 | "PPO_TEMP": 50.0822, 60 | "CONTINUOUS": False, 61 | }, 62 | "spaceinvaders": { 63 | "NUM_ENVS": 64, 64 | "NUM_STEPS": 128, 65 | "TOTAL_TIMESTEPS": 1e7, 66 | "UPDATE_EPOCHS": 4, 67 | "NUM_MINIBATCHES": 8, 68 | "GAMMA": 0.99, 69 | "GAE_LAMBDA": 0.95, 70 | "CLIP_EPS": 0.2, 71 | "ENT_COEF": 0.01, 72 | "VF_COEF": 0.5, 73 | "MAX_GRAD_NORM": 0.5, 74 | "ENV_NAME": "SpaceInvaders-MinAtar", 75 | "HSIZE": 64, 76 | "ACTIVATION": "relu", 77 | "DEBUG": False, 78 | "PPO_TEMP": 167.0913, 79 | "CONTINUOUS": False, 80 | }, 81 | "ant": { 82 | "NUM_ENVS": 2048, 83 | "NUM_STEPS": 10, 84 | "TOTAL_TIMESTEPS": 5e7, 85 | "UPDATE_EPOCHS": 4, 86 | "NUM_MINIBATCHES": 32, 87 | "GAMMA": 0.99, 88 | "GAE_LAMBDA": 0.95, 89 | "CLIP_EPS": 0.2, 90 | "ENT_COEF": 0.0, 91 | "VF_COEF": 0.5, 92 | "MAX_GRAD_NORM": 0.5, 93 | "ACTIVATION": "tanh", 94 | "HSIZE": 64, 95 | "ENV_NAME": "Brax-ant", 96 | "BACKEND": "positional", 97 | "SYMLOG_OBS": False, 98 | "CLIP_ACTION": True, 99 | "DEBUG": False, 100 | "NORMALIZE": True, 101 | "CONTINUOUS": True, 102 | "PPO_TEMP": 6517.31, 103 | }, 104 | "gridworld": { 105 | "LR": 1e-4, 106 | "ANNEAL_LR": True, 107 | "NUM_ENVS": 1024, 108 | "NUM_STEPS": 20, 109 | "TOTAL_TIMESTEPS": 3e7, 110 | "UPDATE_EPOCHS": 2, 111 | "NUM_MINIBATCHES": 16, 112 | "GAMMA": 0.99, 113 | "GAE_LAMBDA": 0.95, 114 | "CLIP_EPS": 0.2, 115 | "ENT_COEF": 0.01, 116 | "VF_COEF": 0.5, 117 | "MAX_GRAD_NORM": 0.5, 118 | "ACTIVATION": "tanh", 119 | "HSIZE": 16, 120 | "ENV_NAME": "gridworld", 121 | "DEBUG": False, 122 | "PPO_TEMP": 1.0, 123 | "CONTINUOUS": False, 124 | }, 125 | } 126 | -------------------------------------------------------------------------------- /rl_optimizer/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import os 7 | import argparse 8 | import seaborn as sns 9 | import pandas as pd 10 | import wandb 11 | from scipy.interpolate import interp1d 12 | import time 13 | 14 | import jax 15 | import jax.numpy as jnp 16 | from jax.sharding import Mesh 17 | from jax.sharding import PartitionSpec as P, NamedSharding 18 | 19 | from configs import all_configs as all_configs 20 | from network import GRU_Opt as optim 21 | from utils import ParameterReshaper 22 | 23 | api = wandb.Api() 24 | 25 | 26 | labs = [""] 27 | 28 | 29 | def set_size(width=487.8225, fraction=1, subplots=(1, 1)): 30 | """Set figure dimensions to avoid scaling in LaTeX. 31 | 32 | Parameters 33 | ---------- 34 | width: float 35 | Document textwidth or columnwidth in pts 36 | fraction: float, optional 37 | Fraction of the width which you wish the figure to occupy 38 | 39 | Returns 40 | ------- 41 | fig_dim: tuple 42 | Dimensions of figure in inches 43 | """ 44 | # Width of figure (in pts) 45 | fig_width_pt = width * fraction 46 | 47 | # Convert from pt to inches 48 | inches_per_pt = 1 / 72.27 49 | 50 | golden_ratio = 1.2 51 | 52 | # Figure width in inches 53 | fig_width_in = fig_width_pt * inches_per_pt 54 | subplots = (1, 1) 55 | fig_height_in = fig_width_in / golden_ratio * (subplots[0] / subplots[1]) 56 | # Figure height in inches 57 | fig_dim = (fig_width_in, fig_height_in) 58 | 59 | return fig_dim 60 | 61 | 62 | episode_lengths = { 63 | "asterix": 1000, 64 | "pendulum": 200, 65 | "acrobot": 500, 66 | "cartpole": 500, 67 | "spaceinvaders": 1000, 68 | "freeway": 2500, 69 | "breakout": 1000, 70 | "ant": 1000, 71 | "gridworld": 0, 72 | } 73 | 74 | 75 | def make_plot( 76 | exp_names=None, 77 | exp_nums=None, 78 | envs=None, 79 | num_runs=5, 80 | pmap=False, 81 | title=None, 82 | larger=False, 83 | pretrained=False, 84 | multi=False, 85 | training=False, 86 | grids=False, 87 | params=None, 88 | mesh=None, 89 | ): 90 | 91 | devices = jax.devices() 92 | sharding_p = NamedSharding( 93 | mesh, 94 | P( 95 | None, 96 | ), 97 | ) 98 | sharding_rng = NamedSharding( 99 | mesh, 100 | P( 101 | "dim", 102 | ), 103 | ) 104 | 105 | if training: 106 | p = "training" 107 | if not os.path.exists(p): 108 | os.mkdir(p) 109 | p = f"{p}/{exp_names}" 110 | if not os.path.exists(p): 111 | os.mkdir(p) 112 | p = f"{p}/{exp_nums}" 113 | elif pretrained: 114 | p = "pretrained" 115 | else: 116 | p = "visualization" 117 | if not os.path.exists(p): 118 | os.mkdir(p) 119 | p = f"{p}/{title}" 120 | 121 | if not os.path.exists(p): 122 | os.mkdir(p) 123 | 124 | returns_list = dict() 125 | param_means = dict() 126 | tau_dormancies = dict() 127 | abs_param_means = dict() 128 | runtimes = dict() 129 | 130 | from train import ( 131 | make_train as meta_make_train, 132 | ) 133 | 134 | for i, env in enumerate(envs): 135 | 136 | if grids: 137 | env = "gridworld" 138 | 139 | returns_list.update({envs[i]: []}) 140 | param_means.update({envs[i]: []}) 141 | tau_dormancies.update({envs[i]: []}) 142 | abs_param_means.update({envs[i]: []}) 143 | runtimes.update({envs[i]: []}) 144 | 145 | if grids: 146 | if not os.path.exists(f"{p}/{envs[i]}"): 147 | os.mkdir(f"{p}/{envs[i]}") 148 | else: 149 | if not os.path.exists(f"{p}/{env}"): 150 | os.mkdir(f"{p}/{env}") 151 | 152 | if pretrained: 153 | if multi: 154 | exp_name = f"pretrained/multi_OPEN.npy" 155 | larger = True 156 | else: 157 | exp_name = f"pretrained/{env}_OPEN.npy" 158 | params = jnp.array(jnp.load(exp_name, allow_pickle=True)) 159 | 160 | else: 161 | if training: 162 | params = params 163 | exp_name = exp_names 164 | exp_num = exp_nums 165 | else: 166 | if grids: 167 | exp_name = exp_names[0] 168 | exp_num = exp_nums[0] 169 | else: 170 | exp_name = exp_names[i] 171 | exp_num = exp_nums[i] 172 | 173 | run_path = f"OPEN/{exp_name}" 174 | run = api.run(run_path) 175 | restored = wandb.restore( 176 | f"curr_param_{exp_num}.npy", 177 | run_path=run_path, 178 | root=p, 179 | replace=True, 180 | ) 181 | print(f"name: {restored.name}, env = {envs[i]}") 182 | params = jnp.array(jnp.load(restored.name, allow_pickle=True)) 183 | 184 | # Need to reshape saved params as they are saved as np arrays 185 | if not training: 186 | if larger: 187 | hidden_size = 32 188 | gru_features = 16 189 | else: 190 | hidden_size = 16 191 | gru_features = 8 192 | pholder = optim(hidden_size=hidden_size, gru_features=gru_features).init( 193 | jax.random.PRNGKey(0) 194 | ) 195 | param_reshaper = ParameterReshaper(pholder) 196 | params = param_reshaper.reshape_single(params) 197 | 198 | all_configs[f"{env}"]["larger"] = larger 199 | 200 | make_train = meta_make_train 201 | 202 | rng = jax.random.PRNGKey(42) 203 | all_configs[f"{env}"]["VISUALISE"] = True 204 | 205 | start = time.time() 206 | rngs = jax.random.split(rng, num_runs) 207 | if env == "gridworld": 208 | asdf = jax.jit( 209 | jax.vmap( 210 | make_train(all_configs[f"{env}"]), 211 | in_axes=(None, 0, None), 212 | ), 213 | static_argnames=["grid_type"], 214 | ) 215 | 216 | if pmap: 217 | params = jax.device_put(params, sharding_p) 218 | rngs = jax.device_put(rngs, sharding_rng) 219 | 220 | out, metrics = asdf(params, rngs, i) 221 | 222 | else: 223 | asdf = jax.jit( 224 | jax.vmap( 225 | make_train(all_configs[f"{env}"]), 226 | in_axes=(None, 0), 227 | ) 228 | ) 229 | 230 | if pmap: 231 | params = jax.device_put(params, sharding_p) 232 | rngs = jax.device_put(rngs, sharding_rng) 233 | 234 | out, metrics = asdf(params, rngs) 235 | out, metrics = jax.device_get(out), jax.device_get(metrics) 236 | 237 | fitness = metrics["returned_episode_returns"][..., -1, -1, :].mean() 238 | end = time.time() 239 | print(f"runtime = {end - start}") 240 | runtimes[envs[i]].append(end - start) 241 | print(f"{envs[i]} learned fitness: {fitness}") 242 | 243 | returns = ( 244 | metrics["returned_episode_returns"].mean(-1).mean(-1).reshape(num_runs, -1) 245 | ) 246 | 247 | if training: 248 | wandb.log({f"eval/{env}/fitness_at_mean": fitness}, step=exp_num) 249 | 250 | config_for_index = all_configs[f"{env}"] 251 | index_from = episode_lengths[env] 252 | index = int(np.ceil(index_from / config_for_index["NUM_STEPS"])) 253 | 254 | if grids: 255 | 256 | returns_list[envs[i]].append(returns[:, index:]) 257 | 258 | return_df = pd.DataFrame(returns[:, index:]) 259 | return_df.to_csv(f"{p}/{envs[i]}/returns.csv") 260 | 261 | tau_dormancies[envs[i]].append( 262 | metrics["dormancy"].mean(-1).reshape(num_runs, -1) 263 | ) 264 | else: 265 | returns_list[env].append(returns[:, index:]) 266 | 267 | return_df = pd.DataFrame(returns[:, index:]) 268 | return_df.to_csv(f"{p}/{env}/returns.csv") 269 | 270 | tau_dormancies[env].append( 271 | metrics["dormancy"].mean(-1).reshape(num_runs, -1) 272 | ) 273 | 274 | def plot_all(values, conf, labels, xlabel, ylabel, title, training=False): 275 | 276 | for j, env in enumerate(values.keys()): 277 | fig, ax = plt.subplots(1, 1, figsize=set_size()) 278 | if grids: 279 | x_mult = ( 280 | conf[f"gridworld"]["NUM_STEPS"] * conf[f"gridworld"]["NUM_ENVS"] 281 | ) 282 | else: 283 | x_mult = conf[f"{env}"]["NUM_STEPS"] * conf[f"{env}"]["NUM_ENVS"] 284 | 285 | legend = [] 286 | for i, value in enumerate(values[env]): 287 | val = values[env][i] 288 | val_df = pd.DataFrame({f"vals_{i}": val[i] for i in range(len(val))}) 289 | 290 | val_ewm = val_df.ewm(span=200, axis=0).mean().to_numpy().T 291 | 292 | mean = val_ewm.mean(0) 293 | 294 | xs = jnp.arange(len(mean)) * x_mult 295 | 296 | std = jnp.std(val_ewm, axis=0) / jnp.sqrt(val_ewm.shape[0]) 297 | 298 | results_max = mean + std 299 | results_min = mean - std 300 | 301 | (leg,) = ax.plot(xs, mean, label=labels[i], linewidth=0.4) 302 | legend.append(leg) 303 | ax.fill_between(x=xs, y1=results_min, y2=results_max, alpha=0.5) 304 | 305 | if j == 0: 306 | 307 | ax.legend( 308 | legend, 309 | labels, 310 | loc="lower right", 311 | ncols=1, 312 | ) 313 | 314 | if len(values.keys()) > 1: 315 | ax.set_title(env, fontsize=8) 316 | ax.set_xlabel(xlabel) 317 | ax.tick_params(axis="x", which="major", pad=-3) 318 | ax.tick_params(axis="y", which="major", pad=-3) 319 | 320 | else: 321 | ax.set_title(env, fontsize=8) 322 | ax.set_xlabel(xlabel) 323 | 324 | ax.set_ylabel(ylabel) 325 | 326 | fig.savefig( 327 | f"{p}/{env}/{title}_{ylabel}_{env}.pdf", 328 | format="pdf", 329 | bbox_inches="tight", 330 | ) 331 | 332 | if training: 333 | fig.savefig( 334 | f"{p}/{env}/{title}_{ylabel}_{env}.png", 335 | format="png", 336 | bbox_inches="tight", 337 | ) 338 | wandb.log( 339 | { 340 | f"eval_figs/{env}/{ylabel}": wandb.Image( 341 | f"{p}/{env}/{title}_{ylabel}_{env}.png" 342 | ) 343 | }, 344 | step=exp_num, 345 | ) 346 | 347 | plot_all( 348 | returns_list, 349 | all_configs, 350 | ["OPEN"], 351 | xlabel="Frames", 352 | ylabel=f"Return", 353 | title=title, 354 | training=training, 355 | ) 356 | 357 | plot_all( 358 | tau_dormancies, 359 | all_configs, 360 | ["OPEN"], 361 | xlabel="Updates", 362 | ylabel=f"Dormancy", 363 | title=title, 364 | training=training, 365 | ) 366 | 367 | for env in envs: 368 | print(f"env: {env}, run_time : {runtimes[env]}") 369 | 370 | 371 | if __name__ == "__main__": 372 | parser = argparse.ArgumentParser() 373 | parser.add_argument("--exp-name", nargs="+", type=str, default=None) 374 | parser.add_argument("--exp-num", nargs="+", type=str, default=None) 375 | parser.add_argument("--envs", nargs="+", required=False, default=None) 376 | parser.add_argument("--num-runs", type=int, default=3) 377 | parser.add_argument( 378 | "--pmap", default=jax.local_device_count() > 1, action="store_true" 379 | ) 380 | parser.add_argument("--title", type=str) 381 | parser.add_argument("--larger", default=False, action="store_true") 382 | parser.add_argument("--pretrained", default=False, action="store_true") 383 | parser.add_argument("--multi", default=False, action="store_true") 384 | 385 | args = parser.parse_args() 386 | 387 | sns.set() 388 | plt.style.use("seaborn-v0_8-colorblind") 389 | 390 | tex_fonts = { 391 | "axes.labelsize": 6, 392 | "font.size": 8, 393 | "legend.fontsize": 5, 394 | "xtick.labelsize": 6, 395 | "ytick.labelsize": 6, 396 | } 397 | 398 | matplotlib.rcParams.update(tex_fonts) 399 | matplotlib.rcParams["axes.formatter.limits"] = [-3, 3] 400 | color_palette = sns.color_palette("colorblind", n_colors=5) 401 | plt.rcParams["axes.prop_cycle"] = plt.cycler(color=color_palette) 402 | devices = jax.local_devices() 403 | mesh = Mesh(devices, axis_names=("dim",)) 404 | 405 | if args.envs == ["gridworld"]: 406 | args.envs = [ 407 | "sixteen_rooms", 408 | "labyrinth", 409 | "rand_sparse", 410 | "rand_dense", 411 | "rand_long", 412 | "standard_maze", 413 | "rand_all", 414 | ] 415 | grids = True 416 | else: 417 | grids = False 418 | 419 | make_plot( 420 | exp_names=args.exp_name, 421 | exp_nums=args.exp_num, 422 | envs=args.envs, 423 | num_runs=args.num_runs, 424 | pmap=args.pmap, 425 | title=args.title, 426 | larger=args.larger, 427 | pretrained=args.pretrained, 428 | multi=args.multi, 429 | grids=grids, 430 | mesh=mesh, 431 | ) 432 | -------------------------------------------------------------------------------- /rl_optimizer/network.py: -------------------------------------------------------------------------------- 1 | """Full OPEN optimizer, incorporating all input features and learnable stochasticity""" 2 | 3 | from typing import Any, Optional 4 | 5 | import flax 6 | import flax.linen as nn 7 | import gin 8 | import jax 9 | from jax import lax 10 | import jax.numpy as jnp 11 | from optax import adam 12 | import optax 13 | 14 | import sys 15 | 16 | import base as opt_base 17 | 18 | from learned_optimization.learned_optimization import tree_utils 19 | from learned_optimization.learned_optimization.learned_optimizers import ( 20 | common, 21 | ) 22 | 23 | 24 | PRNGKey = jnp.ndarray 25 | 26 | 27 | def _second_moment_normalizer(x, axis, eps=1e-5): 28 | return x * lax.rsqrt(eps + jnp.mean(jnp.square(x), axis=axis, keepdims=True)) 29 | 30 | 31 | def iter_proportion(iterations, total_its=100000): 32 | f32 = jnp.float32 33 | 34 | return iterations / f32(total_its) 35 | 36 | 37 | @flax.struct.dataclass 38 | class GRUOptState: 39 | params: Any 40 | rolling_features: common.MomAccumulator 41 | iteration: jnp.ndarray 42 | state: Any 43 | carry: Any 44 | 45 | 46 | @gin.configurable 47 | class GRU_Opt(opt_base.LearnedOptimizer): 48 | def __init__( 49 | self, 50 | exp_mult=0.001, 51 | step_mult=0.001, 52 | hidden_size=16, 53 | gru_features=8, 54 | ): 55 | 56 | super().__init__() 57 | self._step_mult = step_mult 58 | self._exp_mult = exp_mult 59 | 60 | self.gru_features = gru_features 61 | 62 | self._gru = nn.GRUCell(features=self.gru_features) 63 | 64 | self._mod = nn.Sequential( 65 | [ 66 | nn.Dense(hidden_size), 67 | nn.LayerNorm(), 68 | nn.relu, 69 | nn.Dense(hidden_size), 70 | nn.LayerNorm(), 71 | nn.relu, 72 | nn.Dense(3), 73 | ] 74 | ) 75 | 76 | def init(self, key: PRNGKey) -> opt_base.MetaParams: 77 | # There are 19 features used as input. For now, hard code this. 78 | key = jax.random.split(key, 5) 79 | 80 | proxy_carry = self._gru.initialize_carry(key[4], (1,)) 81 | 82 | return { 83 | "params": self._mod.init(key[0], jnp.zeros([self.gru_features])), 84 | "gru_params": self._gru.init(key[2], proxy_carry, jnp.zeros([19])), 85 | } 86 | 87 | def opt_fn( 88 | self, theta: opt_base.MetaParams, is_training: bool = False 89 | ) -> opt_base.Optimizer: 90 | # ALL MOMENTUM TIMESCALES 91 | decays = jnp.asarray([0.1, 0.5, 0.9, 0.99, 0.999, 0.9999]) 92 | 93 | mod = self._mod 94 | gru = self._gru 95 | exp_mult = self._exp_mult 96 | step_mult = self._step_mult 97 | 98 | theta_mlp = theta["params"] 99 | theta_gru = theta["gru_params"] 100 | 101 | class _Opt(opt_base.Optimizer): 102 | def init( 103 | self, 104 | params: opt_base.Params, 105 | model_state: Any = None, 106 | num_steps: Optional[int] = None, 107 | key: Optional[PRNGKey] = None, 108 | ) -> GRUOptState: 109 | """Initialize opt state.""" 110 | 111 | param_tree = jax.tree_util.tree_structure(params) 112 | 113 | keys = jax.random.split(key, param_tree.num_leaves) 114 | keys = jax.tree_util.tree_unflatten(param_tree, keys) 115 | 116 | carry = jax.tree_util.tree_map( 117 | lambda p, k: gru.initialize_carry(k, jnp.expand_dims(p, -1).shape), 118 | params, 119 | keys, 120 | ) 121 | 122 | return GRUOptState( 123 | params=params, 124 | state=model_state, 125 | rolling_features=common.vec_rolling_mom(decays).init(params), 126 | iteration=jnp.asarray(0, dtype=jnp.int32), 127 | carry=carry, 128 | ) 129 | 130 | def update( 131 | self, 132 | opt_state_actor: GRUOptState, 133 | crit_opt_state: GRUOptState, 134 | grad: Any, 135 | activations: float, 136 | key: Optional[PRNGKey] = None, 137 | training_prop=0, 138 | batch_prop=0, 139 | config=None, 140 | layer_props=None, 141 | model_state: Any = None, 142 | mask=None, 143 | ) -> GRUOptState: 144 | 145 | next_rolling_features_actor = common.vec_rolling_mom(decays).update( 146 | opt_state_actor.rolling_features, grad["actor"] 147 | ) 148 | 149 | next_rolling_features_critic = common.vec_rolling_mom(decays).update( 150 | crit_opt_state.rolling_features, grad["critic"] 151 | ) 152 | 153 | rolling_features = { 154 | "actor": next_rolling_features_actor.m, 155 | "critic": next_rolling_features_critic.m, 156 | } 157 | 158 | training_step_feature = training_prop 159 | batch_feature = batch_prop 160 | eps1 = 1e-13 161 | eps2 = 1e-8 162 | t = opt_state_actor.iteration + 1 163 | 164 | def _update_tensor(p, g, mom, k, dorm, carry, layer_prop, mask): 165 | 166 | # this doesn't work with scalar parameters, so let's reshape. 167 | if not p.shape: 168 | p = jnp.expand_dims(p, 0) 169 | 170 | # use gradient conditioning (Optim4RL) 171 | gsign = jnp.expand_dims(jnp.sign(g), 0) 172 | glog = jnp.expand_dims(jnp.log(jnp.abs(g) + eps1), 0) 173 | 174 | mom = jnp.expand_dims(mom, 0) 175 | did_reshape = True 176 | else: 177 | gsign = jnp.sign(g) 178 | glog = jnp.log(jnp.abs(g) + eps1) 179 | did_reshape = False 180 | 181 | inps = [] 182 | inp_g = [] 183 | 184 | batch_gsign = jnp.expand_dims(gsign, axis=-1) 185 | batch_glog = jnp.expand_dims(glog, axis=-1) 186 | 187 | # feature consisting of raw parameter values 188 | batch_p = jnp.expand_dims(p, axis=-1) 189 | inps.append(batch_p) 190 | 191 | # feature consisting of all momentum values 192 | momsign = jnp.sign(mom) 193 | momlog = jnp.log(jnp.abs(mom) + eps1) 194 | inps.append(momsign) 195 | inps.append(momlog) 196 | 197 | inp_stack = jnp.concatenate(inps, axis=-1) 198 | 199 | axis = list(range(len(p.shape))) 200 | 201 | inp_stack_g = jnp.concatenate( 202 | [inp_stack, batch_gsign, batch_glog], axis=-1 203 | ) 204 | 205 | inp_stack_g = _second_moment_normalizer(inp_stack_g, axis=axis) 206 | 207 | # once normalized, add features that are constant across tensor. 208 | # namly the training proportion, batch proportion, parameter value and dormancy 209 | def stack_tensors(feature, input): 210 | 211 | stacked = jnp.reshape( 212 | feature, [1] * len(axis) + list(feature.shape[-1:]) 213 | ) 214 | stacked = jnp.tile(stacked, list(p.shape) + [1]) 215 | return jnp.concatenate([input, stacked], axis=-1) 216 | 217 | inp = jnp.tile( 218 | jnp.reshape( 219 | training_step_feature, 220 | [1] * len(axis) + list(training_step_feature.shape[-1:]), 221 | ), 222 | list(p.shape) + [1], 223 | ) 224 | 225 | stacked_batch_prop = jnp.tile( 226 | jnp.reshape( 227 | batch_feature, 228 | [1] * len(axis) + list(batch_feature.shape[-1:]), 229 | ), 230 | list(p.shape) + [1], 231 | ) 232 | 233 | layer_prop = jnp.expand_dims(layer_prop, 0) 234 | 235 | stacked_layer_prop = jnp.tile( 236 | jnp.reshape( 237 | layer_prop, [1] * len(axis) + list(layer_prop.shape[-1:]) 238 | ), 239 | list(p.shape) + [1], 240 | ) 241 | 242 | inp = jnp.concatenate([inp, stacked_layer_prop], axis=-1) 243 | 244 | inp = jnp.concatenate([inp, stacked_batch_prop], axis=-1) 245 | 246 | batch_dorm = jnp.expand_dims(dorm, axis=-1) 247 | 248 | if p.shape != dorm.shape: 249 | batch_dorm = jnp.tile( 250 | batch_dorm, [p.shape[0]] + len(axis) * [1] 251 | ) 252 | 253 | inp = jnp.concatenate([inp, batch_dorm], axis=-1) 254 | 255 | inp_g = jnp.concatenate([inp_stack_g, inp], axis=-1) 256 | 257 | gru_new_carry, gru_out = gru.apply(theta_gru, carry, inp_g) 258 | 259 | # apply the per parameter MLP. 260 | output = mod.apply(theta_mlp, gru_out) 261 | 262 | update_ = ( 263 | output[..., 0] * step_mult * jnp.exp(output[..., 1] * exp_mult) 264 | ) 265 | 266 | # Add the stochasticity *only* to the actor (using the mask) 267 | update = ( 268 | update_ 269 | + output[..., 2] 270 | * mask 271 | * jax.random.normal(k, shape=update_.shape) 272 | * step_mult 273 | ) 274 | 275 | update = update.reshape(p.shape) 276 | 277 | return (update, gru_new_carry) 278 | 279 | full_params = { 280 | "actor": opt_state_actor.params, 281 | "critic": crit_opt_state.params, 282 | } 283 | param_tree = jax.tree_util.tree_structure(full_params) 284 | 285 | keys = jax.random.split(key, param_tree.num_leaves) 286 | keys = jax.tree_util.tree_unflatten(param_tree, keys) 287 | 288 | activations = jax.tree_util.tree_flatten(activations)[0] 289 | 290 | activations = jax.tree_util.tree_unflatten(param_tree, activations) 291 | 292 | def calc_dormancy(tensor_activations): 293 | tensor_activations = tensor_activations + 1e-11 294 | total_activations = jnp.abs(tensor_activations).sum(axis=-1) 295 | total_activations = jnp.tile( 296 | jnp.expand_dims(total_activations, -1), 297 | tensor_activations.shape[-1], 298 | ) 299 | dormancy = ( 300 | tensor_activations 301 | / total_activations 302 | * tensor_activations.shape[-1] 303 | ) 304 | return dormancy 305 | 306 | dormancies = jax.tree_util.tree_map(calc_dormancy, activations) 307 | 308 | full_carry = { 309 | "actor": opt_state_actor.carry, 310 | "critic": crit_opt_state.carry, 311 | } 312 | 313 | updates_carry = jax.tree_util.tree_map( 314 | _update_tensor, 315 | full_params, 316 | grad, 317 | rolling_features, 318 | keys, 319 | dormancies, 320 | full_carry, 321 | layer_props, 322 | mask, 323 | ) 324 | 325 | updates_carry_leaves = jax.tree_util.tree_leaves(updates_carry) 326 | updates = [ 327 | updates_carry_leaves[i] 328 | for i in range(0, len(updates_carry_leaves), 2) 329 | ] 330 | new_carry = [ 331 | updates_carry_leaves[i + 1] 332 | for i in range(0, len(updates_carry_leaves), 2) 333 | ] 334 | 335 | updates = jax.tree_util.tree_unflatten(param_tree, updates) 336 | new_carry = jax.tree_util.tree_unflatten(param_tree, new_carry) 337 | 338 | # Make update globally 0 339 | updates_flat = jax.flatten_util.ravel_pytree(updates)[0] 340 | update_mean = updates_flat.mean() 341 | update_mean = jax.tree_util.tree_unflatten( 342 | param_tree, jnp.tile(update_mean, param_tree.num_leaves) 343 | ) 344 | 345 | updates = jax.tree_util.tree_map( 346 | lambda x, mu: x - mu, updates, update_mean 347 | ) 348 | 349 | def param_update(p, update): 350 | 351 | new_param = p - update 352 | 353 | return new_param 354 | 355 | next_params = jax.tree_util.tree_map(param_update, full_params, updates) 356 | 357 | # For simplicity, maitain different opt states between the actor and the critic 358 | next_opt_state_actor = GRUOptState( 359 | params=tree_utils.match_type( 360 | next_params["actor"], opt_state_actor.params 361 | ), 362 | rolling_features=tree_utils.match_type( 363 | next_rolling_features_actor, opt_state_actor.rolling_features 364 | ), 365 | iteration=opt_state_actor.iteration + 1, 366 | state=model_state, 367 | carry=new_carry["actor"], 368 | ) 369 | 370 | next_opt_state_critic = GRUOptState( 371 | params=tree_utils.match_type( 372 | next_params["critic"], crit_opt_state.params 373 | ), 374 | rolling_features=tree_utils.match_type( 375 | next_rolling_features_critic, crit_opt_state.rolling_features 376 | ), 377 | iteration=opt_state_actor.iteration + 1, 378 | state=model_state, 379 | carry=new_carry["critic"], 380 | ) 381 | 382 | param_flat_actor = jax.flatten_util.ravel_pytree(next_params["actor"])[ 383 | 0 384 | ] 385 | if config["VISUALISE"]: 386 | param_mean_actor = jnp.mean(jnp.array(param_flat_actor)) 387 | param_abs_mean_actor = jnp.mean(jnp.abs(jnp.array(param_flat_actor))) 388 | 389 | param_flat_critic = jax.flatten_util.ravel_pytree( 390 | next_params["critic"] 391 | )[0] 392 | if config["VISUALISE"]: 393 | param_mean_critic = jnp.mean(jnp.array(param_flat_critic)) 394 | param_abs_mean_critic = jnp.mean(jnp.abs(jnp.array(param_flat_critic))) 395 | 396 | if config["VISUALISE"]: 397 | return ( 398 | next_opt_state_actor, 399 | next_opt_state_critic, 400 | (param_mean_actor, param_abs_mean_actor), 401 | (param_mean_critic, param_abs_mean_critic), 402 | ) 403 | else: 404 | return ( 405 | next_opt_state_actor, 406 | next_opt_state_critic, 407 | (param_abs_mean_actor), 408 | (param_abs_mean_critic), 409 | ) 410 | 411 | return _Opt() 412 | -------------------------------------------------------------------------------- /rl_optimizer/pretrained/ant_OPEN.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/ant_OPEN.npy -------------------------------------------------------------------------------- /rl_optimizer/pretrained/asterix_OPEN.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/asterix_OPEN.npy -------------------------------------------------------------------------------- /rl_optimizer/pretrained/breakout_OPEN.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/breakout_OPEN.npy -------------------------------------------------------------------------------- /rl_optimizer/pretrained/freeway_OPEN.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/freeway_OPEN.npy -------------------------------------------------------------------------------- /rl_optimizer/pretrained/gridworld_OPEN.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/gridworld_OPEN.npy -------------------------------------------------------------------------------- /rl_optimizer/pretrained/multi_OPEN.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/multi_OPEN.npy -------------------------------------------------------------------------------- /rl_optimizer/pretrained/spaceinvaders_OPEN.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGoldie/rl-learned-optimization/80985af1d3fb149945bc0ab7bfcc588462ed04b1/rl_optimizer/pretrained/spaceinvaders_OPEN.npy -------------------------------------------------------------------------------- /rl_optimizer/train.py: -------------------------------------------------------------------------------- 1 | """(inner and outer) training loop for OPEN""" 2 | 3 | import numpy as np 4 | from typing import Sequence, NamedTuple, Any 5 | import os 6 | import os.path as osp 7 | from datetime import datetime 8 | from tqdm import tqdm 9 | import wandb 10 | import time 11 | from functools import partial 12 | import argparse 13 | 14 | from network import GRU_Opt 15 | from configs import all_configs 16 | from eval import make_plot 17 | from utils import ( 18 | GymnaxGymWrapper, 19 | GymnaxLogWrapper, 20 | FlatWrapper, 21 | BraxGymnaxWrapper, 22 | ClipAction, 23 | TransformObservation, 24 | NormalizeObservation, 25 | NormalizeReward, 26 | VecEnv, 27 | ) 28 | 29 | import jax 30 | import jax.numpy as jnp 31 | import flax.linen as nn 32 | import optax 33 | from flax.linen.initializers import constant, orthogonal 34 | from flax.training.train_state import TrainState 35 | import flax 36 | import distrax 37 | import gymnax 38 | from brax.envs.wrappers.gym import GymWrapper 39 | from brax import envs 40 | import evosax 41 | from evosax.algorithms.distribution_based import Open_ES 42 | from evosax.core.fitness_shaping import ( 43 | centered_rank_fitness_shaping_fn, 44 | identity_fitness_shaping_fn, 45 | ) 46 | from gymnax.environments import spaces 47 | from optax import adam 48 | from jax.sharding import Mesh 49 | from jax.sharding import PartitionSpec as P, NamedSharding 50 | 51 | import sys 52 | 53 | # GROOVE imports cause an issue 54 | groove_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "groove")) 55 | sys.path.insert(0, groove_path) 56 | 57 | # Import the modules you need 58 | # The 'groove' module will be found in the original path 59 | # The 'environments' module will be found via the path we just added 60 | import groove.environments.gridworld.configs as grid_conf 61 | from groove.environments.gridworld import gridworld as grid 62 | from groove.environments.gridworld.configs import ENV_MODE_KWARGS 63 | 64 | sys.path.remove(groove_path) 65 | 66 | 67 | class Actor(nn.Module): 68 | action_dim: Sequence[int] 69 | config: dict 70 | 71 | @nn.compact 72 | def __call__(self, x): 73 | hsize = self.config["HSIZE"] 74 | if self.config["ACTIVATION"] == "relu": 75 | activation = nn.relu 76 | else: 77 | activation = nn.tanh 78 | 79 | actor_mean = nn.Dense( 80 | hsize, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 81 | )(x) 82 | actor_mean = activation(actor_mean) 83 | actor_mean_activation_1 = { 84 | "kernel": jnp.mean(jnp.abs(actor_mean), axis=0), 85 | "bias": jnp.mean(jnp.abs(actor_mean), axis=0), 86 | } 87 | 88 | actor_mean = nn.Dense( 89 | hsize, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 90 | )(actor_mean) 91 | actor_mean = activation(actor_mean) 92 | actor_mean_activation_2 = { 93 | "kernel": jnp.mean(jnp.abs(actor_mean), axis=0), 94 | "bias": jnp.mean(jnp.abs(actor_mean), axis=0), 95 | } 96 | actor_mean = nn.Dense( 97 | self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) 98 | )(actor_mean) 99 | actor_mean_activation_3 = { 100 | "kernel": jnp.mean(jnp.abs(actor_mean), axis=0), 101 | "bias": jnp.mean(jnp.abs(actor_mean), axis=0), 102 | } 103 | 104 | if self.config["CONTINUOUS"]: 105 | actor_logtstd = self.param( 106 | "log_std", nn.initializers.zeros, (self.action_dim,) 107 | ) 108 | actor_mean_activation_4 = jnp.expand_dims( 109 | jnp.mean(jnp.exp(actor_logtstd), axis=0), axis=0 110 | ) 111 | pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd)) 112 | else: 113 | pi = distrax.Categorical(logits=actor_mean) 114 | 115 | if self.config["CONTINUOUS"]: 116 | activations = ( 117 | actor_mean_activation_1, 118 | actor_mean_activation_2, 119 | actor_mean_activation_3, 120 | actor_mean_activation_4, 121 | ) 122 | else: 123 | activations = ( 124 | actor_mean_activation_1, 125 | actor_mean_activation_2, 126 | actor_mean_activation_3, 127 | ) 128 | 129 | return pi, activations 130 | 131 | 132 | class Critic(nn.Module): 133 | config: dict 134 | 135 | @nn.compact 136 | def __call__(self, x): 137 | hsize = self.config["HSIZE"] 138 | if self.config["ACTIVATION"] == "relu": 139 | activation = nn.relu 140 | else: 141 | activation = nn.tanh 142 | 143 | critic = nn.Dense( 144 | hsize, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 145 | )(x) 146 | critic = activation(critic) 147 | critic_mean_activation_1 = { 148 | "kernel": jnp.mean(jnp.abs(critic), axis=0), 149 | "bias": jnp.mean(jnp.abs(critic), axis=0), 150 | } 151 | critic = nn.Dense( 152 | hsize, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 153 | )(critic) 154 | critic = activation(critic) 155 | critic_mean_activation_2 = { 156 | "kernel": jnp.mean(jnp.abs(critic), axis=0), 157 | "bias": jnp.mean(jnp.abs(critic), axis=0), 158 | } 159 | critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( 160 | critic 161 | ) 162 | critic_mean_activation_3 = { 163 | "kernel": jnp.mean(jnp.abs(critic), axis=0), 164 | "bias": jnp.mean(jnp.abs(critic), axis=0), 165 | } 166 | 167 | activations = ( 168 | critic_mean_activation_1, 169 | critic_mean_activation_2, 170 | critic_mean_activation_3, 171 | ) 172 | return jnp.squeeze(critic, axis=-1), activations 173 | 174 | 175 | class Transition(NamedTuple): 176 | done: jnp.ndarray 177 | action: jnp.ndarray 178 | value: jnp.ndarray 179 | reward: jnp.ndarray 180 | log_prob: jnp.ndarray 181 | obs: jnp.ndarray 182 | info: jnp.ndarray 183 | 184 | 185 | def symlog(x): 186 | return jnp.sign(x) * jnp.log(jnp.abs(x) + 1) 187 | 188 | 189 | gridtypes = [ 190 | "sixteen_rooms", 191 | "labyrinth", 192 | "rand_sparse", 193 | "rand_dense", 194 | "rand_long", 195 | "standard_maze", 196 | "rand_all", 197 | ] 198 | 199 | 200 | def make_train(config): 201 | 202 | config["NUM_UPDATES"] = ( 203 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] 204 | ) 205 | config["MINIBATCH_SIZE"] = ( 206 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] 207 | ) 208 | config["TOTAL_UPDATES"] = ( 209 | config["NUM_UPDATES"] * config["UPDATE_EPOCHS"] * config["NUM_MINIBATCHES"] 210 | ) 211 | 212 | @partial(jax.jit, static_argnames=["grid_type"]) 213 | def train(meta_params, rng, grid_type=6): 214 | 215 | if "Brax-" in config["ENV_NAME"]: 216 | name = config["ENV_NAME"].split("Brax-")[1] 217 | env, env_params = BraxGymnaxWrapper(env_name=name), None 218 | if config.get("CLIP_ACTION"): 219 | env = ClipAction(env) 220 | env = GymnaxLogWrapper(env) 221 | if config.get("SYMLOG_OBS"): 222 | env = TransformObservation(env, transform_obs=symlog) 223 | 224 | env = VecEnv(env) 225 | if config.get("NORMALIZE"): 226 | env = NormalizeObservation(env) 227 | env = NormalizeReward(env, config["GAMMA"]) 228 | actor = Actor(env.action_space(env_params).shape[0], config=config) 229 | critic = Critic(config=config) 230 | init_x = jnp.zeros(env.observation_space(env_params).shape) 231 | else: 232 | # INIT ENV 233 | if config["ENV_NAME"] == "gridworld": 234 | env = grid.GridWorld(**ENV_MODE_KWARGS[gridtypes[grid_type]]) 235 | rng, rng_ = jax.random.split(rng) 236 | env_params = grid_conf.reset_env_params(rng_, gridtypes[grid_type]) 237 | else: 238 | env, env_params = gymnax.make(config["ENV_NAME"]) 239 | env = GymnaxGymWrapper(env, env_params, config) 240 | env = FlatWrapper(env) 241 | 242 | env = GymnaxLogWrapper(env) 243 | env = VecEnv(env) 244 | actor = Actor(env.action_space, config=config) 245 | critic = Critic(config=config) 246 | init_x = jnp.zeros(env.observation_space) 247 | 248 | # INIT NETWORK 249 | 250 | rng, _rng = jax.random.split(rng) 251 | actor_params = actor.init(_rng, init_x) 252 | critic_params = critic.init(_rng, init_x) 253 | 254 | if config["larger"]: 255 | meta_opt = GRU_Opt(hidden_size=32, gru_features=16) 256 | else: 257 | meta_opt = GRU_Opt(hidden_size=16, gru_features=8) 258 | opt = meta_opt.opt_fn(meta_params) 259 | clip_opt = optax.clip_by_global_norm(config["MAX_GRAD_NORM"]) 260 | 261 | rng, rng_act, rng_crit = jax.random.split(rng, 3) 262 | train_state_actor = opt.init(actor_params, key=rng_act) 263 | train_state_critic = opt.init(critic_params, key=rng_crit) 264 | 265 | act_param_tree = jax.tree_util.tree_structure(train_state_actor.params) 266 | act_layer_props = [] 267 | num_act_layers = len(train_state_actor.params["params"]) 268 | for i, layer in enumerate(train_state_actor.params["params"]): 269 | layer_prop = i / (num_act_layers - 1) 270 | if type(train_state_actor.params["params"][layer]) == dict: 271 | act_layer_props.extend( 272 | [layer_prop] * len(train_state_actor.params["params"][layer]) 273 | ) 274 | else: 275 | act_layer_props.extend([layer_prop]) 276 | 277 | act_layer_props = jax.tree_util.tree_unflatten(act_param_tree, act_layer_props) 278 | 279 | crit_param_tree = jax.tree_util.tree_structure(train_state_critic.params) 280 | crit_layer_props = [] 281 | num_crit_layers = len(train_state_critic.params["params"]) 282 | for i, layer in enumerate(train_state_critic.params["params"]): 283 | layer_prop = i / (num_crit_layers - 1) 284 | if type(train_state_critic.params["params"][layer]) == dict: 285 | crit_layer_props.extend( 286 | [layer_prop] * len(train_state_critic.params["params"][layer]) 287 | ) 288 | else: 289 | crit_layer_props.extend([layer_prop]) 290 | 291 | crit_layer_props = jax.tree_util.tree_unflatten( 292 | crit_param_tree, crit_layer_props 293 | ) 294 | 295 | # INIT ENV 296 | all_rng = jax.random.split(_rng, config["NUM_ENVS"] + 1) 297 | rng, _rng = all_rng[0], all_rng[1:] 298 | obsv, env_state = env.reset(_rng, env_params) 299 | 300 | # TRAIN LOOP 301 | def _update_step(runner_state, unused): 302 | 303 | # COLLECT TRAJECTORIES 304 | def _env_step(runner_state, unused): 305 | ( 306 | train_state_actor, 307 | train_state_critic, 308 | env_state, 309 | last_obs, 310 | last_done, 311 | rng, 312 | last_param_abs, 313 | ) = runner_state 314 | rng, _rng = jax.random.split(rng) 315 | # SELECT ACTION 316 | pi, _ = actor.apply(train_state_actor.params, last_obs) 317 | value, _ = critic.apply(train_state_critic.params, last_obs) 318 | action = pi.sample(seed=_rng) 319 | 320 | log_prob = pi.log_prob(action) 321 | # STEP ENV 322 | rng, _rng = jax.random.split(rng) 323 | rng_step = jax.random.split(_rng, config["NUM_ENVS"]) 324 | obsv, env_state, reward, done, info = env.step( 325 | rng_step, env_state, action, env_params 326 | ) 327 | transition = Transition( 328 | done, action, value, reward, log_prob, last_obs, info 329 | ) 330 | runner_state = ( 331 | train_state_actor, 332 | train_state_critic, 333 | env_state, 334 | obsv, 335 | done, 336 | rng, 337 | 0.0, 338 | ) 339 | 340 | return runner_state, transition 341 | 342 | runner_state, traj_batch = jax.lax.scan( 343 | _env_step, runner_state, None, config["NUM_STEPS"] 344 | ) 345 | 346 | # CALCULATE ADVANTAGE 347 | ( 348 | train_state_actor, 349 | train_state_critic, 350 | env_state, 351 | last_obs, 352 | last_done, 353 | rng, 354 | last_param_abs, 355 | ) = runner_state 356 | last_val, _ = critic.apply(train_state_critic.params, last_obs) 357 | last_val = jnp.where(last_done, jnp.zeros_like(last_val), last_val) 358 | 359 | def _calculate_gae(traj_batch, last_val): 360 | def _get_advantages(gae_and_next_value, transition): 361 | gae, next_value = gae_and_next_value 362 | done, value, reward = ( 363 | transition.done, 364 | transition.value, 365 | transition.reward, 366 | ) 367 | delta = reward + config["GAMMA"] * next_value * (1 - done) - value 368 | gae = ( 369 | delta 370 | + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae 371 | ) 372 | return (gae, value), gae 373 | 374 | _, advantages = jax.lax.scan( 375 | _get_advantages, 376 | (jnp.zeros_like(last_val), last_val), 377 | traj_batch, 378 | reverse=True, 379 | unroll=16, 380 | ) 381 | return advantages, advantages + traj_batch.value 382 | 383 | advantages, targets = _calculate_gae(traj_batch, last_val) 384 | 385 | # UPDATE NETWORK 386 | def _update_epoch(update_state, unused): 387 | def _update_minbatch(train_state_key, batch_info): 388 | train_state_actor, train_state_critic, key = train_state_key 389 | key, key_ = jax.random.split(key) 390 | 391 | traj_batch, advantages, targets = batch_info 392 | 393 | def _loss_fn(actor_params, critic_params, traj_batch, gae, targets): 394 | # RERUN NETWORK 395 | pi, actor_activations = actor.apply( 396 | actor_params, traj_batch.obs 397 | ) 398 | value, critic_activations = critic.apply( 399 | critic_params, traj_batch.obs 400 | ) 401 | 402 | log_prob = pi.log_prob(traj_batch.action) 403 | 404 | # CALCULATE VALUE LOSS 405 | value_pred_clipped = traj_batch.value + ( 406 | value - traj_batch.value 407 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) 408 | value_losses = jnp.square(value - targets) 409 | value_losses_clipped = jnp.square(value_pred_clipped - targets) 410 | value_loss = ( 411 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() 412 | ) 413 | 414 | # CALCULATE ACTOR LOSS 415 | ratio = jnp.exp(log_prob - traj_batch.log_prob) 416 | gae = (gae - gae.mean()) / (gae.std() + 1e-8) 417 | loss_actor1 = ratio * gae 418 | loss_actor2 = ( 419 | jnp.clip( 420 | ratio, 421 | 1.0 - config["CLIP_EPS"], 422 | 1.0 + config["CLIP_EPS"], 423 | ) 424 | * gae 425 | ) 426 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2) 427 | loss_actor = loss_actor.mean() 428 | entropy = pi.entropy().mean() 429 | 430 | total_loss = ( 431 | loss_actor 432 | + config["VF_COEF"] * value_loss 433 | - config["ENT_COEF"] * entropy 434 | ) 435 | 436 | return total_loss, (actor_activations, critic_activations) 437 | 438 | training_prop = ( 439 | train_state_actor.iteration 440 | // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]) 441 | ) / (config["NUM_UPDATES"] - 1) 442 | batch_prop = ( 443 | (train_state_actor.iteration // config["NUM_MINIBATCHES"]) 444 | % config["UPDATE_EPOCHS"] 445 | ) / (config["UPDATE_EPOCHS"] - 1) 446 | 447 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True, argnums=[0, 1]) 448 | (total_loss, (actor_activations, critic_activations)), ( 449 | actor_grads, 450 | critic_grads, 451 | ) = grad_fn( 452 | train_state_actor.params, 453 | train_state_critic.params, 454 | traj_batch, 455 | advantages, 456 | targets, 457 | ) 458 | 459 | key_actor, key_critic = jax.random.split(key_) 460 | actor_grads, _ = clip_opt.update(actor_grads, None) 461 | critic_grads, _ = clip_opt.update(critic_grads, None) 462 | actor_mask = {"kernel": 1, "bias": 1} 463 | critic_mask = {"kernel": 0, "bias": 0} 464 | 465 | # FOR NOW, HARD CODE MASK 466 | if config["CONTINUOUS"]: 467 | mask = { 468 | "actor": { 469 | "params": { 470 | "Dense_0": actor_mask, 471 | "Dense_1": actor_mask, 472 | "Dense_2": actor_mask, 473 | "log_std": 1, 474 | } 475 | }, 476 | "critic": { 477 | "params": { 478 | "Dense_0": critic_mask, 479 | "Dense_1": critic_mask, 480 | "Dense_2": critic_mask, 481 | } 482 | }, 483 | } 484 | 485 | else: 486 | 487 | mask = { 488 | "actor": { 489 | "params": { 490 | "Dense_0": actor_mask, 491 | "Dense_1": actor_mask, 492 | "Dense_2": actor_mask, 493 | } 494 | }, 495 | "critic": { 496 | "params": { 497 | "Dense_0": critic_mask, 498 | "Dense_1": critic_mask, 499 | "Dense_2": critic_mask, 500 | } 501 | }, 502 | } 503 | 504 | activations = { 505 | "actor": actor_activations, 506 | "critic": critic_activations, 507 | } 508 | grads = {"actor": actor_grads, "critic": critic_grads} 509 | layer_props = {"actor": act_layer_props, "critic": crit_layer_props} 510 | 511 | # APPLY OPTIMIZER 512 | ( 513 | train_state_actor, 514 | train_state_critic, 515 | actor_updates, 516 | critic_updates, 517 | ) = opt.update( 518 | train_state_actor, 519 | train_state_critic, 520 | grads, 521 | activations, 522 | key=key_actor, 523 | training_prop=training_prop, 524 | config=config, 525 | batch_prop=batch_prop, 526 | layer_props=layer_props, 527 | mask=mask, 528 | ) 529 | 530 | train_state_key_ = (train_state_actor, train_state_critic, key) 531 | return train_state_key_, ( 532 | total_loss, 533 | actor_updates, 534 | critic_updates, 535 | actor_activations, 536 | critic_activations, 537 | ) 538 | 539 | ( 540 | train_state_actor, 541 | train_state_critic, 542 | traj_batch, 543 | advantages, 544 | targets, 545 | rng, 546 | ) = update_state 547 | rng, _rng = jax.random.split(rng) 548 | batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"] 549 | assert ( 550 | batch_size == config["NUM_STEPS"] * config["NUM_ENVS"] 551 | ), "batch size must be equal to number of steps * number of envs" 552 | permutation = jax.random.permutation(_rng, batch_size) 553 | batch = (traj_batch, advantages, targets) 554 | batch = jax.tree_util.tree_map( 555 | lambda x: x.reshape((batch_size,) + x.shape[2:]), batch 556 | ) 557 | shuffled_batch = jax.tree_util.tree_map( 558 | lambda x: jnp.take(x, permutation, axis=0), batch 559 | ) 560 | minibatches = jax.tree_util.tree_map( 561 | lambda x: jnp.reshape( 562 | x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) 563 | ), 564 | shuffled_batch, 565 | ) 566 | 567 | train_state_key = (train_state_actor, train_state_critic, rng) 568 | train_state_key, total_loss_updates = jax.lax.scan( 569 | _update_minbatch, train_state_key, minibatches 570 | ) 571 | train_state_actor, train_state_critic, rng = train_state_key 572 | update_state = ( 573 | train_state_actor, 574 | train_state_critic, 575 | traj_batch, 576 | advantages, 577 | targets, 578 | rng, 579 | ) 580 | 581 | return update_state, total_loss_updates 582 | 583 | update_state = ( 584 | train_state_actor, 585 | train_state_critic, 586 | traj_batch, 587 | advantages, 588 | targets, 589 | rng, 590 | ) 591 | update_state, loss_update = jax.lax.scan( 592 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"] 593 | ) 594 | 595 | train_state_actor = update_state[0] 596 | train_state_critic = update_state[1] 597 | 598 | ( 599 | loss_info, 600 | actor_updates, 601 | critic_updates, 602 | actor_activations, 603 | critic_activations, 604 | ) = loss_update 605 | if config["VISUALISE"]: 606 | 607 | metric = traj_batch.info 608 | 609 | else: 610 | metric = dict() 611 | 612 | metric.update( 613 | { 614 | "returned_episode_returns": traj_batch.info[ 615 | "returned_episode_returns" 616 | ][-1].mean() 617 | } 618 | ) 619 | 620 | rng = update_state[-1] 621 | 622 | actor_size = len(jax.flatten_util.ravel_pytree(runner_state[0].params)[0]) 623 | critic_size = len(jax.flatten_util.ravel_pytree(runner_state[1].params)[0]) 624 | if config["VISUALISE"]: 625 | actor_abs = jax.flatten_util.ravel_pytree(actor_updates[1])[0] 626 | critic_abs = jax.flatten_util.ravel_pytree(actor_updates[1])[0] 627 | else: 628 | actor_abs = jax.flatten_util.ravel_pytree(actor_updates)[0][-1] 629 | critic_abs = jax.flatten_util.ravel_pytree(actor_updates)[0][-1] 630 | abs_param_mean = (actor_size * actor_abs + critic_size * critic_abs) / ( 631 | actor_size + critic_size 632 | ) 633 | runner_state = ( 634 | train_state_actor, 635 | train_state_critic, 636 | env_state, 637 | last_obs, 638 | last_done, 639 | rng, 640 | abs_param_mean.mean(), 641 | ) 642 | if config["VISUALISE"]: 643 | 644 | # CALCULATE DORMANCY FOR TRACKING 645 | def calc_dormancy(tensor_activations, tau=1e-7): 646 | tensor_activations = tensor_activations + 1e-11 647 | total_activations = jnp.abs(tensor_activations).sum(axis=-1) 648 | total_activations = jnp.tile( 649 | jnp.expand_dims(total_activations, -1), 650 | tensor_activations.shape[-1], 651 | ) 652 | dormancy = ( 653 | tensor_activations 654 | / total_activations 655 | * tensor_activations.shape[-1] 656 | ) 657 | tau_dormancy = dormancy < tau 658 | tau_dormancy = tau_dormancy.sum(axis=(-1)) 659 | return tau_dormancy 660 | 661 | full_activations = (actor_activations, critic_activations) 662 | dormancies = jax.tree_util.tree_map(calc_dormancy, full_activations) 663 | dormancies = jax.tree_util.tree_flatten(dormancies)[0] 664 | 665 | prop_dormancies = jnp.stack(dormancies).sum(axis=0) 666 | prop_dormancies = ( 667 | prop_dormancies 668 | / len(jax.flatten_util.ravel_pytree(full_activations)[0]) 669 | * prop_dormancies.shape[0] 670 | * prop_dormancies.shape[1] 671 | ) 672 | param_mean = ( 673 | actor_size * jax.flatten_util.ravel_pytree(actor_updates[0])[0] 674 | + critic_size * jax.flatten_util.ravel_pytree(critic_updates[0])[0] 675 | ) / (actor_size + critic_size) 676 | log_updates = { 677 | "param_mean": param_mean, 678 | "param_abs_mean": abs_param_mean, 679 | } 680 | metric.update({"dormancy": prop_dormancies}) 681 | metric.update(log_updates) 682 | 683 | return runner_state, metric 684 | 685 | rng, _rng = jax.random.split(rng) 686 | runner_state = ( 687 | train_state_actor, 688 | train_state_critic, 689 | env_state, 690 | obsv, 691 | jnp.zeros((config["NUM_ENVS"]), dtype=bool), 692 | _rng, 693 | 0.0, 694 | ) 695 | 696 | runner_state, metric = jax.lax.scan( 697 | _update_step, runner_state, None, config["NUM_UPDATES"] 698 | ) 699 | 700 | return runner_state, metric 701 | 702 | return train 703 | 704 | 705 | if __name__ == "__main__": 706 | 707 | parser = argparse.ArgumentParser() 708 | parser.add_argument("--num-generations", type=int, default=512) 709 | parser.add_argument("--envs", nargs="+", required=True) 710 | parser.add_argument("--lr", type=float, default=0.03) 711 | parser.add_argument("--popsize", type=int, default=64) 712 | parser.add_argument("--num-rollouts", type=int, default=1) 713 | parser.add_argument("--save-every-k", type=int, default=24) 714 | parser.add_argument("--noise-level", type=float, default=0.03) 715 | parser.add_argument( 716 | "--pmap", default=jax.local_device_count() > 1, action="store_true" 717 | ) 718 | parser.add_argument("--wandb-name", type=str, default="OPEN") 719 | parser.add_argument("--wandb-entity", type=str, default=None) 720 | parser.add_argument("--sigma-decay", type=float, default=0.999) 721 | parser.add_argument("--lr-decay", type=float, default=0.999) 722 | parser.add_argument("--larger", default=False, action="store_true") 723 | 724 | args = parser.parse_args() 725 | 726 | evo_config = { 727 | "ENV_NAME": args.envs, 728 | "POPULATION_SIZE": args.popsize, 729 | "NUM_GENERATIONS": args.num_generations, 730 | "NUM_ROLLOUTS": args.num_rollouts, 731 | "SAVE_EVERY_K": args.save_every_k, 732 | "NOISE_LEVEL": args.noise_level, 733 | "PMAP": args.pmap, 734 | "LR": args.lr, 735 | "num_GPUs": jax.local_device_count(), 736 | } 737 | 738 | all_configs = {k: all_configs[k] for k in evo_config["ENV_NAME"]} 739 | 740 | save_loc = "training" 741 | if not os.path.exists(save_loc): 742 | os.mkdir(save_loc) 743 | save_dir = f"{save_loc}/{str(datetime.now()).replace(' ', '_')}_optimizer" 744 | os.mkdir(f"{save_dir}") 745 | 746 | popsize = args.popsize 747 | num_generations = args.num_generations 748 | num_rollouts = args.num_rollouts 749 | save_every_k_gens = args.save_every_k 750 | 751 | wandb.init( 752 | project="OPEN", 753 | config=evo_config, 754 | name=args.wandb_name, 755 | entity=args.wandb_entity, 756 | ) 757 | 758 | if args.larger: 759 | meta_opt = GRU_Opt(hidden_size=32, gru_features=16) 760 | else: 761 | meta_opt = GRU_Opt(hidden_size=16, gru_features=8) 762 | params = meta_opt.init(jax.random.PRNGKey(0)) 763 | 764 | params = jax.tree.map(lambda x: jnp.zeros_like(x), params) 765 | 766 | devices = jax.devices() 767 | mesh = Mesh(devices, axis_names=("dim",)) 768 | sharding_p = NamedSharding( 769 | mesh, 770 | P( 771 | "dim", 772 | ), 773 | ) 774 | sharding_rng = NamedSharding( 775 | mesh, 776 | P( 777 | "dim", 778 | ), 779 | ) 780 | 781 | def make_rollout(train_fn): 782 | def single_rollout(rng_input, meta_params): 783 | params, metrics = train_fn(meta_params, rng_input) 784 | 785 | fitness = metrics["returned_episode_returns"][-1] 786 | return fitness 787 | 788 | vmap_rollout = jax.vmap(single_rollout, in_axes=(0, None)) 789 | rollout = jax.jit(jax.vmap(vmap_rollout, in_axes=(0, 0))) 790 | 791 | # if evo_config["PMAP"]: 792 | # rollout = jax.pmap(rollout) 793 | 794 | return rollout 795 | 796 | for k in all_configs.keys(): 797 | all_configs[k]["NUM_UPDATES"] = ( 798 | all_configs[k]["TOTAL_TIMESTEPS"] 799 | // all_configs[k]["NUM_STEPS"] 800 | // all_configs[k]["NUM_ENVS"] 801 | ) 802 | all_configs[k]["larger"] = args.larger 803 | 804 | rollouts = {k: make_rollout(make_train(v)) for k, v in all_configs.items()} 805 | 806 | rng = jax.random.PRNGKey(42) 807 | 808 | if args.envs == ["gridworld"]: 809 | fitness_shaping_fn = identity_fitness_shaping_fn 810 | else: 811 | fitness_shaping_fn = centered_rank_fitness_shaping_fn 812 | 813 | # In practice, you should set decay rate and transition steps to be how much/long you want to decay for. 814 | # We set transition steps = 1 and all other hparams to match the behaviour of the original evosax here. 815 | strategy = Open_ES( 816 | population_size=args.popsize, 817 | optimizer=optax.adam( 818 | optax.schedules.exponential_decay( 819 | evo_config["LR"], 820 | decay_rate=args.lr_decay, 821 | end_value=0, 822 | transition_steps=1, 823 | ), 824 | b1=0.99, 825 | ), 826 | std_schedule=optax.schedules.exponential_decay( 827 | evo_config["NOISE_LEVEL"], 828 | decay_rate=args.sigma_decay, 829 | end_value=0, 830 | transition_steps=1, 831 | ), 832 | solution=params, 833 | fitness_shaping_fn=fitness_shaping_fn, 834 | ) 835 | 836 | es_params = strategy.default_params 837 | 838 | state = strategy.init(key=rng, params=es_params, mean=params) 839 | 840 | most_neg = {env: 0 for env in evo_config["ENV_NAME"]} 841 | 842 | fit_history = [] 843 | for gen in tqdm(range(num_generations)): 844 | 845 | rng, rng_ask, rng_eval = jax.random.split(rng, 3) 846 | x, state = jax.jit(strategy.ask)(rng_ask, state, es_params) 847 | x = jax.device_put(x, sharding_p) 848 | 849 | # Set up antithetic task sampling by repeating each optimizer. 850 | if args.envs == ["gridworld"]: 851 | x, state = jax.jit(strategy.ask)(rng_ask, state, es_params) 852 | new_orders = [ 853 | [i, int(args.popsize / 2) + i] for i in range(int(args.popsize / 2)) 854 | ] 855 | new_orders = jnp.array([x for y in new_orders for x in y]) 856 | x = x[new_orders] 857 | 858 | fit_info = {} 859 | 860 | wandb.log( 861 | { 862 | "evo/std":state.std, 863 | "evo/lr":evo_config["LR"]*(args.lr_decay**gen), 864 | }, 865 | step = gen 866 | ) 867 | 868 | all_fitness = [] 869 | 870 | for env in args.envs: 871 | 872 | rng, rng_eval = jax.random.split(rng) 873 | 874 | all_configs[env]["VISUALISE"] = False 875 | rollout = rollouts[env] 876 | 877 | # Antithetic task sampling for gridworld - antithetic perturbatiosn are evaluated on the same rng 878 | if args.envs == ["gridworld"]: 879 | batch_rng = jax.random.split(rng_eval, args.popsize / 2) 880 | batch_rng = jnp.repeat(batch_rng, 2, axis=0) 881 | 882 | else: 883 | batch_rng = jax.random.split(rng_eval, num_rollouts) 884 | batch_rng = jnp.tile(batch_rng, (args.popsize, 1, 1)) 885 | 886 | if args.pmap: 887 | batch_rng_pmap = jax.device_put(batch_rng, sharding_rng) 888 | fitness = rollout(batch_rng_pmap, x) 889 | fitness = jax.device_get(fitness) 890 | fitness = fitness.reshape(-1, evo_config["NUM_ROLLOUTS"]).mean(axis=1) 891 | 892 | else: 893 | batch_rng = jnp.reshape(batch_rng, (-1, num_rollouts, 2)) 894 | fitness = rollout(batch_rng, x) 895 | fitness = fitness.mean(axis=1) 896 | 897 | fitness = jnp.nan_to_num(fitness, nan=-100000) 898 | 899 | print(f"fitness: {fitness}") 900 | print(f"mean fitness_{env} = {jnp.mean(fitness):.3f}") 901 | print(f"fitness_spread at gen {gen} is {fitness.max()-fitness.min()}") 902 | 903 | fit_history.append(fitness.mean()) 904 | 905 | fitness_var = jnp.var(fitness) 906 | fitness_spread = fitness.max() - fitness.min() 907 | 908 | # PPO_TEMP is set to the return obtained by PPO with Adam, and is used to normalise across environments 909 | fit_norm = fitness / all_configs[env]["PPO_TEMP"] 910 | 911 | mean_norm = fit_norm.mean() 912 | 913 | wandb.log( 914 | { 915 | f"training/{env}/avg_fitness": fitness.mean(), 916 | f"training/{env}/fitness_histo": wandb.Histogram( 917 | fitness, num_bins=16 918 | ), 919 | f"training/{env}/fitness_spread": fitness_spread, 920 | f"training/{env}/fitness_variance": fitness_var, 921 | f"training/{env}/normalised_fitness": mean_norm, 922 | f"training/{env}/normalised_histo": wandb.Histogram( 923 | fit_norm, num_bins=16 924 | ), 925 | f"training/{env}/best_fitness": jnp.max(fitness), 926 | f"training/{env}/worst_fitness": jnp.min(fitness), 927 | }, 928 | step=gen, 929 | ) 930 | 931 | all_fitness.append(fit_norm) 932 | 933 | fitnesses = jnp.stack(all_fitness, axis=0) 934 | fitnesses_mean = jnp.mean(fitnesses, axis=0) 935 | 936 | wandb.log( 937 | { 938 | "training/average_over_env/normalised_fitness": fitnesses_mean.mean(), 939 | "training/average_over_env/normalised_histo": wandb.Histogram( 940 | fitnesses_mean, num_bins=16 941 | ), 942 | }, 943 | step=gen, 944 | ) 945 | 946 | param_sum = state.mean.sum() 947 | param_abs_sum = jnp.abs(state.mean).sum() 948 | param_abs_mean = jnp.abs(state.mean).mean() 949 | 950 | wandb.log( 951 | { 952 | "param_stats/param_sum": state.mean.sum(), 953 | "param_stats/param_abs_sum": jnp.abs(state.mean).sum(), 954 | "param_stats/param_abs_mean": jnp.abs(state.mean).mean(), 955 | "param_stats/param_mean": state.mean.mean(), 956 | }, 957 | step=gen, 958 | ) 959 | 960 | # normalization for antithetic task sampling 961 | if args.envs == ["gridworld"]: 962 | first_greater = jnp.greater(fitness[::2], fitness[1::2]) 963 | rank_fitness = jnp.zeros_like(fitness) 964 | rank_fitness = rank_fitness.at[::2].set(-1 * first_greater.astype(float)) 965 | fitnesses_mean = rank_fitness.at[1::2].set( 966 | first_greater.astype(float) - 1.0 967 | ) 968 | 969 | # use standard key as update is deterministic 970 | # multiply fitness by -1 since the new evosax is only set to minimise! 971 | x = jax.device_get(x) 972 | state, metrics = jax.jit(strategy.tell)( 973 | jax.random.PRNGKey(0), x, -1 * fitnesses_mean, state, es_params 974 | ) 975 | wandb.log( 976 | {"training/average_over_env/best_fitness": -1 * metrics["best_fitness"]}, step=gen 977 | ) 978 | 979 | if gen % save_every_k_gens == 0: 980 | print("SAVING & EVALUATING!") 981 | jnp.save(osp.join(save_dir, f"curr_param_{gen}.npy"), state.mean) 982 | np.save(osp.join(save_dir, f"fit_history.npy"), np.array(fit_history)) 983 | 984 | wandb.save( 985 | osp.join(save_dir, f"curr_param_{gen}.npy"), 986 | base_path=save_dir, 987 | ) 988 | if args.envs == ["gridworld"]: 989 | plot_envs = gridtypes 990 | grids = True 991 | else: 992 | plot_envs = args.envs 993 | grids = False 994 | 995 | time.sleep(1) 996 | make_plot( 997 | exp_names=wandb.run.id, 998 | exp_nums=gen, 999 | envs=plot_envs, 1000 | num_runs=8, 1001 | pmap=args.pmap, 1002 | larger=args.larger, 1003 | training=True, 1004 | grids=grids, 1005 | params=strategy.get_mean(state), 1006 | mesh=mesh, 1007 | ) 1008 | -------------------------------------------------------------------------------- /rl_optimizer/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | from gymnax.environments import spaces, environment 5 | from typing import NamedTuple, Optional, Tuple, Union 6 | from brax import envs 7 | from flax import struct 8 | from functools import partial 9 | from gymnax.wrappers.purerl import GymnaxWrapper 10 | import chex 11 | from jax import vjp, flatten_util 12 | from jax.tree_util import tree_flatten 13 | 14 | 15 | class GymnaxGymWrapper: 16 | def __init__(self, env, env_params, config): 17 | self.env = env 18 | self.env_params = env_params 19 | if config["CONTINUOUS"]: 20 | self.action_space = env.action_space(self.env_params).shape[0] 21 | else: 22 | self.action_space = env.action_space(self.env_params).n 23 | 24 | self.observation_space = env.observation_space(self.env_params).shape 25 | 26 | def reset(self, rng, params=None): 27 | rng, _rng = jax.random.split(rng) 28 | obs, env_state = self.env.reset(_rng, self.env_params) 29 | state = (env_state, rng) 30 | return obs, state 31 | 32 | def step(self, key, state, action, params=None): 33 | env_state, rng = state 34 | rng, _rng = jax.random.split(rng) 35 | obs, env_state, reward, done, info = self.env.step( 36 | _rng, env_state, action, self.env_params 37 | ) 38 | state = (env_state, rng) 39 | return obs, state, reward, done, info 40 | 41 | 42 | class MetaGymnaxGymWrapper: 43 | def __init__(self, env, env_param_generator): 44 | self.env = env 45 | self.env_param_generator = env_param_generator 46 | self.observation_space = env.observation_space(self.env_params).shape 47 | self.action_space = env.action_space(self.env_params).n 48 | 49 | def reset(self, rng): 50 | rng, _rng = jax.random.split(rng) 51 | env_params = self.env_param_generator(_rng) 52 | rng, _rng = jax.random.split(rng) 53 | obs, env_state = self.env.reset(_rng, env_params) 54 | state = (env_state, env_params, rng) 55 | return obs, state 56 | 57 | def step(self, state, action): 58 | env_state, env_params, rng = state 59 | rng, _rng = jax.random.split(rng) 60 | obs, env_state, reward, done, info = self.env.step( 61 | _rng, env_state, action, env_params 62 | ) 63 | state = (env_state, env_params, rng) 64 | return obs, state, reward, done, info 65 | 66 | 67 | class FlatWrapper: 68 | def __init__(self, env): 69 | self.env = env 70 | self.observation_space = np.prod(env.observation_space) 71 | self.action_space = env.action_space 72 | 73 | def reset(self, rng, params=None): 74 | obs, env_state = self.env.reset(rng, params) 75 | obs = jnp.reshape(obs, (self.observation_space,)) 76 | return obs, env_state 77 | 78 | def step(self, key, state, action, params=None): 79 | obs, state, reward, done, info = self.env.step(key, state, action, params) 80 | obs = jnp.reshape(obs, (self.observation_space,)) 81 | return obs, state, reward, done, info 82 | 83 | 84 | class EpisodeStats(NamedTuple): 85 | episode_returns: float 86 | episode_lengths: int 87 | returned_episode_returns: float 88 | returned_episode_lengths: int 89 | 90 | 91 | class GymnaxLogWrapper(GymnaxWrapper): 92 | def __init__(self, env): 93 | self.env = env 94 | self.observation_space = env.observation_space 95 | self.action_space = env.action_space 96 | 97 | def reset(self, rng, params=None): 98 | obs, env_state = self.env.reset(rng, params) 99 | state = (env_state, EpisodeStats(0, 0, 0, 0)) 100 | return obs, state 101 | 102 | def step(self, key, state, action, params=None): 103 | # def step(self, state, action): 104 | env_state, episode_stats = state 105 | obs, env_state, reward, done, info = self.env.step( 106 | key, env_state, action, params 107 | ) 108 | new_episode_return = episode_stats.episode_returns + reward 109 | new_episode_length = episode_stats.episode_lengths + 1 110 | new_episode_stats = EpisodeStats( 111 | episode_returns=new_episode_return * (1 - done), 112 | episode_lengths=new_episode_length * (1 - done), 113 | returned_episode_returns=episode_stats.returned_episode_returns * (1 - done) 114 | + new_episode_return * done, 115 | returned_episode_lengths=episode_stats.returned_episode_lengths * (1 - done) 116 | + new_episode_length * done, 117 | ) 118 | state = (env_state, new_episode_stats) 119 | info = {} 120 | info["returned_episode_returns"] = new_episode_stats.returned_episode_returns 121 | info["returned_episode_lengths"] = new_episode_stats.returned_episode_lengths 122 | return obs, state, reward, done, info 123 | 124 | 125 | class EvalStats(NamedTuple): 126 | first_returned_episode_returns: float 127 | ever_done: bool 128 | 129 | 130 | class GymnaxLogEvalWrapper: 131 | def __init__(self, env): 132 | self.env = env 133 | self.observation_space = env.observation_space 134 | self.action_space = env.action_space 135 | 136 | def reset(self, rng, params=None): 137 | obs, env_state = self.env.reset(rng, params) 138 | state = (env_state, EvalStats(0, False)) 139 | return obs, state 140 | 141 | def step(self, key, state, action, params=None): 142 | env_state, episode_stats = state 143 | obs, env_state, reward, done, info = self.env.step( 144 | key, env_state, action, params 145 | ) 146 | ever_done = jnp.logical_or(episode_stats.ever_done, done) 147 | episode_return = episode_stats.first_returned_episode_returns + reward * ( 148 | 1 - ever_done 149 | ) 150 | new_episode_stats = EvalStats( 151 | first_returned_episode_returns=episode_return, 152 | ever_done=ever_done, 153 | ) 154 | state = (env_state, new_episode_stats) 155 | info = {} 156 | info["first_returned_episode_returns"] = ( 157 | new_episode_stats.first_returned_episode_returns 158 | ) 159 | info["ever_done"] = new_episode_stats.ever_done 160 | return obs, state, reward, done, info 161 | 162 | 163 | class AutoResetEnvWrapper(GymnaxWrapper): 164 | """Provides standard auto-reset functionality, providing the same behaviour as Gymnax-default.""" 165 | 166 | def __init__(self, env: environment.Environment): 167 | super().__init__(env) 168 | 169 | @partial(jax.jit, static_argnums=(0, 2)) 170 | def reset( 171 | self, key, params: Optional[environment.EnvParams] = None 172 | ) -> Tuple[chex.Array, environment.EnvState]: 173 | return self._env.reset(key, params) 174 | 175 | @partial(jax.jit, static_argnums=(0, 4)) 176 | def step(self, rng, state, action, params=None): 177 | 178 | rng, _rng = jax.random.split(rng) 179 | obs_st, state_st, reward, done, info = self._env.step( 180 | _rng, state, action, params 181 | ) 182 | 183 | rng, _rng = jax.random.split(rng) 184 | obs_re, state_re = self._env.reset(_rng, params) 185 | 186 | # Auto-reset environment based on termination 187 | def auto_reset(done, state_re, state_st, obs_re, obs_st): 188 | state = jax.tree_util.tree_map( 189 | lambda x, y: jax.lax.select(done, x, y), state_re, state_st 190 | ) 191 | obs = jax.lax.select(done, obs_re, obs_st) 192 | 193 | return obs, state 194 | 195 | obs, state = auto_reset(done, state_re, state_st, obs_re, obs_st) 196 | 197 | return obs, state, reward, done, info 198 | 199 | 200 | class BraxGymnaxWrapper: 201 | def __init__(self, env_name, backend="positional"): 202 | # def __init__(self, env_name, backend="generalized"): 203 | 204 | # ****** BACKEND CURRENTLY NOT IMPLEMENTED 205 | 206 | env = envs.get_environment(env_name=env_name, backend=backend) 207 | env = envs.wrappers.training.EpisodeWrapper( 208 | env, episode_length=1000, action_repeat=1 209 | ) 210 | env = envs.wrappers.training.AutoResetWrapper(env) 211 | self._env = env 212 | self.action_size = env.action_size 213 | self.observation_size = (env.observation_size,) 214 | 215 | def reset(self, key, params=None): 216 | # print(f'bye: {key}') 217 | state = self._env.reset(key) 218 | return state.obs, state 219 | 220 | def step(self, key, state, action, params): 221 | next_state = self._env.step(state, action) 222 | return next_state.obs, next_state, next_state.reward, next_state.done > 0.5, {} 223 | 224 | def observation_space(self, params): 225 | return spaces.Box( 226 | low=-jnp.inf, 227 | high=jnp.inf, 228 | shape=(self._env.observation_size,), 229 | ) 230 | 231 | def action_space(self, params): 232 | return spaces.Box( 233 | low=-1.0, 234 | high=1.0, 235 | shape=(self._env.action_size,), 236 | ) 237 | 238 | 239 | class GymnaxWrapper(object): 240 | """Base class for Gymnax wrappers.""" 241 | 242 | def __init__(self, env): 243 | self._env = env 244 | 245 | # provide proxy access to regular attributes of wrapped object 246 | def __getattr__(self, name): 247 | return getattr(self._env, name) 248 | 249 | 250 | class ClipAction(GymnaxWrapper): 251 | def __init__(self, env): 252 | super().__init__(env) 253 | # self.action_lim = config["ACTION_LIM"] 254 | 255 | def step(self, key, state, action, params=None): 256 | """TODO: FIX""" 257 | # old_action = action 258 | action = jnp.clip(action, -1, 1) 259 | # jax.debug.print('{old} -> {new}', old=old_action, new=action) 260 | # action = jnp.clip(action, -1.0, 1.0) 261 | # jax.debug.print('{action}', action=action) 262 | return self._env.step(key, state, action, params) 263 | 264 | 265 | class TransformObservation(GymnaxWrapper): 266 | def __init__(self, env, transform_obs): 267 | super().__init__(env) 268 | self.transform_obs = transform_obs 269 | 270 | def reset(self, key, params=None): 271 | obs, state = self._env.reset(key, params) 272 | return self.transform_obs(obs), state 273 | 274 | def step(self, key, state, action, params=None): 275 | obs, state, reward, done, info = self._env.step(key, state, action, params) 276 | return self.transform_obs(obs), state, reward, done, info 277 | 278 | 279 | class TransformReward(GymnaxWrapper): 280 | def __init__(self, env, transform_reward): 281 | super().__init__(env) 282 | self.transform_reward = transform_reward 283 | 284 | def step(self, key, state, action, params=None): 285 | obs, state, reward, done, info = self._env.step(key, state, action, params) 286 | return obs, state, self.transform_reward(reward), done, info 287 | 288 | 289 | class VecEnv(GymnaxWrapper): 290 | def __init__(self, env): 291 | super().__init__(env) 292 | self.reset = jax.vmap(self._env.reset, in_axes=(0, None)) 293 | self.step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None)) 294 | 295 | 296 | @struct.dataclass 297 | class NormalizeObsEnvState: 298 | mean: jnp.ndarray 299 | var: jnp.ndarray 300 | count: float 301 | env_state: environment.EnvState 302 | 303 | 304 | class NormalizeObservation(GymnaxWrapper): 305 | def __init__(self, env): 306 | super().__init__(env) 307 | 308 | def reset(self, key, params=None): 309 | obs, state = self._env.reset(key, params) 310 | state = NormalizeObsEnvState( 311 | mean=jnp.zeros_like(obs), 312 | var=jnp.ones_like(obs), 313 | count=1e-4, 314 | env_state=state, 315 | ) 316 | batch_mean = jnp.mean(obs, axis=0) 317 | batch_var = jnp.var(obs, axis=0) 318 | batch_count = obs.shape[0] 319 | 320 | delta = batch_mean - state.mean 321 | tot_count = state.count + batch_count 322 | 323 | new_mean = state.mean + delta * batch_count / tot_count 324 | m_a = state.var * state.count 325 | m_b = batch_var * batch_count 326 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count 327 | new_var = M2 / tot_count 328 | new_count = tot_count 329 | 330 | state = NormalizeObsEnvState( 331 | mean=new_mean, 332 | var=new_var, 333 | count=new_count, 334 | env_state=state.env_state, 335 | ) 336 | 337 | return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state 338 | 339 | def step(self, key, state, action, params=None): 340 | obs, env_state, reward, done, info = self._env.step( 341 | key, state.env_state, action, params 342 | ) 343 | 344 | batch_mean = jnp.mean(obs, axis=0) 345 | batch_var = jnp.var(obs, axis=0) 346 | batch_count = obs.shape[0] 347 | 348 | delta = batch_mean - state.mean 349 | tot_count = state.count + batch_count 350 | 351 | new_mean = state.mean + delta * batch_count / tot_count 352 | m_a = state.var * state.count 353 | m_b = batch_var * batch_count 354 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count 355 | new_var = M2 / tot_count 356 | new_count = tot_count 357 | 358 | state = NormalizeObsEnvState( 359 | mean=new_mean, 360 | var=new_var, 361 | count=new_count, 362 | env_state=env_state, 363 | ) 364 | return ( 365 | (obs - state.mean) / jnp.sqrt(state.var + 1e-8), 366 | state, 367 | reward, 368 | done, 369 | info, 370 | ) 371 | 372 | 373 | @struct.dataclass 374 | class NormalizeRewEnvState: 375 | mean: jnp.ndarray 376 | var: jnp.ndarray 377 | count: float 378 | return_val: float 379 | env_state: environment.EnvState 380 | 381 | 382 | class NormalizeReward(GymnaxWrapper): 383 | def __init__(self, env, gamma): 384 | super().__init__(env) 385 | self.gamma = gamma 386 | 387 | def reset(self, key, params=None): 388 | obs, state = self._env.reset(key, params) 389 | batch_count = obs.shape[0] 390 | state = NormalizeRewEnvState( 391 | mean=0.0, 392 | var=1.0, 393 | count=1e-4, 394 | return_val=jnp.zeros((batch_count,)), 395 | env_state=state, 396 | ) 397 | return obs, state 398 | 399 | def step(self, key, state, action, params=None): 400 | obs, env_state, reward, done, info = self._env.step( 401 | key, state.env_state, action, params 402 | ) 403 | return_val = state.return_val * self.gamma * (1 - done) + reward 404 | 405 | batch_mean = jnp.mean(return_val, axis=0) 406 | batch_var = jnp.var(return_val, axis=0) 407 | batch_count = obs.shape[0] 408 | 409 | delta = batch_mean - state.mean 410 | tot_count = state.count + batch_count 411 | 412 | new_mean = state.mean + delta * batch_count / tot_count 413 | m_a = state.var * state.count 414 | m_b = batch_var * batch_count 415 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count 416 | new_var = M2 / tot_count 417 | new_count = tot_count 418 | 419 | state = NormalizeRewEnvState( 420 | mean=new_mean, 421 | var=new_var, 422 | count=new_count, 423 | return_val=return_val, 424 | env_state=env_state, 425 | ) 426 | return obs, state, reward / jnp.sqrt(state.var + 1e-8), done, info 427 | 428 | 429 | # Use parameter Reshaper from old gymnax 430 | def ravel_pytree(pytree): 431 | leaves, _ = tree_flatten(pytree) 432 | flat, _ = vjp(ravel_list, *leaves) 433 | return flat 434 | 435 | 436 | def ravel_list(*lst): 437 | return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([]) 438 | 439 | 440 | class ParameterReshaper(object): 441 | def __init__( 442 | self, 443 | placeholder_params: Union[chex.ArrayTree, chex.Array], 444 | n_devices: Optional[int] = None, 445 | verbose: bool = True, 446 | ): 447 | """Reshape flat parameters vectors into generation eval shape.""" 448 | # Get network shape to reshape 449 | self.placeholder_params = placeholder_params 450 | 451 | # Set total parameters depending on type of placeholder params 452 | flat, self.unravel_pytree = flatten_util.ravel_pytree(placeholder_params) 453 | self.total_params = flat.shape[0] 454 | self.reshape_single = jax.jit(self.unravel_pytree) 455 | 456 | if n_devices is None: 457 | self.n_devices = jax.local_device_count() 458 | else: 459 | self.n_devices = n_devices 460 | if self.n_devices > 1 and verbose: 461 | print( 462 | f"ParameterReshaper: {self.n_devices} devices detected. Please" 463 | " make sure that the ES population size divides evenly across" 464 | " the number of devices to pmap/parallelize over." 465 | ) 466 | 467 | if verbose: 468 | print( 469 | f"ParameterReshaper: {self.total_params} parameters detected" 470 | " for optimization." 471 | ) 472 | 473 | def reshape(self, x: chex.Array) -> chex.ArrayTree: 474 | """Perform reshaping for a 2D matrix (pop_members, params).""" 475 | vmap_shape = jax.vmap(self.reshape_single) 476 | if self.n_devices > 1: 477 | x = self.split_params_for_pmap(x) 478 | map_shape = jax.pmap(vmap_shape) 479 | else: 480 | map_shape = vmap_shape 481 | return map_shape(x) 482 | 483 | def multi_reshape(self, x: chex.Array) -> chex.ArrayTree: 484 | """Reshape parameters lying already on different devices.""" 485 | # No reshaping required! 486 | vmap_shape = jax.vmap(self.reshape_single) 487 | return jax.pmap(vmap_shape)(x) 488 | 489 | def flatten(self, x: chex.ArrayTree) -> chex.Array: 490 | """Reshaping pytree parameters into flat array.""" 491 | vmap_flat = jax.vmap(ravel_pytree) 492 | if self.n_devices > 1: 493 | # Flattening of pmap paramater trees to apply vmap flattening 494 | def map_flat(x): 495 | x_re = jax.tree_util.tree_map(lambda x: x.reshape(-1, *x.shape[2:]), x) 496 | return vmap_flat(x_re) 497 | 498 | else: 499 | map_flat = vmap_flat 500 | flat = map_flat(x) 501 | return flat 502 | 503 | def multi_flatten(self, x: chex.Array) -> chex.ArrayTree: 504 | """Flatten parameters lying remaining on different devices.""" 505 | # No reshaping required! 506 | vmap_flat = jax.vmap(ravel_pytree) 507 | return jax.pmap(vmap_flat)(x) 508 | 509 | def split_params_for_pmap(self, param: chex.Array) -> chex.Array: 510 | """Helper reshapes param (bs, #params) into (#dev, bs/#dev, #params).""" 511 | return jnp.stack(jnp.split(param, self.n_devices)) 512 | 513 | @property 514 | def vmap_dict(self) -> chex.ArrayTree: 515 | """Get a dictionary specifying axes to vmap over.""" 516 | vmap_dict = jax.tree_util.tree_map(lambda x: 0, self.placeholder_params) 517 | return vmap_dict 518 | -------------------------------------------------------------------------------- /run_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if GPU names are provided 4 | if [ -z "$1" ]; then 5 | echo "Please provide GPU names (e.g., 0,1,2)" 6 | exit 1 7 | fi 8 | 9 | # GPU names passed as the first argument 10 | GPU_NAMES=$1 11 | 12 | # Run the Docker command with the specified GPUs. shm-size is needed for sharding, as Docker defaults to tiny RAM. Feel free to change this if it causes issues. 13 | docker run -it --rm --gpus '"device='$GPU_NAMES'"' -v $(pwd):/rl_optimizer -w /rl_optimizer/rl_optimizer --shm-size=5g open -------------------------------------------------------------------------------- /setup/Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 3 | 4 | ARG DEBIAN_FRONTEND=noninteractive 5 | 6 | RUN apt update && apt install python3-pip -y 7 | RUN pip3 install --upgrade pip 8 | RUN apt-get install tmux -y 9 | RUN apt-get install vim -y 10 | RUN apt install libglfw3-dev -y 11 | RUN apt install libglfw3 -y 12 | RUN apt-get update && apt-get install -y git 13 | RUN pip3 install --upgrade pip setuptools wheel 14 | 15 | COPY requirements.txt /tmp/requirements.txt 16 | # Need to use specific cuda versions for jax 17 | ARG USE_CUDA=true 18 | RUN if [ "$USE_CUDA" = true ] ; \ 19 | then pip install "jax[cuda12]>=0.4.25, <0.6.0" -f "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" ; \ 20 | fi 21 | RUN pip3 install -r /tmp/requirements.txt 22 | 23 | ARG WANDB_API 24 | ARG USERNAME 25 | ARG USER_UID 26 | ARG USER_GID 27 | 28 | RUN groupadd --gid $USER_GID $USERNAME \ 29 | && useradd --uid $USER_UID --gid $USER_GID --create-home $USERNAME \ 30 | && chown -R $USER_UID:$USER_GID /home/$USERNAME 31 | 32 | USER $USERNAME 33 | 34 | ENV WANDB_API_KEY=$WANDB_API 35 | 36 | WORKDIR rl_optimizer/ 37 | 38 | CMD ["/bin/bash"] 39 | -------------------------------------------------------------------------------- /setup/build_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set default image name 4 | IMAGE_NAME="open" 5 | 6 | # Get current user info 7 | USER_NAME=$(id -un) 8 | USER_ID=$(id -u) 9 | GROUP_ID=$(id -g) 10 | 11 | # Check if WANDB_API key is provided 12 | if [ -z "$1" ]; then 13 | echo "Please provide WANDB_API_KEY key as argument" 14 | exit 1 15 | fi 16 | WANDB_API=$1 17 | 18 | # Build the Docker image with current user info 19 | docker build \ 20 | --build-arg USERNAME="${USER_NAME}" \ 21 | --build-arg USER_UID="${USER_ID}" \ 22 | --build-arg USER_GID="${GROUP_ID}" \ 23 | --build-arg WANDB_API="${WANDB_API}" \ 24 | -t ${IMAGE_NAME} \ 25 | . -------------------------------------------------------------------------------- /setup/requirements.txt: -------------------------------------------------------------------------------- 1 | brax>=0.9.0 2 | distrax @ git+https://github.com/google-deepmind/distrax # distrax release doesn't support jax > 0.4.13 3 | evosax @ git+https://github.com/RobertTLange/evosax.git 4 | flax 5 | gymnasium 6 | gymnax==0.0.6 # Later gymnax versions have bugs that prevent running certain envs 7 | jaxlib 8 | numpy 9 | optax @ git+https://github.com/google-deepmind/optax.git 10 | tqdm 11 | wandb 12 | gin-config 13 | seaborn --------------------------------------------------------------------------------