├── .gitignore ├── LICENSE ├── README.md ├── RESOURCES.md ├── docs ├── cartpole_plot_parallel.png ├── cartpole_plot_parallel_old.png ├── cartpole_plot_seconds.png ├── minatar_plot_parallel.png ├── minatar_plot_parallel_old.png └── minatar_plot_seconds.png ├── examples ├── brax_minatar.ipynb └── walkthrough.ipynb ├── purejaxrl ├── dpo_continuous_action.py ├── dqn.py ├── experimental │ └── s5 │ │ ├── README.md │ │ ├── ppo_s5.py │ │ ├── s5.py │ │ └── wrappers.py ├── ppo.py ├── ppo_continuous_action.py ├── ppo_minigrid.py ├── ppo_rnn.py └── wrappers.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[co] 2 | __pycache__/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Chris Lu 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PureJaxRL (End-to-End RL Training in Pure Jax) 2 | 3 | [](https://github.com/luchris429/purejaxrl/LICENSE) 4 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 5 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/luchris429/purejaxrl/blob/main/examples/walkthrough.ipynb) 6 | 7 | PureJaxRL is a high-performance, end-to-end Jax Reinforcement Learning (RL) implementation. When running many agents in parallel on GPUs, our implementation is over 1000x faster than standard PyTorch RL implementations. Unlike other Jax RL implementations, we implement the *entire training pipeline in JAX*, including the environment. This allows us to get significant speedups through JIT compilation and by avoiding CPU-GPU data transfer. It also results in easier debugging because the system is fully synchronous. More importantly, this code allows you to use jax to `jit`, `vmap`, `pmap`, and `scan` entire RL training pipelines. With this, we can: 8 | 9 | - 🏃 Efficiently run tons of seeds in parallel on one GPU 10 | - 💻 Perform rapid hyperparameter tuning 11 | - 🦎 Discover new RL algorithms with meta-evolution 12 | 13 | For more details, visit the accompanying blog post: https://chrislu.page/blog/meta-disco/ 14 | 15 | This notebook walks through the basic usage: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/luchris429/purejaxrl/blob/main/examples/walkthrough.ipynb) 16 | 17 | ## CHECK OUT [RESOURCES.MD](https://github.com/luchris429/purejaxrl/blob/main/RESOURCES.md) to see github repos that are part of the Jax RL Ecosystem! 18 | 19 | ## Performance 20 | 21 | Without vectorization, our implementation runs 10x faster than [CleanRL's PyTorch baselines](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py), as shown in the single-thread performance plot. 22 | 23 | Cartpole | Minatar-Breakout 24 | :-------------------------:|:-------------------------: 25 | ![](docs/cartpole_plot_seconds.png) | ![](docs/minatar_plot_seconds.png) 26 | 27 | 28 | With vectorized training, we can train 2048 PPO agents in half the time it takes to train a single PyTorch PPO agent on a single GPU. The vectorized agent training allows for simultaneous training across multiple seeds, rapid hyperparameter tuning, and even evolutionary Meta-RL. 29 | 30 | Vectorised Cartpole | Vectorised Minatar-Breakout 31 | :-------------------------:|:-------------------------: 32 | ![](docs/cartpole_plot_parallel.png) | ![](docs/minatar_plot_parallel.png) 33 | 34 | 35 | ## Code Philosophy 36 | 37 | PureJaxRL is inspired by [CleanRL](https://github.com/vwxyzjn/cleanrl), providing high-quality single-file implementations with research-friendly features. Like CleanRL, this is not a modular library and is not meant to be imported. The repository focuses on simplicity and clarity in its implementations, making it an excellent resource for researchers and practitioners. 38 | 39 | ## Installation 40 | 41 | Install dependencies using the requirements.txt file: 42 | 43 | ``` 44 | pip install -r requirements.txt 45 | ``` 46 | 47 | In order to use JAX on your accelerators, you can find more details in the [JAX documentation](https://github.com/google/jax#installation). 48 | 49 | ## Example Usage 50 | 51 | [`examples/walkthrough.ipynb`](https://github.com/luchris429/purejaxrl/blob/main/examples/walkthrough.ipynb) walks through the basic usage. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/luchris429/purejaxrl/blob/main/examples/walkthrough.ipynb) 52 | 53 | [`examples/brax_minatar.ipynb`](https://github.com/luchris429/purejaxrl/blob/main/examples/brax_minatar.ipynb) walks through using PureJaxRL for Brax and MinAtar. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/luchris429/purejaxrl/blob/main/examples/brax_minatar.ipynb) 54 | 55 | ## Related Work 56 | 57 | Check out the list of [RESOURCES](https://github.com/luchris429/purejaxrl/blob/main/RESOURCES.md) to see libraries that are closely related to PureJaxRL! 58 | 59 | The following repositories and projects were pre-cursors to `purejaxrl`: 60 | 61 | - [Model-Free Opponent Shaping](https://arxiv.org/abs/2205.01447) (ICML 2022) (https://github.com/luchris429/Model-Free-Opponent-Shaping) 62 | 63 | - [Discovered Policy Optimisation](https://arxiv.org/abs/2210.05639) (NeurIPS 2022) (https://github.com/luchris429/discovered-policy-optimisation) 64 | 65 | - [Adversarial Cheap Talk](https://arxiv.org/abs/2211.11030) (ICML 2023) (https://github.com/luchris429/adversarial-cheap-talk) 66 | 67 | ## Citation 68 | 69 | If you use PureJaxRL in your work, please cite the following paper: 70 | 71 | ``` 72 | @article{lu2022discovered, 73 | title={Discovered policy optimisation}, 74 | author={Lu, Chris and Kuba, Jakub and Letcher, Alistair and Metz, Luke and Schroeder de Witt, Christian and Foerster, Jakob}, 75 | journal={Advances in Neural Information Processing Systems}, 76 | volume={35}, 77 | pages={16455--16468}, 78 | year={2022} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /RESOURCES.md: -------------------------------------------------------------------------------- 1 | # PureJaxRL Resources 2 | 3 | Last year, I released [PureJaxRL](https://github.com/luchris429/purejaxrl), a simple repository that implements RL algorithms entirely end-to-end in JAX, which enables speedups of up to 4000x in RL training. PureJaxRL, in turn, was inspired by multiple projects, including [CleanRL](https://github.com/vwxyzjn/cleanrl) and [Gymnax](https://github.com/RobertTLange/gymnax). Since the release of PureJaxRL, a large number of projects related to or inspired by PureJaxRL have come out, vastly expanding its use case from standard single-agent RL settings. This curated list contains those projects alongside other relevant implementations of algorithms, environments, tools, and tutorials. 4 | 5 | To understand more about the benefits PureJaxRL, I recommend viewing the [original blog post](https://chrislu.page/blog/meta-disco/) or [tweet thread](https://x.com/_chris_lu_/status/1643992216413831171). 6 | 7 | The PureJaxRL repository can be found here: 8 | 9 | [https://github.com/luchris429/purejaxrl/](https://github.com/luchris429/purejaxrl/). 10 | 11 | The format of the list is from [awesome](https://github.com/sindresorhus/awesome) and [awesome-jax](https://github.com/n2cholas/awesome-jax). While this list is curated, it is certainly not complete. If you have a repository you would like to add, please contribute! 12 | 13 | If you find this resource useful, please *star* the repo! It helps establish and grow the end-to-end JAX RL community. 14 | 15 | ## Contents 16 | 17 | - [Algorithms](#algorithms) 18 | - [Environments](#environments) 19 | - [Related Components](#components) 20 | - [Tutorials and Blog Posts](#tutorials-and-blog-posts) 21 | - [Related Papers](#papers) 22 | 23 | ## Algorithms 24 | 25 | ### End-to-End JAX RL Implementations 26 | 27 | - [purejaxrl](https://github.com/luchris429/purejaxrl) - Classic and simple end-to-end RL training in pure JAX. 28 | 29 | - [rejax](https://github.com/keraJLi/rejax) - Modular and importable end-to-end JAX RL training. 30 | 31 | - [Stoix](https://github.com/EdanToledo/Stoix) - End-to-end JAX RL training with advanced logging, configs, and more. 32 | 33 | - [purejaxql](https://github.com/mttga/purejaxql/) - Simple single-file end-to-end JAX baselines for Q-Learning. 34 | 35 | - [jym](https://github.com/rpegoud/jym) - Educational and beginner-friendly end-to-end JAX RL training. 36 | 37 | ### Jax RL (But Not End-to-End) Repos 38 | 39 | - [cleanrl](https://github.com/vwxyzjn/cleanrl) - Clean implementations of RL Algorithms (in both PyTorch and JAX!). 40 | 41 | - [jaxrl](https://github.com/ikostrikov/jaxrl) - JAX implementation of algorithms for Deep Reinforcement Learning with continuous action spaces. 42 | 43 | - [rlbase](https://github.com/kvfrans/rlbase_stable) - Single-file JAX implementations of Deep RL algorithms. 44 | 45 | ### Multi-Agent RL 46 | 47 | - [JaxMARL](https://github.com/FLAIROx/JaxMARL) - Multi-Agent RL Algorithms and Environments in pure JAX. 48 | 49 | - [Mava](https://github.com/instadeepai/Mava) - Multi-Agent RL Algorithms in pure JAX (previously tensorflow-based algorithms). 50 | 51 | - [pax](https://github.com/ucl-dark/pax) - Scalable Opponent Shaping Algorithms in pure JAX. 52 | 53 | ### Offline RL 54 | 55 | - [JAX-CORL](https://github.com/nissymori/JAX-CORL) - Single-file implementations of offline RL algorithms in JAX. 56 | 57 | ### Inverse-RL 58 | 59 | - [jaxirl](https://github.com/FLAIROx/jaxirl) - Pure JAX for Inverse Reinforcement Learning. 60 | 61 | ### Unsupervised Environment Design 62 | 63 | - [minimax](https://github.com/facebookresearch/minimax) - Canonical implementations of UED algorithms in pure JAX, including SSM-based acceleration. 64 | 65 | - [jaxued](https://github.com/DramaCow/jaxued) - Single-file implementations of UED algorithms in pure JAX. 66 | 67 | ### Quality-Diversity 68 | 69 | - [QDax](https://github.com/adaptive-intelligent-robotics/QDax) - Quality-Diversity algorithms in pure JAX. 70 | 71 | ### Partially-Observed RL 72 | 73 | - [popjaxrl](https://github.com/luchris429/popjaxrl) - Partially-observed RL environments (POPGym) and architectures (incl. SSM's) in pure JAX. 74 | 75 | ### Meta-Learning RL Objectives 76 | 77 | - [groove](https://github.com/EmptyJackson/groove) - Library for [LPG-like](https://arxiv.org/abs/2007.08794) meta-RL in Pure JAX. 78 | 79 | - [discovered-policy-optimisation](https://github.com/luchris429/discovered-policy-optimisation) - Library for [LPO](https://arxiv.org/abs/2210.05639) meta-RL in Pure JAX. 80 | 81 | - [rl-learned-optimization](https://github.com/AlexGoldie/rl-learned-optimization) - Library for [OPEN](https://arxiv.org/abs/2407.07082) in Pure JAX. 82 | 83 | ## Environments 84 | 85 | - [gymnax](https://github.com/RobertTLange/gymnax) - Classic RL environments in JAX. 86 | 87 | - [brax](https://github.com/google/brax) - Continuous control environments in JAX. 88 | 89 | - [JaxMARL](https://github.com/FLAIROx/JaxMARL) - Multi-agent algorithms and environments in pure JAX. 90 | 91 | - [jumanji](https://github.com/instadeepai/jumanji) - Suite of unique RL environments in JAX. 92 | 93 | - [pgx](https://github.com/sotetsuk/pgx) - Suite of popular board games in JAX. 94 | 95 | - [popjaxrl](https://github.com/luchris429/popjaxrl) - Partially-observed RL environments (POPGym) in JAX. 96 | 97 | - [waymax](https://github.com/waymo-research/waymax) - Self-driving car simulator in JAX. 98 | 99 | - [Craftax](https://github.com/MichaelTMatthews/Craftax) - A challenging crafter-like and nethack-inspired benchmark in JAX. 100 | 101 | - [xland-minigrid](https://github.com/corl-team/xland-minigrid) - A large-scale meta-RL environment in JAX. 102 | 103 | - [navix](https://github.com/epignatelli/navix) - Classic minigrid environments in JAX. 104 | 105 | - [autoverse](https://github.com/smearle/autoverse) - A fast, evolvable description language for reinforcement learning environments. 106 | 107 | - [qdx](https://github.com/jolle-ag/qdx) - Quantum Error Corection with JAX. 108 | 109 | - [matrax](https://github.com/instadeepai/matrax) - Matrix games in JAX. 110 | 111 | - [AlphaTrade](https://github.com/KangOxford/AlphaTrade) - Limit Order Book (LOB) in JAX. 112 | 113 | ## Relevant Tools and Components 114 | 115 | - [evosax](https://github.com/RobertTLange/evosax) - Evolution strategies in JAX. 116 | 117 | - [evojax](https://github.com/google/evojax) - Evolution strategies in JAX. 118 | 119 | - [flashbax](https://github.com/instadeepai/flashbax) - Accelerated replay buffers in JAX. 120 | 121 | - [dejax](https://github.com/hr0nix/dejax) - Accelerated replay buffers in JAX. 122 | 123 | - [rlax](https://github.com/google-deepmind/rlax) - RL components and building blocks in JAX. 124 | 125 | - [mctx](https://github.com/google-deepmind/mctx) - Monte Carlo tree searh in JAX. 126 | 127 | - [distrax](https://github.com/google-deepmind/distrax) - Distributions in JAX. 128 | 129 | - [optax](https://github.com/google-deepmind/optax) - Gradient-based optimizers in JAX. 130 | 131 | - [flax](https://github.com/google/flax) - Neural Networks in JAX. 132 | 133 | ## Tutorials and Blog Posts 134 | 135 | - [Achieving 4000x Speedups with PureJaxRL](https://chrislu.page/blog/meta-disco/) - A blog post on how JAX can massively speedup RL training through vectorisation. 136 | 137 | - [Breaking down State-of-the-Art PPO Implementations in JAX](https://towardsdatascience.com/breaking-down-state-of-the-art-ppo-implementations-in-jax-6f102c06c149) - A blog post explaining PureJaxRL's PPO Implementation in depth. 138 | 139 | - [A Gentle Introduction to Deep Reinforcement Learning in JAX](https://towardsdatascience.com/a-gentle-introduction-to-deep-reinforcement-learning-in-jax-c1e45a179b92) - A JAX tutorial on Deep RL. 140 | 141 | - [Writing an RL Environment in JAX](https://medium.com/@ngoodger_7766/writing-an-rl-environment-in-jax-9f74338898ba) - A JAX tutorial on making environments. 142 | 143 | - [Getting started with JAX (MLPs, CNNs & RNNs)](https://roberttlange.com/posts/2020/03/blog-post-10/) - A basic JAX neural network tutorial. 144 | 145 | - [awesome-jax](https://github.com/n2cholas/awesome-jax) - A list of useful libraries in JAX 146 | -------------------------------------------------------------------------------- /docs/cartpole_plot_parallel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luchris429/purejaxrl/31756b197773a52db763fdbe6d635e4b46522a73/docs/cartpole_plot_parallel.png -------------------------------------------------------------------------------- /docs/cartpole_plot_parallel_old.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luchris429/purejaxrl/31756b197773a52db763fdbe6d635e4b46522a73/docs/cartpole_plot_parallel_old.png -------------------------------------------------------------------------------- /docs/cartpole_plot_seconds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luchris429/purejaxrl/31756b197773a52db763fdbe6d635e4b46522a73/docs/cartpole_plot_seconds.png -------------------------------------------------------------------------------- /docs/minatar_plot_parallel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luchris429/purejaxrl/31756b197773a52db763fdbe6d635e4b46522a73/docs/minatar_plot_parallel.png -------------------------------------------------------------------------------- /docs/minatar_plot_parallel_old.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luchris429/purejaxrl/31756b197773a52db763fdbe6d635e4b46522a73/docs/minatar_plot_parallel_old.png -------------------------------------------------------------------------------- /docs/minatar_plot_seconds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luchris429/purejaxrl/31756b197773a52db763fdbe6d635e4b46522a73/docs/minatar_plot_seconds.png -------------------------------------------------------------------------------- /purejaxrl/dpo_continuous_action.py: -------------------------------------------------------------------------------- 1 | """Re-implementation of Discovered Policy Optimisation (DPO) 2 | 3 | https://arxiv.org/abs/2210.05639 4 | 5 | This differs from PPO in just a few lines of the policy objective. 6 | 7 | Please refer to the paper for more details. 8 | """ 9 | import jax 10 | import jax.numpy as jnp 11 | import flax.linen as nn 12 | import numpy as np 13 | import optax 14 | from flax.linen.initializers import constant, orthogonal 15 | from typing import Sequence, NamedTuple, Any 16 | from flax.training.train_state import TrainState 17 | import distrax 18 | from wrappers import ( 19 | LogWrapper, 20 | BraxGymnaxWrapper, 21 | VecEnv, 22 | NormalizeVecObservation, 23 | NormalizeVecReward, 24 | ClipAction, 25 | ) 26 | 27 | 28 | class ActorCritic(nn.Module): 29 | action_dim: Sequence[int] 30 | activation: str = "tanh" 31 | 32 | @nn.compact 33 | def __call__(self, x): 34 | if self.activation == "relu": 35 | activation = nn.relu 36 | else: 37 | activation = nn.tanh 38 | actor_mean = nn.Dense( 39 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 40 | )(x) 41 | actor_mean = activation(actor_mean) 42 | actor_mean = nn.Dense( 43 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 44 | )(actor_mean) 45 | actor_mean = activation(actor_mean) 46 | actor_mean = nn.Dense( 47 | self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) 48 | )(actor_mean) 49 | actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,)) 50 | pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd)) 51 | 52 | critic = nn.Dense( 53 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 54 | )(x) 55 | critic = activation(critic) 56 | critic = nn.Dense( 57 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 58 | )(critic) 59 | critic = activation(critic) 60 | critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( 61 | critic 62 | ) 63 | 64 | return pi, jnp.squeeze(critic, axis=-1) 65 | 66 | 67 | class Transition(NamedTuple): 68 | done: jnp.ndarray 69 | action: jnp.ndarray 70 | value: jnp.ndarray 71 | reward: jnp.ndarray 72 | log_prob: jnp.ndarray 73 | obs: jnp.ndarray 74 | info: jnp.ndarray 75 | 76 | 77 | def make_train(config): 78 | config["NUM_UPDATES"] = ( 79 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] 80 | ) 81 | config["MINIBATCH_SIZE"] = ( 82 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] 83 | ) 84 | env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None 85 | env = LogWrapper(env) 86 | env = ClipAction(env) 87 | env = VecEnv(env) 88 | if config["NORMALIZE_ENV"]: 89 | env = NormalizeVecObservation(env) 90 | env = NormalizeVecReward(env, config["GAMMA"]) 91 | 92 | def linear_schedule(count): 93 | frac = ( 94 | 1.0 95 | - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) 96 | / config["NUM_UPDATES"] 97 | ) 98 | return config["LR"] * frac 99 | 100 | def train(rng): 101 | # INIT NETWORK 102 | network = ActorCritic( 103 | env.action_space(env_params).shape[0], activation=config["ACTIVATION"] 104 | ) 105 | rng, _rng = jax.random.split(rng) 106 | init_x = jnp.zeros(env.observation_space(env_params).shape) 107 | network_params = network.init(_rng, init_x) 108 | if config["ANNEAL_LR"]: 109 | tx = optax.chain( 110 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), 111 | optax.adam(learning_rate=linear_schedule, eps=1e-5), 112 | ) 113 | else: 114 | tx = optax.chain( 115 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), 116 | optax.adam(config["LR"], eps=1e-5), 117 | ) 118 | train_state = TrainState.create( 119 | apply_fn=network.apply, 120 | params=network_params, 121 | tx=tx, 122 | ) 123 | 124 | # INIT ENV 125 | rng, _rng = jax.random.split(rng) 126 | reset_rng = jax.random.split(_rng, config["NUM_ENVS"]) 127 | obsv, env_state = env.reset(reset_rng, env_params) 128 | 129 | # TRAIN LOOP 130 | def _update_step(runner_state, unused): 131 | # COLLECT TRAJECTORIES 132 | def _env_step(runner_state, unused): 133 | train_state, env_state, last_obs, rng = runner_state 134 | 135 | # SELECT ACTION 136 | rng, _rng = jax.random.split(rng) 137 | pi, value = network.apply(train_state.params, last_obs) 138 | action = pi.sample(seed=_rng) 139 | log_prob = pi.log_prob(action) 140 | 141 | # STEP ENV 142 | rng, _rng = jax.random.split(rng) 143 | rng_step = jax.random.split(_rng, config["NUM_ENVS"]) 144 | obsv, env_state, reward, done, info = env.step( 145 | rng_step, env_state, action, env_params 146 | ) 147 | transition = Transition( 148 | done, action, value, reward, log_prob, last_obs, info 149 | ) 150 | runner_state = (train_state, env_state, obsv, rng) 151 | return runner_state, transition 152 | 153 | runner_state, traj_batch = jax.lax.scan( 154 | _env_step, runner_state, None, config["NUM_STEPS"] 155 | ) 156 | 157 | # CALCULATE ADVANTAGE 158 | train_state, env_state, last_obs, rng = runner_state 159 | _, last_val = network.apply(train_state.params, last_obs) 160 | 161 | def _calculate_gae(traj_batch, last_val): 162 | def _get_advantages(gae_and_next_value, transition): 163 | gae, next_value = gae_and_next_value 164 | done, value, reward = ( 165 | transition.done, 166 | transition.value, 167 | transition.reward, 168 | ) 169 | delta = reward + config["GAMMA"] * next_value * (1 - done) - value 170 | gae = ( 171 | delta 172 | + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae 173 | ) 174 | return (gae, value), gae 175 | 176 | _, advantages = jax.lax.scan( 177 | _get_advantages, 178 | (jnp.zeros_like(last_val), last_val), 179 | traj_batch, 180 | reverse=True, 181 | unroll=16, 182 | ) 183 | return advantages, advantages + traj_batch.value 184 | 185 | advantages, targets = _calculate_gae(traj_batch, last_val) 186 | 187 | # UPDATE NETWORK 188 | def _update_epoch(update_state, unused): 189 | def _update_minbatch(train_state, batch_info): 190 | traj_batch, advantages, targets = batch_info 191 | 192 | def _loss_fn(params, traj_batch, gae, targets): 193 | # RERUN NETWORK 194 | pi, value = network.apply(params, traj_batch.obs) 195 | log_prob = pi.log_prob(traj_batch.action) 196 | 197 | # CALCULATE VALUE LOSS 198 | value_pred_clipped = traj_batch.value + ( 199 | value - traj_batch.value 200 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) 201 | value_losses = jnp.square(value - targets) 202 | value_losses_clipped = jnp.square(value_pred_clipped - targets) 203 | value_loss = ( 204 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() 205 | ) 206 | 207 | # CALCULATE ACTOR LOSS 208 | alpha = config["DPO_ALPHA"] 209 | beta = config["DPO_BETA"] 210 | log_diff = log_prob - traj_batch.log_prob 211 | ratio = jnp.exp(log_diff) 212 | gae = (gae - gae.mean()) / (gae.std() + 1e-8) 213 | is_pos = (gae >= 0.0).astype("float32") 214 | r1 = ratio - 1.0 215 | drift1 = nn.relu(r1 * gae - alpha * nn.tanh(r1 * gae / alpha)) 216 | drift2 = nn.relu( 217 | log_diff * gae - beta * nn.tanh(log_diff * gae / beta) 218 | ) 219 | drift = drift1 * is_pos + drift2 * (1 - is_pos) 220 | loss_actor = -(ratio * gae - drift).mean() 221 | entropy = pi.entropy().mean() 222 | 223 | total_loss = ( 224 | loss_actor 225 | + config["VF_COEF"] * value_loss 226 | - config["ENT_COEF"] * entropy 227 | ) 228 | return total_loss, (value_loss, loss_actor, entropy) 229 | 230 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) 231 | total_loss, grads = grad_fn( 232 | train_state.params, traj_batch, advantages, targets 233 | ) 234 | train_state = train_state.apply_gradients(grads=grads) 235 | return train_state, total_loss 236 | 237 | train_state, traj_batch, advantages, targets, rng = update_state 238 | rng, _rng = jax.random.split(rng) 239 | batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"] 240 | assert ( 241 | batch_size == config["NUM_STEPS"] * config["NUM_ENVS"] 242 | ), "batch size must be equal to number of steps * number of envs" 243 | permutation = jax.random.permutation(_rng, batch_size) 244 | batch = (traj_batch, advantages, targets) 245 | batch = jax.tree_util.tree_map( 246 | lambda x: x.reshape((batch_size,) + x.shape[2:]), batch 247 | ) 248 | shuffled_batch = jax.tree_util.tree_map( 249 | lambda x: jnp.take(x, permutation, axis=0), batch 250 | ) 251 | minibatches = jax.tree_util.tree_map( 252 | lambda x: jnp.reshape( 253 | x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) 254 | ), 255 | shuffled_batch, 256 | ) 257 | train_state, total_loss = jax.lax.scan( 258 | _update_minbatch, train_state, minibatches 259 | ) 260 | update_state = (train_state, traj_batch, advantages, targets, rng) 261 | return update_state, total_loss 262 | 263 | update_state = (train_state, traj_batch, advantages, targets, rng) 264 | update_state, loss_info = jax.lax.scan( 265 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"] 266 | ) 267 | train_state = update_state[0] 268 | metric = traj_batch.info 269 | rng = update_state[-1] 270 | if config.get("DEBUG"): 271 | 272 | def callback(info): 273 | return_values = info["returned_episode_returns"][ 274 | info["returned_episode"] 275 | ] 276 | timesteps = ( 277 | info["timestep"][info["returned_episode"]] * config["NUM_ENVS"] 278 | ) 279 | for t in range(len(timesteps)): 280 | print( 281 | f"global step={timesteps[t]}, episodic return={return_values[t]}" 282 | ) 283 | 284 | jax.debug.callback(callback, metric) 285 | 286 | runner_state = (train_state, env_state, last_obs, rng) 287 | return runner_state, metric 288 | 289 | rng, _rng = jax.random.split(rng) 290 | runner_state = (train_state, env_state, obsv, _rng) 291 | runner_state, metric = jax.lax.scan( 292 | _update_step, runner_state, None, config["NUM_UPDATES"] 293 | ) 294 | return {"runner_state": runner_state, "metrics": metric} 295 | 296 | return train 297 | 298 | 299 | if __name__ == "__main__": 300 | config = { 301 | "LR": 3e-4, 302 | "NUM_ENVS": 2048, 303 | "NUM_STEPS": 10, 304 | "TOTAL_TIMESTEPS": 5e7, 305 | "UPDATE_EPOCHS": 4, 306 | "NUM_MINIBATCHES": 32, 307 | "GAMMA": 0.99, 308 | "GAE_LAMBDA": 0.95, 309 | "CLIP_EPS": 0.2, 310 | "DPO_ALPHA": 2.0, 311 | "DPO_BETA": 0.6, 312 | "ENT_COEF": 0.0, 313 | "VF_COEF": 0.5, 314 | "MAX_GRAD_NORM": 0.5, 315 | "ACTIVATION": "tanh", 316 | "ENV_NAME": "hopper", 317 | "ANNEAL_LR": False, 318 | "NORMALIZE_ENV": True, 319 | "DEBUG": True, 320 | } 321 | rng = jax.random.PRNGKey(30) 322 | train_jit = jax.jit(make_train(config)) 323 | out = train_jit(rng) 324 | -------------------------------------------------------------------------------- /purejaxrl/dqn.py: -------------------------------------------------------------------------------- 1 | """ 2 | PureJaxRL version of CleanRL's DQN: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_jax.py 3 | """ 4 | import os 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | import chex 9 | import flax 10 | import wandb 11 | import optax 12 | import flax.linen as nn 13 | from flax.training.train_state import TrainState 14 | from gymnax.wrappers.purerl import FlattenObservationWrapper, LogWrapper 15 | import gymnax 16 | import flashbax as fbx 17 | 18 | 19 | class QNetwork(nn.Module): 20 | action_dim: int 21 | 22 | @nn.compact 23 | def __call__(self, x: jnp.ndarray): 24 | x = nn.Dense(120)(x) 25 | x = nn.relu(x) 26 | x = nn.Dense(84)(x) 27 | x = nn.relu(x) 28 | x = nn.Dense(self.action_dim)(x) 29 | return x 30 | 31 | 32 | @chex.dataclass(frozen=True) 33 | class TimeStep: 34 | obs: chex.Array 35 | action: chex.Array 36 | reward: chex.Array 37 | done: chex.Array 38 | 39 | 40 | class CustomTrainState(TrainState): 41 | target_network_params: flax.core.FrozenDict 42 | timesteps: int 43 | n_updates: int 44 | 45 | 46 | def make_train(config): 47 | 48 | config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // config["NUM_ENVS"] 49 | 50 | basic_env, env_params = gymnax.make(config["ENV_NAME"]) 51 | env = FlattenObservationWrapper(basic_env) 52 | env = LogWrapper(env) 53 | 54 | vmap_reset = lambda n_envs: lambda rng: jax.vmap(env.reset, in_axes=(0, None))( 55 | jax.random.split(rng, n_envs), env_params 56 | ) 57 | vmap_step = lambda n_envs: lambda rng, env_state, action: jax.vmap( 58 | env.step, in_axes=(0, 0, 0, None) 59 | )(jax.random.split(rng, n_envs), env_state, action, env_params) 60 | 61 | def train(rng): 62 | 63 | # INIT ENV 64 | rng, _rng = jax.random.split(rng) 65 | init_obs, env_state = vmap_reset(config["NUM_ENVS"])(_rng) 66 | 67 | # INIT BUFFER 68 | buffer = fbx.make_flat_buffer( 69 | max_length=config["BUFFER_SIZE"], 70 | min_length=config["BUFFER_BATCH_SIZE"], 71 | sample_batch_size=config["BUFFER_BATCH_SIZE"], 72 | add_sequences=False, 73 | add_batch_size=config["NUM_ENVS"], 74 | ) 75 | buffer = buffer.replace( 76 | init=jax.jit(buffer.init), 77 | add=jax.jit(buffer.add, donate_argnums=0), 78 | sample=jax.jit(buffer.sample), 79 | can_sample=jax.jit(buffer.can_sample), 80 | ) 81 | rng = jax.random.PRNGKey(0) # use a dummy rng here 82 | _action = basic_env.action_space().sample(rng) 83 | _, _env_state = env.reset(rng, env_params) 84 | _obs, _, _reward, _done, _ = env.step(rng, _env_state, _action, env_params) 85 | _timestep = TimeStep(obs=_obs, action=_action, reward=_reward, done=_done) 86 | buffer_state = buffer.init(_timestep) 87 | 88 | # INIT NETWORK AND OPTIMIZER 89 | network = QNetwork(action_dim=env.action_space(env_params).n) 90 | rng, _rng = jax.random.split(rng) 91 | init_x = jnp.zeros(env.observation_space(env_params).shape) 92 | network_params = network.init(_rng, init_x) 93 | 94 | def linear_schedule(count): 95 | frac = 1.0 - (count / config["NUM_UPDATES"]) 96 | return config["LR"] * frac 97 | 98 | lr = linear_schedule if config.get("LR_LINEAR_DECAY", False) else config["LR"] 99 | tx = optax.adam(learning_rate=lr) 100 | 101 | train_state = CustomTrainState.create( 102 | apply_fn=network.apply, 103 | params=network_params, 104 | target_network_params=jax.tree_map(lambda x: jnp.copy(x), network_params), 105 | tx=tx, 106 | timesteps=0, 107 | n_updates=0, 108 | ) 109 | 110 | # epsilon-greedy exploration 111 | def eps_greedy_exploration(rng, q_vals, t): 112 | rng_a, rng_e = jax.random.split( 113 | rng, 2 114 | ) # a key for sampling random actions and one for picking 115 | eps = jnp.clip( # get epsilon 116 | ( 117 | (config["EPSILON_FINISH"] - config["EPSILON_START"]) 118 | / config["EPSILON_ANNEAL_TIME"] 119 | ) 120 | * t 121 | + config["EPSILON_START"], 122 | config["EPSILON_FINISH"], 123 | ) 124 | greedy_actions = jnp.argmax(q_vals, axis=-1) # get the greedy actions 125 | chosed_actions = jnp.where( 126 | jax.random.uniform(rng_e, greedy_actions.shape) 127 | < eps, # pick the actions that should be random 128 | jax.random.randint( 129 | rng_a, shape=greedy_actions.shape, minval=0, maxval=q_vals.shape[-1] 130 | ), # sample random actions, 131 | greedy_actions, 132 | ) 133 | return chosed_actions 134 | 135 | # TRAINING LOOP 136 | def _update_step(runner_state, unused): 137 | 138 | train_state, buffer_state, env_state, last_obs, rng = runner_state 139 | 140 | # STEP THE ENV 141 | rng, rng_a, rng_s = jax.random.split(rng, 3) 142 | q_vals = network.apply(train_state.params, last_obs) 143 | action = eps_greedy_exploration( 144 | rng_a, q_vals, train_state.timesteps 145 | ) # explore with epsilon greedy_exploration 146 | obs, env_state, reward, done, info = vmap_step(config["NUM_ENVS"])( 147 | rng_s, env_state, action 148 | ) 149 | train_state = train_state.replace( 150 | timesteps=train_state.timesteps + config["NUM_ENVS"] 151 | ) # update timesteps count 152 | 153 | # BUFFER UPDATE 154 | timestep = TimeStep(obs=last_obs, action=action, reward=reward, done=done) 155 | buffer_state = buffer.add(buffer_state, timestep) 156 | 157 | # NETWORKS UPDATE 158 | def _learn_phase(train_state, rng): 159 | 160 | learn_batch = buffer.sample(buffer_state, rng).experience 161 | 162 | q_next_target = network.apply( 163 | train_state.target_network_params, learn_batch.second.obs 164 | ) # (batch_size, num_actions) 165 | q_next_target = jnp.max(q_next_target, axis=-1) # (batch_size,) 166 | target = ( 167 | learn_batch.first.reward 168 | + (1 - learn_batch.first.done) * config["GAMMA"] * q_next_target 169 | ) 170 | 171 | def _loss_fn(params): 172 | q_vals = network.apply( 173 | params, learn_batch.first.obs 174 | ) # (batch_size, num_actions) 175 | chosen_action_qvals = jnp.take_along_axis( 176 | q_vals, 177 | jnp.expand_dims(learn_batch.first.action, axis=-1), 178 | axis=-1, 179 | ).squeeze(axis=-1) 180 | return jnp.mean((chosen_action_qvals - target) ** 2) 181 | 182 | loss, grads = jax.value_and_grad(_loss_fn)(train_state.params) 183 | train_state = train_state.apply_gradients(grads=grads) 184 | train_state = train_state.replace(n_updates=train_state.n_updates + 1) 185 | return train_state, loss 186 | 187 | rng, _rng = jax.random.split(rng) 188 | is_learn_time = ( 189 | (buffer.can_sample(buffer_state)) 190 | & ( # enough experience in buffer 191 | train_state.timesteps > config["LEARNING_STARTS"] 192 | ) 193 | & ( # pure exploration phase ended 194 | train_state.timesteps % config["TRAINING_INTERVAL"] == 0 195 | ) # training interval 196 | ) 197 | train_state, loss = jax.lax.cond( 198 | is_learn_time, 199 | lambda train_state, rng: _learn_phase(train_state, rng), 200 | lambda train_state, rng: (train_state, jnp.array(0.0)), # do nothing 201 | train_state, 202 | _rng, 203 | ) 204 | 205 | # update target network 206 | train_state = jax.lax.cond( 207 | train_state.timesteps % config["TARGET_UPDATE_INTERVAL"] == 0, 208 | lambda train_state: train_state.replace( 209 | target_network_params=optax.incremental_update( 210 | train_state.params, 211 | train_state.target_network_params, 212 | config["TAU"], 213 | ) 214 | ), 215 | lambda train_state: train_state, 216 | operand=train_state, 217 | ) 218 | 219 | metrics = { 220 | "timesteps": train_state.timesteps, 221 | "updates": train_state.n_updates, 222 | "loss": loss.mean(), 223 | "returns": info["returned_episode_returns"].mean(), 224 | } 225 | 226 | # report on wandb if required 227 | if config.get("WANDB_MODE", "disabled") == "online": 228 | 229 | def callback(metrics): 230 | if metrics["timesteps"] % 100 == 0: 231 | wandb.log(metrics) 232 | 233 | jax.debug.callback(callback, metrics) 234 | 235 | runner_state = (train_state, buffer_state, env_state, obs, rng) 236 | 237 | return runner_state, metrics 238 | 239 | # train 240 | rng, _rng = jax.random.split(rng) 241 | runner_state = (train_state, buffer_state, env_state, init_obs, _rng) 242 | 243 | runner_state, metrics = jax.lax.scan( 244 | _update_step, runner_state, None, config["NUM_UPDATES"] 245 | ) 246 | return {"runner_state": runner_state, "metrics": metrics} 247 | 248 | return train 249 | 250 | 251 | def main(): 252 | 253 | config = { 254 | "NUM_ENVS": 10, 255 | "BUFFER_SIZE": 10000, 256 | "BUFFER_BATCH_SIZE": 128, 257 | "TOTAL_TIMESTEPS": 5e5, 258 | "EPSILON_START": 1.0, 259 | "EPSILON_FINISH": 0.05, 260 | "EPSILON_ANNEAL_TIME": 25e4, 261 | "TARGET_UPDATE_INTERVAL": 500, 262 | "LR": 2.5e-4, 263 | "LEARNING_STARTS": 10000, 264 | "TRAINING_INTERVAL": 10, 265 | "LR_LINEAR_DECAY": False, 266 | "GAMMA": 0.99, 267 | "TAU": 1.0, 268 | "ENV_NAME": "CartPole-v1", 269 | "SEED": 0, 270 | "NUM_SEEDS": 1, 271 | "WANDB_MODE": "disabled", # set to online to activate wandb 272 | "ENTITY": "", 273 | "PROJECT": "", 274 | } 275 | 276 | wandb.init( 277 | entity=config["ENTITY"], 278 | project=config["PROJECT"], 279 | tags=["DQN", config["ENV_NAME"].upper(), f"jax_{jax.__version__}"], 280 | name=f'purejaxrl_dqn_{config["ENV_NAME"]}', 281 | config=config, 282 | mode=config["WANDB_MODE"], 283 | ) 284 | 285 | rng = jax.random.PRNGKey(config["SEED"]) 286 | rngs = jax.random.split(rng, config["NUM_SEEDS"]) 287 | train_vjit = jax.jit(jax.vmap(make_train(config))) 288 | outs = jax.block_until_ready(train_vjit(rngs)) 289 | 290 | 291 | if __name__ == "__main__": 292 | main() 293 | -------------------------------------------------------------------------------- /purejaxrl/experimental/s5/README.md: -------------------------------------------------------------------------------- 1 | # PPO S5 2 | 3 | This is a re-implementation of the architecture from [this paper](https://arxiv.org/abs/2303.03982). 4 | 5 | This is currently a work-in-progress since the code in its current state needs to be cleaned. 6 | 7 | If you use this code in an academic paper, please cite: 8 | 9 | 10 | ``` 11 | @article{lu2023structured, 12 | title={Structured State Space Models for In-Context Reinforcement Learning}, 13 | author={Lu, Chris and Schroecker, Yannick and Gu, Albert and Parisotto, Emilio and Foerster, Jakob and Singh, Satinder and Behbahani, Feryal}, 14 | journal={arXiv preprint arXiv:2303.03982}, 15 | year={2023} 16 | } 17 | 18 | @article{smith2022simplified, 19 | title={Simplified state space layers for sequence modeling}, 20 | author={Smith, Jimmy TH and Warrington, Andrew and Linderman, Scott W}, 21 | journal={arXiv preprint arXiv:2208.04933}, 22 | year={2022} 23 | } 24 | 25 | @article{lu2022discovered, 26 | title={Discovered policy optimisation}, 27 | author={Lu, Chris and Kuba, Jakub and Letcher, Alistair and Metz, Luke and Schroeder de Witt, Christian and Foerster, Jakob}, 28 | journal={Advances in Neural Information Processing Systems}, 29 | volume={35}, 30 | pages={16455--16468}, 31 | year={2022} 32 | } 33 | ``` -------------------------------------------------------------------------------- /purejaxrl/experimental/s5/ppo_s5.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | import numpy as np 5 | import optax 6 | from flax.linen.initializers import constant, orthogonal 7 | from typing import Sequence, NamedTuple, Any, Dict 8 | from flax.training.train_state import TrainState 9 | import distrax 10 | import gymnax 11 | from wrappers import FlattenObservationWrapper, LogWrapper 12 | from gymnax.environments import spaces 13 | from s5 import init_S5SSM, make_DPLR_HiPPO, StackedEncoderModel 14 | 15 | d_model = 256 16 | ssm_size = 256 17 | C_init = "lecun_normal" 18 | discretization="zoh" 19 | dt_min=0.001 20 | dt_max=0.1 21 | n_layers = 4 22 | conj_sym=True 23 | clip_eigs=False 24 | bidirectional=False 25 | 26 | blocks = 1 27 | block_size = int(ssm_size / blocks) 28 | 29 | Lambda, _, B, V, B_orig = make_DPLR_HiPPO(ssm_size) 30 | 31 | block_size = block_size // 2 32 | ssm_size = ssm_size // 2 33 | 34 | Lambda = Lambda[:block_size] 35 | V = V[:, :block_size] 36 | 37 | Vinv = V.conj().T 38 | 39 | 40 | ssm_init_fn = init_S5SSM(H=d_model, 41 | P=ssm_size, 42 | Lambda_re_init=Lambda.real, 43 | Lambda_im_init=Lambda.imag, 44 | V=V, 45 | Vinv=Vinv, 46 | C_init=C_init, 47 | discretization=discretization, 48 | dt_min=dt_min, 49 | dt_max=dt_max, 50 | conj_sym=conj_sym, 51 | clip_eigs=clip_eigs, 52 | bidirectional=bidirectional) 53 | 54 | class ActorCriticS5(nn.Module): 55 | action_dim: Sequence[int] 56 | config: Dict 57 | 58 | def setup(self): 59 | self.encoder_0 = nn.Dense(128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)) 60 | self.encoder_1 = nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)) 61 | 62 | self.action_body_0 = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0)) 63 | self.action_body_1 = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0)) 64 | self.action_decoder = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)) 65 | 66 | self.value_body_0 = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0)) 67 | self.value_body_1 = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0)) 68 | self.value_decoder = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0)) 69 | 70 | self.s5 = StackedEncoderModel( 71 | ssm=ssm_init_fn, 72 | d_model=d_model, 73 | n_layers=n_layers, 74 | activation="half_glu1", 75 | ) 76 | 77 | def __call__(self, hidden, x): 78 | obs, dones = x 79 | embedding = self.encoder_0(obs) 80 | embedding = nn.leaky_relu(embedding) 81 | embedding = self.encoder_1(embedding) 82 | embedding = nn.leaky_relu(embedding) 83 | 84 | hidden, embedding = self.s5(hidden, embedding, dones) 85 | 86 | actor_mean = self.action_body_0(embedding) 87 | actor_mean = nn.leaky_relu(actor_mean) 88 | actor_mean = self.action_body_1(actor_mean) 89 | actor_mean = nn.leaky_relu(actor_mean) 90 | actor_mean = self.action_decoder(actor_mean) 91 | 92 | pi = distrax.Categorical(logits=actor_mean) 93 | 94 | critic = self.value_body_0(embedding) 95 | critic = nn.leaky_relu(critic) 96 | critic = self.value_body_1(critic) 97 | critic = nn.leaky_relu(critic) 98 | critic = self.value_decoder(critic) 99 | 100 | return hidden, pi, jnp.squeeze(critic, axis=-1) 101 | 102 | class Transition(NamedTuple): 103 | done: jnp.ndarray 104 | action: jnp.ndarray 105 | value: jnp.ndarray 106 | reward: jnp.ndarray 107 | log_prob: jnp.ndarray 108 | obs: jnp.ndarray 109 | info: jnp.ndarray 110 | 111 | 112 | def make_train(config): 113 | config["NUM_UPDATES"] = ( 114 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] 115 | ) 116 | config["MINIBATCH_SIZE"] = ( 117 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] 118 | ) 119 | env, env_params = gymnax.make(config["ENV_NAME"]) 120 | env = FlattenObservationWrapper(env) 121 | env = LogWrapper(env) 122 | 123 | def linear_schedule(count): 124 | frac = ( 125 | 1.0 126 | - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) 127 | / config["NUM_UPDATES"] 128 | ) 129 | return config["LR"] * frac 130 | 131 | def train(rng): 132 | # INIT NETWORK 133 | network = ActorCriticS5(env.action_space(env_params).n, config=config) 134 | rng, _rng = jax.random.split(rng) 135 | init_x = ( 136 | jnp.zeros( 137 | (1, config["NUM_ENVS"], *env.observation_space(env_params).shape) 138 | ), 139 | jnp.zeros((1, config["NUM_ENVS"])), 140 | ) 141 | init_hstate = StackedEncoderModel.initialize_carry(config["NUM_ENVS"], ssm_size, n_layers) 142 | network_params = network.init(_rng, init_hstate, init_x) 143 | if config["ANNEAL_LR"]: 144 | tx = optax.chain( 145 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), 146 | optax.adam(learning_rate=linear_schedule, eps=1e-5), 147 | ) 148 | else: 149 | tx = optax.chain( 150 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), 151 | optax.adam(config["LR"], eps=1e-5), 152 | ) 153 | train_state = TrainState.create( 154 | apply_fn=network.apply, 155 | params=network_params, 156 | tx=tx, 157 | ) 158 | 159 | # INIT ENV 160 | rng, _rng = jax.random.split(rng) 161 | reset_rng = jax.random.split(_rng, config["NUM_ENVS"]) 162 | obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params) 163 | init_hstate = StackedEncoderModel.initialize_carry(config["NUM_ENVS"], ssm_size, n_layers) 164 | 165 | # TRAIN LOOP 166 | def _update_step(runner_state, unused): 167 | # COLLECT TRAJECTORIES 168 | def _env_step(runner_state, unused): 169 | train_state, env_state, last_obs, last_done, hstate, rng = runner_state 170 | rng, _rng = jax.random.split(rng) 171 | 172 | # SELECT ACTION 173 | ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :]) 174 | hstate, pi, value = network.apply(train_state.params, hstate, ac_in) 175 | action = pi.sample(seed=_rng) 176 | log_prob = pi.log_prob(action) 177 | value, action, log_prob = ( 178 | value.squeeze(0), 179 | action.squeeze(0), 180 | log_prob.squeeze(0), 181 | ) 182 | 183 | # STEP ENV 184 | rng, _rng = jax.random.split(rng) 185 | rng_step = jax.random.split(_rng, config["NUM_ENVS"]) 186 | obsv, env_state, reward, done, info = jax.vmap( 187 | env.step, in_axes=(0, 0, 0, None) 188 | )(rng_step, env_state, action, env_params) 189 | transition = Transition( 190 | last_done, action, value, reward, log_prob, last_obs, info 191 | ) 192 | runner_state = (train_state, env_state, obsv, done, hstate, rng) 193 | return runner_state, transition 194 | 195 | initial_hstate = runner_state[-2] 196 | runner_state, traj_batch = jax.lax.scan( 197 | _env_step, runner_state, None, config["NUM_STEPS"] 198 | ) 199 | 200 | # CALCULATE ADVANTAGE 201 | train_state, env_state, last_obs, last_done, hstate, rng = runner_state 202 | ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :]) 203 | _, _, last_val = network.apply(train_state.params, hstate, ac_in) 204 | last_val = last_val.squeeze(0) 205 | def _calculate_gae(traj_batch, last_val, last_done): 206 | def _get_advantages(carry, transition): 207 | gae, next_value, next_done = carry 208 | done, value, reward = transition.done, transition.value, transition.reward 209 | delta = reward + config["GAMMA"] * next_value * (1 - next_done) - value 210 | gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - next_done) * gae 211 | return (gae, value, done), gae 212 | _, advantages = jax.lax.scan(_get_advantages, (jnp.zeros_like(last_val), last_val, last_done), traj_batch, reverse=True, unroll=16) 213 | return advantages, advantages + traj_batch.value 214 | advantages, targets = _calculate_gae(traj_batch, last_val, last_done) 215 | 216 | # UPDATE NETWORK 217 | def _update_epoch(update_state, unused): 218 | def _update_minbatch(train_state, batch_info): 219 | init_hstate, traj_batch, advantages, targets = batch_info 220 | 221 | def _loss_fn(params, init_hstate, traj_batch, gae, targets): 222 | # RERUN NETWORK 223 | _, pi, value = network.apply( 224 | params, init_hstate, (traj_batch.obs, traj_batch.done) 225 | ) 226 | log_prob = pi.log_prob(traj_batch.action) 227 | 228 | # CALCULATE VALUE LOSS 229 | value_pred_clipped = traj_batch.value + ( 230 | value - traj_batch.value 231 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) 232 | value_losses = jnp.square(value - targets) 233 | value_losses_clipped = jnp.square(value_pred_clipped - targets) 234 | value_loss = ( 235 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() 236 | ) 237 | 238 | # CALCULATE ACTOR LOSS 239 | ratio = jnp.exp(log_prob - traj_batch.log_prob) 240 | gae = (gae - gae.mean()) / (gae.std() + 1e-8) 241 | loss_actor1 = ratio * gae 242 | loss_actor2 = ( 243 | jnp.clip( 244 | ratio, 245 | 1.0 - config["CLIP_EPS"], 246 | 1.0 + config["CLIP_EPS"], 247 | ) 248 | * gae 249 | ) 250 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2) 251 | loss_actor = loss_actor.mean() 252 | entropy = pi.entropy().mean() 253 | 254 | total_loss = ( 255 | loss_actor 256 | + config["VF_COEF"] * value_loss 257 | - config["ENT_COEF"] * entropy 258 | ) 259 | return total_loss, (value_loss, loss_actor, entropy) 260 | 261 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) 262 | total_loss, grads = grad_fn( 263 | train_state.params, init_hstate, traj_batch, advantages, targets 264 | ) 265 | train_state = train_state.apply_gradients(grads=grads) 266 | return train_state, total_loss 267 | 268 | ( 269 | train_state, 270 | init_hstate, 271 | traj_batch, 272 | advantages, 273 | targets, 274 | rng, 275 | ) = update_state 276 | 277 | rng, _rng = jax.random.split(rng) 278 | permutation = jax.random.permutation(_rng, config["NUM_ENVS"]) 279 | batch = (init_hstate, traj_batch, advantages, targets) 280 | 281 | shuffled_batch = jax.tree_util.tree_map( 282 | lambda x: jnp.take(x, permutation, axis=1), batch 283 | ) 284 | 285 | minibatches = jax.tree_util.tree_map( 286 | lambda x: jnp.swapaxes( 287 | jnp.reshape( 288 | x, 289 | [x.shape[0], config["NUM_MINIBATCHES"], -1] 290 | + list(x.shape[2:]), 291 | ), 292 | 1, 293 | 0, 294 | ), 295 | shuffled_batch, 296 | ) 297 | 298 | train_state, total_loss = jax.lax.scan( 299 | _update_minbatch, train_state, minibatches 300 | ) 301 | update_state = ( 302 | train_state, 303 | init_hstate, 304 | traj_batch, 305 | advantages, 306 | targets, 307 | rng, 308 | ) 309 | return update_state, total_loss 310 | 311 | init_hstate = initial_hstate # TBH 312 | update_state = ( 313 | train_state, 314 | init_hstate, 315 | traj_batch, 316 | advantages, 317 | targets, 318 | rng, 319 | ) 320 | update_state, loss_info = jax.lax.scan( 321 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"] 322 | ) 323 | train_state = update_state[0] 324 | metric = traj_batch.info 325 | rng = update_state[-1] 326 | if config.get("DEBUG"): 327 | def callback(info): 328 | return_values = info["returned_episode_returns"][info["returned_episode"]] 329 | timesteps = info["timestep"][info["returned_episode"]] * config["NUM_ENVS"] 330 | for t in range(len(timesteps)): 331 | print(f"global step={timesteps[t]}, episodic return={return_values[t]}") 332 | jax.debug.callback(callback, metric) 333 | 334 | runner_state = (train_state, env_state, last_obs, last_done, hstate, rng) 335 | return runner_state, metric 336 | 337 | rng, _rng = jax.random.split(rng) 338 | runner_state = ( 339 | train_state, 340 | env_state, 341 | obsv, 342 | jnp.zeros((config["NUM_ENVS"]), dtype=bool), 343 | init_hstate, 344 | _rng, 345 | ) 346 | runner_state, metric = jax.lax.scan( 347 | _update_step, runner_state, None, config["NUM_UPDATES"] 348 | ) 349 | return {"runner_state": runner_state, "metric": metric} 350 | 351 | return train 352 | 353 | 354 | if __name__ == "__main__": 355 | config = { 356 | "LR": 2.5e-4, 357 | "NUM_ENVS": 4, 358 | "NUM_STEPS": 128, 359 | "TOTAL_TIMESTEPS": 5e5, 360 | "UPDATE_EPOCHS": 4, 361 | "NUM_MINIBATCHES": 4, 362 | "GAMMA": 0.99, 363 | "GAE_LAMBDA": 0.95, 364 | "CLIP_EPS": 0.2, 365 | "ENT_COEF": 0.01, 366 | "VF_COEF": 0.5, 367 | "MAX_GRAD_NORM": 0.5, 368 | "ENV_NAME": "CartPole-v1", 369 | "ANNEAL_LR": True, 370 | "DEBUG": True, 371 | } 372 | 373 | rng = jax.random.PRNGKey(30) 374 | train_jit = jax.jit(make_train(config)) 375 | out = train_jit(rng) 376 | -------------------------------------------------------------------------------- /purejaxrl/experimental/s5/s5.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/lindermanlab/S5""" 2 | 3 | from functools import partial 4 | import jax 5 | import jax.numpy as np 6 | import jax.numpy as jnp 7 | from flax import linen as nn 8 | from jax.nn.initializers import lecun_normal, normal 9 | from jax import random 10 | from jax.numpy.linalg import eigh 11 | 12 | class SequenceLayer(nn.Module): 13 | """ Defines a single S5 layer, with S5 SSM, nonlinearity, 14 | dropout, batch/layer norm, etc. 15 | Args: 16 | ssm (nn.Module): the SSM to be used (i.e. S5 ssm) 17 | dropout (float32): dropout rate 18 | d_model (int32): this is the feature size of the layer inputs and outputs 19 | we usually refer to this size as H 20 | activation (string): Type of activation function to use 21 | training (bool): whether in training mode or not 22 | prenorm (bool): apply prenorm if true or postnorm if false 23 | batchnorm (bool): apply batchnorm if true or layernorm if false 24 | bn_momentum (float32): the batchnorm momentum if batchnorm is used 25 | step_rescale (float32): allows for uniformly changing the timescale parameter, 26 | e.g. after training on a different resolution for 27 | the speech commands benchmark 28 | """ 29 | ssm: nn.Module 30 | # dropout: float 31 | d_model: int 32 | activation: str = "gelu" 33 | # training: bool = True 34 | # prenorm: bool = False 35 | # batchnorm: bool = False 36 | # bn_momentum: float = 0.90 37 | step_rescale: float = 1.0 38 | 39 | def setup(self): 40 | """Initializes the ssm, batch/layer norm and dropout 41 | """ 42 | self.seq = self.ssm(step_rescale=self.step_rescale) 43 | 44 | if self.activation in ["full_glu"]: 45 | self.out1 = nn.Dense(self.d_model) 46 | self.out2 = nn.Dense(self.d_model) 47 | elif self.activation in ["half_glu1", "half_glu2"]: 48 | self.out2 = nn.Dense(self.d_model) 49 | 50 | # if self.batchnorm: 51 | # self.norm = nn.BatchNorm(use_running_average=not self.training, 52 | # momentum=self.bn_momentum, axis_name='batch') 53 | # else: 54 | # self.norm = nn.LayerNorm() 55 | 56 | # self.drop = nn.Dropout( 57 | # self.dropout, 58 | # broadcast_dims=[0], 59 | # deterministic=not self.training, 60 | # ) 61 | self.drop = lambda x: x 62 | 63 | def __call__(self, hidden, x, d): 64 | """ 65 | Compute the LxH output of S5 layer given an LxH input. 66 | Args: 67 | x (float32): input sequence (L, d_model) 68 | d (bool): reset signal (L,) 69 | Returns: 70 | output sequence (float32): (L, d_model) 71 | """ 72 | skip = x 73 | # if self.prenorm: 74 | # x = self.norm(x) 75 | # hidden, x = self.seq(hidden, x, d) 76 | hidden, x = jax.vmap(self.seq, in_axes=1, out_axes=1)(hidden, x, d) 77 | # hidden = jnp.swapaxes(hidden, 1, 0) 78 | 79 | if self.activation in ["full_glu"]: 80 | x = self.drop(nn.gelu(x)) 81 | x = self.out1(x) * jax.nn.sigmoid(self.out2(x)) 82 | x = self.drop(x) 83 | elif self.activation in ["half_glu1"]: 84 | x = self.drop(nn.gelu(x)) 85 | x = x * jax.nn.sigmoid(self.out2(x)) 86 | x = self.drop(x) 87 | elif self.activation in ["half_glu2"]: 88 | # Only apply GELU to the gate input 89 | x1 = self.drop(nn.gelu(x)) 90 | x = x * jax.nn.sigmoid(self.out2(x1)) 91 | x = self.drop(x) 92 | elif self.activation in ["gelu"]: 93 | x = self.drop(nn.gelu(x)) 94 | else: 95 | raise NotImplementedError( 96 | "Activation: {} not implemented".format(self.activation)) 97 | 98 | x = skip + x 99 | # if not self.prenorm: 100 | # x = self.norm(x) 101 | return hidden, x 102 | 103 | @staticmethod 104 | def initialize_carry(batch_size, hidden_size): 105 | # Use a dummy key since the default state init fn is just zeros. 106 | # return nn.LSTMCell.initialize_carry( 107 | # jax.random.PRNGKey(0), (batch_size,), hidden_size) 108 | return jnp.zeros((1, batch_size, hidden_size), dtype=jnp.complex64) 109 | 110 | def log_step_initializer(dt_min=0.001, dt_max=0.1): 111 | """ Initialize the learnable timescale Delta by sampling 112 | uniformly between dt_min and dt_max. 113 | Args: 114 | dt_min (float32): minimum value 115 | dt_max (float32): maximum value 116 | Returns: 117 | init function 118 | """ 119 | def init(key, shape): 120 | """ Init function 121 | Args: 122 | key: jax random key 123 | shape tuple: desired shape 124 | Returns: 125 | sampled log_step (float32) 126 | """ 127 | return random.uniform(key, shape) * ( 128 | np.log(dt_max) - np.log(dt_min) 129 | ) + np.log(dt_min) 130 | 131 | return init 132 | 133 | 134 | def init_log_steps(key, input): 135 | """ Initialize an array of learnable timescale parameters 136 | Args: 137 | key: jax random key 138 | input: tuple containing the array shape H and 139 | dt_min and dt_max 140 | Returns: 141 | initialized array of timescales (float32): (H,) 142 | """ 143 | H, dt_min, dt_max = input 144 | log_steps = [] 145 | for i in range(H): 146 | key, skey = random.split(key) 147 | log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,)) 148 | log_steps.append(log_step) 149 | 150 | return np.array(log_steps) 151 | 152 | 153 | def init_VinvB(init_fun, rng, shape, Vinv): 154 | """ Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. 155 | Note we will parameterize this with two different matrices for complex 156 | numbers. 157 | Args: 158 | init_fun: the initialization function to use, e.g. lecun_normal() 159 | rng: jax random key to be used with init function. 160 | shape (tuple): desired shape (P,H) 161 | Vinv: (complex64) the inverse eigenvectors used for initialization 162 | Returns: 163 | B_tilde (complex64) of shape (P,H,2) 164 | """ 165 | B = init_fun(rng, shape) 166 | VinvB = Vinv @ B 167 | VinvB_real = VinvB.real 168 | VinvB_imag = VinvB.imag 169 | return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) 170 | 171 | 172 | def trunc_standard_normal(key, shape): 173 | """ Sample C with a truncated normal distribution with standard deviation 1. 174 | Args: 175 | key: jax random key 176 | shape (tuple): desired shape, of length 3, (H,P,_) 177 | Returns: 178 | sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) 179 | """ 180 | H, P, _ = shape 181 | Cs = [] 182 | for i in range(H): 183 | key, skey = random.split(key) 184 | C = lecun_normal()(skey, shape=(1, P, 2)) 185 | Cs.append(C) 186 | return np.array(Cs)[:, 0] 187 | 188 | 189 | def init_CV(init_fun, rng, shape, V): 190 | """ Initialize C_tilde=CV. First sample C. Then compute CV. 191 | Note we will parameterize this with two different matrices for complex 192 | numbers. 193 | Args: 194 | init_fun: the initialization function to use, e.g. lecun_normal() 195 | rng: jax random key to be used with init function. 196 | shape (tuple): desired shape (H,P) 197 | V: (complex64) the eigenvectors used for initialization 198 | Returns: 199 | C_tilde (complex64) of shape (H,P,2) 200 | """ 201 | C_ = init_fun(rng, shape) 202 | C = C_[..., 0] + 1j * C_[..., 1] 203 | CV = C @ V 204 | CV_real = CV.real 205 | CV_imag = CV.imag 206 | return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1) 207 | 208 | 209 | # Discretization functions 210 | def discretize_bilinear(Lambda, B_tilde, Delta): 211 | """ Discretize a diagonalized, continuous-time linear SSM 212 | using bilinear transform method. 213 | Args: 214 | Lambda (complex64): diagonal state matrix (P,) 215 | B_tilde (complex64): input matrix (P, H) 216 | Delta (float32): discretization step sizes (P,) 217 | Returns: 218 | discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) 219 | """ 220 | Identity = np.ones(Lambda.shape[0]) 221 | 222 | BL = 1 / (Identity - (Delta / 2.0) * Lambda) 223 | Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda) 224 | B_bar = (BL * Delta)[..., None] * B_tilde 225 | return Lambda_bar, B_bar 226 | 227 | 228 | def discretize_zoh(Lambda, B_tilde, Delta): 229 | """ Discretize a diagonalized, continuous-time linear SSM 230 | using zero-order hold method. 231 | Args: 232 | Lambda (complex64): diagonal state matrix (P,) 233 | B_tilde (complex64): input matrix (P, H) 234 | Delta (float32): discretization step sizes (P,) 235 | Returns: 236 | discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) 237 | """ 238 | Identity = np.ones(Lambda.shape[0]) 239 | Lambda_bar = np.exp(Lambda * Delta) 240 | B_bar = (1/Lambda * (Lambda_bar-Identity))[..., None] * B_tilde 241 | return Lambda_bar, B_bar 242 | 243 | 244 | # Parallel scan operations 245 | @jax.vmap 246 | def binary_operator(q_i, q_j): 247 | """ Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. 248 | Args: 249 | q_i: tuple containing A_i and Bu_i at position i (P,), (P,) 250 | q_j: tuple containing A_j and Bu_j at position j (P,), (P,) 251 | Returns: 252 | new element ( A_out, Bu_out ) 253 | """ 254 | A_i, b_i = q_i 255 | A_j, b_j = q_j 256 | return A_j * A_i, A_j * b_i + b_j 257 | 258 | # Parallel scan operations 259 | @jax.vmap 260 | def binary_operator_reset(q_i, q_j): 261 | """ Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. 262 | Args: 263 | q_i: tuple containing A_i and Bu_i at position i (P,), (P,) 264 | q_j: tuple containing A_j and Bu_j at position j (P,), (P,) 265 | Returns: 266 | new element ( A_out, Bu_out ) 267 | """ 268 | A_i, b_i, c_i = q_i 269 | A_j, b_j, c_j = q_j 270 | return ( 271 | (A_j * A_i)*(1 - c_j) + A_j * c_j, 272 | (A_j * b_i + b_j)*(1 - c_j) + b_j * c_j, 273 | c_i * (1 - c_j) + c_j, 274 | ) 275 | 276 | 277 | 278 | def apply_ssm(Lambda_bar, B_bar, C_tilde, hidden, input_sequence, resets, conj_sym, bidirectional): 279 | """ Compute the LxH output of discretized SSM given an LxH input. 280 | Args: 281 | Lambda_bar (complex64): discretized diagonal state matrix (P,) 282 | B_bar (complex64): discretized input matrix (P, H) 283 | C_tilde (complex64): output matrix (H, P) 284 | input_sequence (float32): input sequence of features (L, H) 285 | reset (bool): input sequence of features (L,) 286 | conj_sym (bool): whether conjugate symmetry is enforced 287 | bidirectional (bool): whether bidirectional setup is used, 288 | Note for this case C_tilde will have 2P cols 289 | Returns: 290 | ys (float32): the SSM outputs (S5 layer preactivations) (L, H) 291 | """ 292 | Lambda_elements = Lambda_bar * jnp.ones((input_sequence.shape[0], 293 | Lambda_bar.shape[0])) 294 | Bu_elements = jax.vmap(lambda u: B_bar @ u)(input_sequence) 295 | 296 | Lambda_elements = jnp.concatenate([ 297 | jnp.ones((1, Lambda_bar.shape[0])), 298 | Lambda_elements, 299 | ]) 300 | 301 | Bu_elements = jnp.concatenate([ 302 | hidden, 303 | Bu_elements, 304 | ]) 305 | 306 | resets = jnp.concatenate([ 307 | jnp.zeros(1), 308 | resets, 309 | ]) 310 | 311 | 312 | _, xs, _ = jax.lax.associative_scan(binary_operator_reset, (Lambda_elements, Bu_elements, resets)) 313 | xs = xs[1:] 314 | 315 | if conj_sym: 316 | return xs[np.newaxis, -1], jax.vmap(lambda x: 2*(C_tilde @ x).real)(xs) 317 | else: 318 | return xs[np.newaxis, -1], jax.vmap(lambda x: (C_tilde @ x).real)(xs) 319 | 320 | 321 | class S5SSM(nn.Module): 322 | Lambda_re_init: np.DeviceArray 323 | Lambda_im_init: np.DeviceArray 324 | V: np.DeviceArray 325 | Vinv: np.DeviceArray 326 | 327 | H: int 328 | P: int 329 | C_init: str 330 | discretization: str 331 | dt_min: float 332 | dt_max: float 333 | conj_sym: bool = True 334 | clip_eigs: bool = False 335 | bidirectional: bool = False 336 | step_rescale: float = 1.0 337 | 338 | """ The S5 SSM 339 | Args: 340 | Lambda_re_init (complex64): Real part of init diag state matrix (P,) 341 | Lambda_im_init (complex64): Imag part of init diag state matrix (P,) 342 | V (complex64): Eigenvectors used for init (P,P) 343 | Vinv (complex64): Inverse eigenvectors used for init (P,P) 344 | H (int32): Number of features of input seq 345 | P (int32): state size 346 | C_init (string): Specifies How C is initialized 347 | Options: [trunc_standard_normal: sample from truncated standard normal 348 | and then multiply by V, i.e. C_tilde=CV. 349 | lecun_normal: sample from Lecun_normal and then multiply by V. 350 | complex_normal: directly sample a complex valued output matrix 351 | from standard normal, does not multiply by V] 352 | conj_sym (bool): Whether conjugate symmetry is enforced 353 | clip_eigs (bool): Whether to enforce left-half plane condition, i.e. 354 | constrain real part of eigenvalues to be negative. 355 | True recommended for autoregressive task/unbounded sequence lengths 356 | Discussed in https://arxiv.org/pdf/2206.11893.pdf. 357 | bidirectional (bool): Whether model is bidirectional, if True, uses two C matrices 358 | discretization: (string) Specifies discretization method 359 | options: [zoh: zero-order hold method, 360 | bilinear: bilinear transform] 361 | dt_min: (float32): minimum value to draw timescale values from when 362 | initializing log_step 363 | dt_max: (float32): maximum value to draw timescale values from when 364 | initializing log_step 365 | step_rescale: (float32): allows for uniformly changing the timescale parameter, e.g. after training 366 | on a different resolution for the speech commands benchmark 367 | """ 368 | 369 | def setup(self): 370 | """Initializes parameters once and performs discretization each time 371 | the SSM is applied to a sequence 372 | """ 373 | 374 | if self.conj_sym: 375 | # Need to account for case where we actually sample real B and C, and then multiply 376 | # by the half sized Vinv and possibly V 377 | local_P = 2*self.P 378 | else: 379 | local_P = self.P 380 | 381 | # Initialize diagonal state to state matrix Lambda (eigenvalues) 382 | self.Lambda_re = self.param("Lambda_re", lambda rng, shape: self.Lambda_re_init, (None,)) 383 | self.Lambda_im = self.param("Lambda_im", lambda rng, shape: self.Lambda_im_init, (None,)) 384 | if self.clip_eigs: 385 | self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im 386 | else: 387 | self.Lambda = self.Lambda_re + 1j * self.Lambda_im 388 | 389 | # Initialize input to state (B) matrix 390 | B_init = lecun_normal() 391 | B_shape = (local_P, self.H) 392 | self.B = self.param("B", 393 | lambda rng, shape: init_VinvB(B_init, 394 | rng, 395 | shape, 396 | self.Vinv), 397 | B_shape) 398 | B_tilde = self.B[..., 0] + 1j * self.B[..., 1] 399 | 400 | # Initialize state to output (C) matrix 401 | if self.C_init in ["trunc_standard_normal"]: 402 | C_init = trunc_standard_normal 403 | C_shape = (self.H, local_P, 2) 404 | elif self.C_init in ["lecun_normal"]: 405 | C_init = lecun_normal() 406 | C_shape = (self.H, local_P, 2) 407 | elif self.C_init in ["complex_normal"]: 408 | C_init = normal(stddev=0.5 ** 0.5) 409 | else: 410 | raise NotImplementedError( 411 | "C_init method {} not implemented".format(self.C_init)) 412 | 413 | if self.C_init in ["complex_normal"]: 414 | if self.bidirectional: 415 | C = self.param("C", C_init, (self.H, 2 * self.P, 2)) 416 | self.C_tilde = C[..., 0] + 1j * C[..., 1] 417 | 418 | else: 419 | C = self.param("C", C_init, (self.H, self.P, 2)) 420 | self.C_tilde = C[..., 0] + 1j * C[..., 1] 421 | 422 | else: 423 | if self.bidirectional: 424 | self.C1 = self.param("C1", 425 | lambda rng, shape: init_CV(C_init, rng, shape, self.V), 426 | C_shape) 427 | self.C2 = self.param("C2", 428 | lambda rng, shape: init_CV(C_init, rng, shape, self.V), 429 | C_shape) 430 | 431 | C1 = self.C1[..., 0] + 1j * self.C1[..., 1] 432 | C2 = self.C2[..., 0] + 1j * self.C2[..., 1] 433 | self.C_tilde = np.concatenate((C1, C2), axis=-1) 434 | 435 | else: 436 | self.C = self.param("C", 437 | lambda rng, shape: init_CV(C_init, rng, shape, self.V), 438 | C_shape) 439 | 440 | self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1] 441 | 442 | # Initialize feedthrough (D) matrix 443 | self.D = self.param("D", normal(stddev=1.0), (self.H,)) 444 | 445 | # Initialize learnable discretization timescale value 446 | self.log_step = self.param("log_step", 447 | init_log_steps, 448 | (self.P, self.dt_min, self.dt_max)) 449 | step = self.step_rescale * np.exp(self.log_step[:, 0]) 450 | 451 | # Discretize 452 | if self.discretization in ["zoh"]: 453 | self.Lambda_bar, self.B_bar = discretize_zoh(self.Lambda, B_tilde, step) 454 | elif self.discretization in ["bilinear"]: 455 | self.Lambda_bar, self.B_bar = discretize_bilinear(self.Lambda, B_tilde, step) 456 | else: 457 | raise NotImplementedError("Discretization method {} not implemented".format(self.discretization)) 458 | 459 | def __call__(self, hidden, input_sequence, resets): 460 | """ 461 | Compute the LxH output of the S5 SSM given an LxH input sequence 462 | using a parallel scan. 463 | Args: 464 | input_sequence (float32): input sequence (L, H) 465 | resets (bool): input sequence (L,) 466 | Returns: 467 | output sequence (float32): (L, H) 468 | """ 469 | hidden, ys = apply_ssm(self.Lambda_bar, 470 | self.B_bar, 471 | self.C_tilde, 472 | hidden, 473 | input_sequence, 474 | resets, 475 | self.conj_sym, 476 | self.bidirectional) 477 | # Add feedthrough matrix output Du; 478 | Du = jax.vmap(lambda u: self.D * u)(input_sequence) 479 | return hidden, ys + Du 480 | 481 | 482 | def init_S5SSM(H, 483 | P, 484 | Lambda_re_init, 485 | Lambda_im_init, 486 | V, 487 | Vinv, 488 | C_init, 489 | discretization, 490 | dt_min, 491 | dt_max, 492 | conj_sym, 493 | clip_eigs, 494 | bidirectional 495 | ): 496 | """Convenience function that will be used to initialize the SSM. 497 | Same arguments as defined in S5SSM above.""" 498 | return partial(S5SSM, 499 | H=H, 500 | P=P, 501 | Lambda_re_init=Lambda_re_init, 502 | Lambda_im_init=Lambda_im_init, 503 | V=V, 504 | Vinv=Vinv, 505 | C_init=C_init, 506 | discretization=discretization, 507 | dt_min=dt_min, 508 | dt_max=dt_max, 509 | conj_sym=conj_sym, 510 | clip_eigs=clip_eigs, 511 | bidirectional=bidirectional) 512 | 513 | 514 | def make_HiPPO(N): 515 | """ Create a HiPPO-LegS matrix. 516 | From https://github.com/srush/annotated-s4/blob/main/s4/s4.py 517 | Args: 518 | N (int32): state size 519 | Returns: 520 | N x N HiPPO LegS matrix 521 | """ 522 | P = np.sqrt(1 + 2 * np.arange(N)) 523 | A = P[:, np.newaxis] * P[np.newaxis, :] 524 | A = np.tril(A) - np.diag(np.arange(N)) 525 | return -A 526 | 527 | 528 | def make_NPLR_HiPPO(N): 529 | """ 530 | Makes components needed for NPLR representation of HiPPO-LegS 531 | From https://github.com/srush/annotated-s4/blob/main/s4/s4.py 532 | Args: 533 | N (int32): state size 534 | Returns: 535 | N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B 536 | """ 537 | # Make -HiPPO 538 | hippo = make_HiPPO(N) 539 | 540 | # Add in a rank 1 term. Makes it Normal. 541 | P = np.sqrt(np.arange(N) + 0.5) 542 | 543 | # HiPPO also specifies the B matrix 544 | B = np.sqrt(2 * np.arange(N) + 1.0) 545 | return hippo, P, B 546 | 547 | 548 | def make_DPLR_HiPPO(N): 549 | """ 550 | Makes components needed for DPLR representation of HiPPO-LegS 551 | From https://github.com/srush/annotated-s4/blob/main/s4/s4.py 552 | Note, we will only use the diagonal part 553 | Args: 554 | N: 555 | Returns: 556 | eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, 557 | eigenvectors V, HiPPO B pre-conjugation 558 | """ 559 | A, P, B = make_NPLR_HiPPO(N) 560 | 561 | S = A + P[:, np.newaxis] * P[np.newaxis, :] 562 | 563 | S_diag = np.diagonal(S) 564 | Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) 565 | 566 | # Diagonalize S to V \Lambda V^* 567 | Lambda_imag, V = eigh(S * -1j) 568 | 569 | P = V.conj().T @ P 570 | B_orig = B 571 | B = V.conj().T @ B 572 | return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig 573 | 574 | class StackedEncoderModel(nn.Module): 575 | """ Defines a stack of S5 layers to be used as an encoder. 576 | Args: 577 | ssm (nn.Module): the SSM to be used (i.e. S5 ssm) 578 | d_model (int32): this is the feature size of the layer inputs and outputs 579 | we usually refer to this size as H 580 | n_layers (int32): the number of S5 layers to stack 581 | activation (string): Type of activation function to use 582 | dropout (float32): dropout rate 583 | training (bool): whether in training mode or not 584 | prenorm (bool): apply prenorm if true or postnorm if false 585 | batchnorm (bool): apply batchnorm if true or layernorm if false 586 | bn_momentum (float32): the batchnorm momentum if batchnorm is used 587 | step_rescale (float32): allows for uniformly changing the timescale parameter, 588 | e.g. after training on a different resolution for 589 | the speech commands benchmark 590 | """ 591 | ssm: nn.Module 592 | d_model: int 593 | n_layers: int 594 | activation: str = "gelu" 595 | 596 | def setup(self): 597 | """ 598 | Initializes a linear encoder and the stack of S5 layers. 599 | """ 600 | self.layers = [ 601 | SequenceLayer( 602 | ssm=self.ssm, 603 | d_model=self.d_model, 604 | activation=self.activation, 605 | ) 606 | for _ in range(self.n_layers) 607 | ] 608 | 609 | def __call__(self, hidden, x, d): 610 | """ 611 | Compute the LxH output of the stacked encoder given an Lxd_input 612 | input sequence. 613 | Args: 614 | x (float32): input sequence (L, d_input) 615 | Returns: 616 | output sequence (float32): (L, d_model) 617 | """ 618 | new_hiddens = [] 619 | for i, layer in enumerate(self.layers): 620 | new_h, x = layer(hidden[i], x, d) 621 | new_hiddens.append(new_h) 622 | 623 | return new_hiddens, x 624 | 625 | @staticmethod 626 | def initialize_carry(batch_size, hidden_size, n_layers): 627 | # Use a dummy key since the default state init fn is just zeros. 628 | return [jnp.zeros((1, batch_size, hidden_size), dtype=jnp.complex64) for _ in range(n_layers)] -------------------------------------------------------------------------------- /purejaxrl/experimental/s5/wrappers.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import chex 4 | import numpy as np 5 | from flax import struct 6 | from functools import partial 7 | from typing import Optional, Tuple, Union, Any 8 | from gymnax.environments import environment, spaces 9 | from brax import envs 10 | 11 | class GymnaxWrapper(object): 12 | """Base class for Gymnax wrappers.""" 13 | 14 | def __init__(self, env): 15 | self._env = env 16 | 17 | # provide proxy access to regular attributes of wrapped object 18 | def __getattr__(self, name): 19 | return getattr(self._env, name) 20 | 21 | class FlattenObservationWrapper(GymnaxWrapper): 22 | """Flatten the observations of the environment.""" 23 | 24 | def __init__(self, env: environment.Environment): 25 | super().__init__(env) 26 | 27 | def observation_space(self, params) -> spaces.Box: 28 | assert isinstance(self._env.observation_space(params), spaces.Box), "Only Box spaces are supported for now." 29 | return spaces.Box( 30 | low=self._env.observation_space(params).low, 31 | high=self._env.observation_space(params).high, 32 | shape=(np.prod(self._env.observation_space(params).shape),), 33 | dtype=self._env.observation_space(params).dtype, 34 | ) 35 | 36 | @partial(jax.jit, static_argnums=(0,)) 37 | def reset( 38 | self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None 39 | ) -> Tuple[chex.Array, environment.EnvState]: 40 | obs, state = self._env.reset(key, params) 41 | obs = jnp.reshape(obs, (-1,)) 42 | return obs, state 43 | 44 | @partial(jax.jit, static_argnums=(0,)) 45 | def step( 46 | self, 47 | key: chex.PRNGKey, 48 | state: environment.EnvState, 49 | action: Union[int, float], 50 | params: Optional[environment.EnvParams] = None, 51 | ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]: 52 | obs, state, reward, done, info = self._env.step(key, state, action, params) 53 | obs = jnp.reshape(obs, (-1,)) 54 | return obs, state, reward, done, info 55 | 56 | @struct.dataclass 57 | class LogEnvState: 58 | env_state: environment.EnvState 59 | episode_returns: float 60 | episode_lengths: int 61 | returned_episode_returns: float 62 | returned_episode_lengths: int 63 | timestep: int 64 | 65 | class LogWrapper(GymnaxWrapper): 66 | """Log the episode returns and lengths.""" 67 | 68 | def __init__(self, env: environment.Environment): 69 | super().__init__(env) 70 | 71 | @partial(jax.jit, static_argnums=(0,)) 72 | def reset( 73 | self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None 74 | ) -> Tuple[chex.Array, environment.EnvState]: 75 | obs, env_state = self._env.reset(key, params) 76 | state = LogEnvState(env_state, 0, 0, 0, 0, 0) 77 | return obs, state 78 | 79 | @partial(jax.jit, static_argnums=(0,)) 80 | def step( 81 | self, 82 | key: chex.PRNGKey, 83 | state: environment.EnvState, 84 | action: Union[int, float], 85 | params: Optional[environment.EnvParams] = None, 86 | ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]: 87 | obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params) 88 | new_episode_return = state.episode_returns + reward 89 | new_episode_length = state.episode_lengths + 1 90 | state = LogEnvState( 91 | env_state = env_state, 92 | episode_returns = new_episode_return * (1 - done), 93 | episode_lengths = new_episode_length * (1 - done), 94 | returned_episode_returns = state.returned_episode_returns * (1 - done) + new_episode_return * done, 95 | returned_episode_lengths = state.returned_episode_lengths * (1 - done) + new_episode_length * done, 96 | timestep = state.timestep + 1, 97 | ) 98 | info["returned_episode_returns"] = state.returned_episode_returns 99 | info["returned_episode_lengths"] = state.returned_episode_lengths 100 | info["timestep"] = state.timestep 101 | info["returned_episode"] = done 102 | return obs, state, reward, done, info 103 | 104 | class BraxGymnaxWrapper: 105 | def __init__(self, env_name, backend="positional"): 106 | env = envs.get_environment(env_name=env_name, backend=backend) 107 | env = envs.wrapper.EpisodeWrapper(env, episode_length=1000, action_repeat=1) 108 | env = envs.wrapper.AutoResetWrapper(env) 109 | self._env = env 110 | self.action_size = env.action_size 111 | self.observation_size = (env.observation_size,) 112 | 113 | def reset(self, key, params=None): 114 | state = self._env.reset(key) 115 | return state.obs, state 116 | 117 | def step(self, key, state, action, params=None): 118 | next_state = self._env.step(state, action) 119 | return next_state.obs, next_state, next_state.reward, next_state.done > 0.5, {} 120 | 121 | def observation_space(self, params): 122 | return spaces.Box( 123 | low=-jnp.inf, 124 | high=jnp.inf, 125 | shape=(self._env.observation_size,), 126 | ) 127 | 128 | def action_space(self, params): 129 | return spaces.Box( 130 | low=-1.0, 131 | high=1.0, 132 | shape=(self._env.action_size,), 133 | ) 134 | 135 | class ClipAction(GymnaxWrapper): 136 | def __init__(self, env, low=-1.0, high=1.0): 137 | super().__init__(env) 138 | self.low = low 139 | self.high = high 140 | 141 | def step(self, key, state, action, params=None): 142 | """TODO: In theory the below line should be the way to do this.""" 143 | # action = jnp.clip(action, self.env.action_space.low, self.env.action_space.high) 144 | action = jnp.clip(action, self.low, self.high) 145 | return self._env.step(key, state, action, params) 146 | 147 | class TransformObservation(GymnaxWrapper): 148 | def __init__(self, env, transform_obs): 149 | super().__init__(env) 150 | self.transform_obs = transform_obs 151 | 152 | def reset(self, key, params=None): 153 | obs, state = self._env.reset(key, params) 154 | return self.transform_obs(obs), state 155 | 156 | def step(self, key, state, action, params=None): 157 | obs, state, reward, done, info = self._env.step(key, state, action, params) 158 | return self.transform_obs(obs), state, reward, done, info 159 | 160 | class TransformReward(GymnaxWrapper): 161 | def __init__(self, env, transform_reward): 162 | super().__init__(env) 163 | self.transform_reward = transform_reward 164 | 165 | def step(self, key, state, action, params=None): 166 | obs, state, reward, done, info = self._env.step(key, state, action, params) 167 | return obs, state, self.transform_reward(reward), done, info 168 | 169 | 170 | class VecEnv(GymnaxWrapper): 171 | def __init__(self, env): 172 | super().__init__(env) 173 | self.reset = jax.vmap(self._env.reset, in_axes=(0, None)) 174 | self.step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None)) 175 | 176 | @struct.dataclass 177 | class NormalizeVecObsEnvState: 178 | mean: jnp.ndarray 179 | var: jnp.ndarray 180 | count: float 181 | env_state: environment.EnvState 182 | 183 | class NormalizeVecObservation(GymnaxWrapper): 184 | def __init__(self, env): 185 | super().__init__(env) 186 | 187 | def reset(self, key, params=None): 188 | obs, state = self._env.reset(key, params) 189 | state = NormalizeVecObsEnvState( 190 | mean=jnp.zeros_like(obs), 191 | var=jnp.ones_like(obs), 192 | count=1e-4, 193 | env_state=state, 194 | ) 195 | batch_mean = jnp.mean(obs, axis=0) 196 | batch_var = jnp.var(obs, axis=0) 197 | batch_count = obs.shape[0] 198 | 199 | delta = batch_mean - state.mean 200 | tot_count = state.count + batch_count 201 | 202 | new_mean = state.mean + delta * batch_count / tot_count 203 | m_a = state.var * state.count 204 | m_b = batch_var * batch_count 205 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count 206 | new_var = M2 / tot_count 207 | new_count = tot_count 208 | 209 | state = NormalizeVecObsEnvState( 210 | mean=new_mean, 211 | var=new_var, 212 | count=new_count, 213 | env_state=state.env_state, 214 | ) 215 | 216 | return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state 217 | 218 | def step(self, key, state, action, params=None): 219 | obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params) 220 | 221 | batch_mean = jnp.mean(obs, axis=0) 222 | batch_var = jnp.var(obs, axis=0) 223 | batch_count = obs.shape[0] 224 | 225 | delta = batch_mean - state.mean 226 | tot_count = state.count + batch_count 227 | 228 | new_mean = state.mean + delta * batch_count / tot_count 229 | m_a = state.var * state.count 230 | m_b = batch_var * batch_count 231 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count 232 | new_var = M2 / tot_count 233 | new_count = tot_count 234 | 235 | state = NormalizeVecObsEnvState( 236 | mean=new_mean, 237 | var=new_var, 238 | count=new_count, 239 | env_state=env_state, 240 | ) 241 | return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state, reward, done, info 242 | 243 | 244 | @struct.dataclass 245 | class NormalizeVecRewEnvState: 246 | mean: jnp.ndarray 247 | var: jnp.ndarray 248 | count: float 249 | return_val: float 250 | env_state: environment.EnvState 251 | 252 | class NormalizeVecReward(GymnaxWrapper): 253 | 254 | def __init__(self, env, gamma): 255 | super().__init__(env) 256 | self.gamma = gamma 257 | 258 | def reset(self, key, params=None): 259 | obs, state = self._env.reset(key, params) 260 | batch_count = obs.shape[0] 261 | state = NormalizeVecRewEnvState( 262 | mean=0.0, 263 | var=1.0, 264 | count=1e-4, 265 | return_val=jnp.zeros((batch_count,)), 266 | env_state=state, 267 | ) 268 | return obs, state 269 | 270 | def step(self, key, state, action, params=None): 271 | obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params) 272 | return_val = (state.return_val * self.gamma * (1 - done) + reward) 273 | 274 | batch_mean = jnp.mean(return_val, axis=0) 275 | batch_var = jnp.var(return_val, axis=0) 276 | batch_count = obs.shape[0] 277 | 278 | delta = batch_mean - state.mean 279 | tot_count = state.count + batch_count 280 | 281 | new_mean = state.mean + delta * batch_count / tot_count 282 | m_a = state.var * state.count 283 | m_b = batch_var * batch_count 284 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count 285 | new_var = M2 / tot_count 286 | new_count = tot_count 287 | 288 | state = NormalizeVecRewEnvState( 289 | mean=new_mean, 290 | var=new_var, 291 | count=new_count, 292 | return_val=return_val, 293 | env_state=env_state, 294 | ) 295 | return obs, state, reward / jnp.sqrt(state.var + 1e-8), done, info 296 | -------------------------------------------------------------------------------- /purejaxrl/ppo.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | import numpy as np 5 | import optax 6 | from flax.linen.initializers import constant, orthogonal 7 | from typing import Sequence, NamedTuple, Any 8 | from flax.training.train_state import TrainState 9 | import distrax 10 | import gymnax 11 | from wrappers import LogWrapper, FlattenObservationWrapper 12 | 13 | 14 | class ActorCritic(nn.Module): 15 | action_dim: Sequence[int] 16 | activation: str = "tanh" 17 | 18 | @nn.compact 19 | def __call__(self, x): 20 | if self.activation == "relu": 21 | activation = nn.relu 22 | else: 23 | activation = nn.tanh 24 | actor_mean = nn.Dense( 25 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 26 | )(x) 27 | actor_mean = activation(actor_mean) 28 | actor_mean = nn.Dense( 29 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 30 | )(actor_mean) 31 | actor_mean = activation(actor_mean) 32 | actor_mean = nn.Dense( 33 | self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) 34 | )(actor_mean) 35 | pi = distrax.Categorical(logits=actor_mean) 36 | 37 | critic = nn.Dense( 38 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 39 | )(x) 40 | critic = activation(critic) 41 | critic = nn.Dense( 42 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 43 | )(critic) 44 | critic = activation(critic) 45 | critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( 46 | critic 47 | ) 48 | 49 | return pi, jnp.squeeze(critic, axis=-1) 50 | 51 | 52 | class Transition(NamedTuple): 53 | done: jnp.ndarray 54 | action: jnp.ndarray 55 | value: jnp.ndarray 56 | reward: jnp.ndarray 57 | log_prob: jnp.ndarray 58 | obs: jnp.ndarray 59 | info: jnp.ndarray 60 | 61 | 62 | def make_train(config): 63 | config["NUM_UPDATES"] = ( 64 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] 65 | ) 66 | config["MINIBATCH_SIZE"] = ( 67 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] 68 | ) 69 | env, env_params = gymnax.make(config["ENV_NAME"]) 70 | env = FlattenObservationWrapper(env) 71 | env = LogWrapper(env) 72 | 73 | def linear_schedule(count): 74 | frac = ( 75 | 1.0 76 | - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) 77 | / config["NUM_UPDATES"] 78 | ) 79 | return config["LR"] * frac 80 | 81 | def train(rng): 82 | # INIT NETWORK 83 | network = ActorCritic( 84 | env.action_space(env_params).n, activation=config["ACTIVATION"] 85 | ) 86 | rng, _rng = jax.random.split(rng) 87 | init_x = jnp.zeros(env.observation_space(env_params).shape) 88 | network_params = network.init(_rng, init_x) 89 | if config["ANNEAL_LR"]: 90 | tx = optax.chain( 91 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), 92 | optax.adam(learning_rate=linear_schedule, eps=1e-5), 93 | ) 94 | else: 95 | tx = optax.chain( 96 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), 97 | optax.adam(config["LR"], eps=1e-5), 98 | ) 99 | train_state = TrainState.create( 100 | apply_fn=network.apply, 101 | params=network_params, 102 | tx=tx, 103 | ) 104 | 105 | # INIT ENV 106 | rng, _rng = jax.random.split(rng) 107 | reset_rng = jax.random.split(_rng, config["NUM_ENVS"]) 108 | obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params) 109 | 110 | # TRAIN LOOP 111 | def _update_step(runner_state, unused): 112 | # COLLECT TRAJECTORIES 113 | def _env_step(runner_state, unused): 114 | train_state, env_state, last_obs, rng = runner_state 115 | 116 | # SELECT ACTION 117 | rng, _rng = jax.random.split(rng) 118 | pi, value = network.apply(train_state.params, last_obs) 119 | action = pi.sample(seed=_rng) 120 | log_prob = pi.log_prob(action) 121 | 122 | # STEP ENV 123 | rng, _rng = jax.random.split(rng) 124 | rng_step = jax.random.split(_rng, config["NUM_ENVS"]) 125 | obsv, env_state, reward, done, info = jax.vmap( 126 | env.step, in_axes=(0, 0, 0, None) 127 | )(rng_step, env_state, action, env_params) 128 | transition = Transition( 129 | done, action, value, reward, log_prob, last_obs, info 130 | ) 131 | runner_state = (train_state, env_state, obsv, rng) 132 | return runner_state, transition 133 | 134 | runner_state, traj_batch = jax.lax.scan( 135 | _env_step, runner_state, None, config["NUM_STEPS"] 136 | ) 137 | 138 | # CALCULATE ADVANTAGE 139 | train_state, env_state, last_obs, rng = runner_state 140 | _, last_val = network.apply(train_state.params, last_obs) 141 | 142 | def _calculate_gae(traj_batch, last_val): 143 | def _get_advantages(gae_and_next_value, transition): 144 | gae, next_value = gae_and_next_value 145 | done, value, reward = ( 146 | transition.done, 147 | transition.value, 148 | transition.reward, 149 | ) 150 | delta = reward + config["GAMMA"] * next_value * (1 - done) - value 151 | gae = ( 152 | delta 153 | + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae 154 | ) 155 | return (gae, value), gae 156 | 157 | _, advantages = jax.lax.scan( 158 | _get_advantages, 159 | (jnp.zeros_like(last_val), last_val), 160 | traj_batch, 161 | reverse=True, 162 | unroll=16, 163 | ) 164 | return advantages, advantages + traj_batch.value 165 | 166 | advantages, targets = _calculate_gae(traj_batch, last_val) 167 | 168 | # UPDATE NETWORK 169 | def _update_epoch(update_state, unused): 170 | def _update_minbatch(train_state, batch_info): 171 | traj_batch, advantages, targets = batch_info 172 | 173 | def _loss_fn(params, traj_batch, gae, targets): 174 | # RERUN NETWORK 175 | pi, value = network.apply(params, traj_batch.obs) 176 | log_prob = pi.log_prob(traj_batch.action) 177 | 178 | # CALCULATE VALUE LOSS 179 | value_pred_clipped = traj_batch.value + ( 180 | value - traj_batch.value 181 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) 182 | value_losses = jnp.square(value - targets) 183 | value_losses_clipped = jnp.square(value_pred_clipped - targets) 184 | value_loss = ( 185 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() 186 | ) 187 | 188 | # CALCULATE ACTOR LOSS 189 | ratio = jnp.exp(log_prob - traj_batch.log_prob) 190 | gae = (gae - gae.mean()) / (gae.std() + 1e-8) 191 | loss_actor1 = ratio * gae 192 | loss_actor2 = ( 193 | jnp.clip( 194 | ratio, 195 | 1.0 - config["CLIP_EPS"], 196 | 1.0 + config["CLIP_EPS"], 197 | ) 198 | * gae 199 | ) 200 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2) 201 | loss_actor = loss_actor.mean() 202 | entropy = pi.entropy().mean() 203 | 204 | total_loss = ( 205 | loss_actor 206 | + config["VF_COEF"] * value_loss 207 | - config["ENT_COEF"] * entropy 208 | ) 209 | return total_loss, (value_loss, loss_actor, entropy) 210 | 211 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) 212 | total_loss, grads = grad_fn( 213 | train_state.params, traj_batch, advantages, targets 214 | ) 215 | train_state = train_state.apply_gradients(grads=grads) 216 | return train_state, total_loss 217 | 218 | train_state, traj_batch, advantages, targets, rng = update_state 219 | rng, _rng = jax.random.split(rng) 220 | # Batching and Shuffling 221 | batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"] 222 | assert ( 223 | batch_size == config["NUM_STEPS"] * config["NUM_ENVS"] 224 | ), "batch size must be equal to number of steps * number of envs" 225 | permutation = jax.random.permutation(_rng, batch_size) 226 | batch = (traj_batch, advantages, targets) 227 | batch = jax.tree_util.tree_map( 228 | lambda x: x.reshape((batch_size,) + x.shape[2:]), batch 229 | ) 230 | shuffled_batch = jax.tree_util.tree_map( 231 | lambda x: jnp.take(x, permutation, axis=0), batch 232 | ) 233 | # Mini-batch Updates 234 | minibatches = jax.tree_util.tree_map( 235 | lambda x: jnp.reshape( 236 | x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) 237 | ), 238 | shuffled_batch, 239 | ) 240 | train_state, total_loss = jax.lax.scan( 241 | _update_minbatch, train_state, minibatches 242 | ) 243 | update_state = (train_state, traj_batch, advantages, targets, rng) 244 | return update_state, total_loss 245 | # Updating Training State and Metrics: 246 | update_state = (train_state, traj_batch, advantages, targets, rng) 247 | update_state, loss_info = jax.lax.scan( 248 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"] 249 | ) 250 | train_state = update_state[0] 251 | metric = traj_batch.info 252 | rng = update_state[-1] 253 | 254 | # Debugging mode 255 | if config.get("DEBUG"): 256 | def callback(info): 257 | return_values = info["returned_episode_returns"][info["returned_episode"]] 258 | timesteps = info["timestep"][info["returned_episode"]] * config["NUM_ENVS"] 259 | for t in range(len(timesteps)): 260 | print(f"global step={timesteps[t]}, episodic return={return_values[t]}") 261 | jax.debug.callback(callback, metric) 262 | 263 | runner_state = (train_state, env_state, last_obs, rng) 264 | return runner_state, metric 265 | 266 | rng, _rng = jax.random.split(rng) 267 | runner_state = (train_state, env_state, obsv, _rng) 268 | runner_state, metric = jax.lax.scan( 269 | _update_step, runner_state, None, config["NUM_UPDATES"] 270 | ) 271 | return {"runner_state": runner_state, "metrics": metric} 272 | 273 | return train 274 | 275 | 276 | if __name__ == "__main__": 277 | config = { 278 | "LR": 2.5e-4, 279 | "NUM_ENVS": 4, 280 | "NUM_STEPS": 128, 281 | "TOTAL_TIMESTEPS": 5e5, 282 | "UPDATE_EPOCHS": 4, 283 | "NUM_MINIBATCHES": 4, 284 | "GAMMA": 0.99, 285 | "GAE_LAMBDA": 0.95, 286 | "CLIP_EPS": 0.2, 287 | "ENT_COEF": 0.01, 288 | "VF_COEF": 0.5, 289 | "MAX_GRAD_NORM": 0.5, 290 | "ACTIVATION": "tanh", 291 | "ENV_NAME": "CartPole-v1", 292 | "ANNEAL_LR": True, 293 | "DEBUG": True, 294 | } 295 | rng = jax.random.PRNGKey(30) 296 | train_jit = jax.jit(make_train(config)) 297 | out = train_jit(rng) 298 | -------------------------------------------------------------------------------- /purejaxrl/ppo_continuous_action.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | import numpy as np 5 | import optax 6 | from flax.linen.initializers import constant, orthogonal 7 | from typing import Sequence, NamedTuple, Any 8 | from flax.training.train_state import TrainState 9 | import distrax 10 | from wrappers import ( 11 | LogWrapper, 12 | BraxGymnaxWrapper, 13 | VecEnv, 14 | NormalizeVecObservation, 15 | NormalizeVecReward, 16 | ClipAction, 17 | ) 18 | 19 | 20 | class ActorCritic(nn.Module): 21 | action_dim: Sequence[int] 22 | activation: str = "tanh" 23 | 24 | @nn.compact 25 | def __call__(self, x): 26 | if self.activation == "relu": 27 | activation = nn.relu 28 | else: 29 | activation = nn.tanh 30 | actor_mean = nn.Dense( 31 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 32 | )(x) 33 | actor_mean = activation(actor_mean) 34 | actor_mean = nn.Dense( 35 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 36 | )(actor_mean) 37 | actor_mean = activation(actor_mean) 38 | actor_mean = nn.Dense( 39 | self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) 40 | )(actor_mean) 41 | actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,)) 42 | pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd)) 43 | 44 | critic = nn.Dense( 45 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 46 | )(x) 47 | critic = activation(critic) 48 | critic = nn.Dense( 49 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 50 | )(critic) 51 | critic = activation(critic) 52 | critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( 53 | critic 54 | ) 55 | 56 | return pi, jnp.squeeze(critic, axis=-1) 57 | 58 | 59 | class Transition(NamedTuple): 60 | done: jnp.ndarray 61 | action: jnp.ndarray 62 | value: jnp.ndarray 63 | reward: jnp.ndarray 64 | log_prob: jnp.ndarray 65 | obs: jnp.ndarray 66 | info: jnp.ndarray 67 | 68 | 69 | def make_train(config): 70 | config["NUM_UPDATES"] = ( 71 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] 72 | ) 73 | config["MINIBATCH_SIZE"] = ( 74 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] 75 | ) 76 | env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None 77 | env = LogWrapper(env) 78 | env = ClipAction(env) 79 | env = VecEnv(env) 80 | if config["NORMALIZE_ENV"]: 81 | env = NormalizeVecObservation(env) 82 | env = NormalizeVecReward(env, config["GAMMA"]) 83 | 84 | def linear_schedule(count): 85 | frac = ( 86 | 1.0 87 | - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) 88 | / config["NUM_UPDATES"] 89 | ) 90 | return config["LR"] * frac 91 | 92 | def train(rng): 93 | # INIT NETWORK 94 | network = ActorCritic( 95 | env.action_space(env_params).shape[0], activation=config["ACTIVATION"] 96 | ) 97 | rng, _rng = jax.random.split(rng) 98 | init_x = jnp.zeros(env.observation_space(env_params).shape) 99 | network_params = network.init(_rng, init_x) 100 | if config["ANNEAL_LR"]: 101 | tx = optax.chain( 102 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), 103 | optax.adam(learning_rate=linear_schedule, eps=1e-5), 104 | ) 105 | else: 106 | tx = optax.chain( 107 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), 108 | optax.adam(config["LR"], eps=1e-5), 109 | ) 110 | train_state = TrainState.create( 111 | apply_fn=network.apply, 112 | params=network_params, 113 | tx=tx, 114 | ) 115 | 116 | # INIT ENV 117 | rng, _rng = jax.random.split(rng) 118 | reset_rng = jax.random.split(_rng, config["NUM_ENVS"]) 119 | obsv, env_state = env.reset(reset_rng, env_params) 120 | 121 | # TRAIN LOOP 122 | def _update_step(runner_state, unused): 123 | # COLLECT TRAJECTORIES 124 | def _env_step(runner_state, unused): 125 | train_state, env_state, last_obs, rng = runner_state 126 | 127 | # SELECT ACTION 128 | rng, _rng = jax.random.split(rng) 129 | pi, value = network.apply(train_state.params, last_obs) 130 | action = pi.sample(seed=_rng) 131 | log_prob = pi.log_prob(action) 132 | 133 | # STEP ENV 134 | rng, _rng = jax.random.split(rng) 135 | rng_step = jax.random.split(_rng, config["NUM_ENVS"]) 136 | obsv, env_state, reward, done, info = env.step( 137 | rng_step, env_state, action, env_params 138 | ) 139 | transition = Transition( 140 | done, action, value, reward, log_prob, last_obs, info 141 | ) 142 | runner_state = (train_state, env_state, obsv, rng) 143 | return runner_state, transition 144 | 145 | runner_state, traj_batch = jax.lax.scan( 146 | _env_step, runner_state, None, config["NUM_STEPS"] 147 | ) 148 | 149 | # CALCULATE ADVANTAGE 150 | train_state, env_state, last_obs, rng = runner_state 151 | _, last_val = network.apply(train_state.params, last_obs) 152 | 153 | def _calculate_gae(traj_batch, last_val): 154 | def _get_advantages(gae_and_next_value, transition): 155 | gae, next_value = gae_and_next_value 156 | done, value, reward = ( 157 | transition.done, 158 | transition.value, 159 | transition.reward, 160 | ) 161 | delta = reward + config["GAMMA"] * next_value * (1 - done) - value 162 | gae = ( 163 | delta 164 | + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae 165 | ) 166 | return (gae, value), gae 167 | 168 | _, advantages = jax.lax.scan( 169 | _get_advantages, 170 | (jnp.zeros_like(last_val), last_val), 171 | traj_batch, 172 | reverse=True, 173 | unroll=16, 174 | ) 175 | return advantages, advantages + traj_batch.value 176 | 177 | advantages, targets = _calculate_gae(traj_batch, last_val) 178 | 179 | # UPDATE NETWORK 180 | def _update_epoch(update_state, unused): 181 | def _update_minbatch(train_state, batch_info): 182 | traj_batch, advantages, targets = batch_info 183 | 184 | def _loss_fn(params, traj_batch, gae, targets): 185 | # RERUN NETWORK 186 | pi, value = network.apply(params, traj_batch.obs) 187 | log_prob = pi.log_prob(traj_batch.action) 188 | 189 | # CALCULATE VALUE LOSS 190 | value_pred_clipped = traj_batch.value + ( 191 | value - traj_batch.value 192 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) 193 | value_losses = jnp.square(value - targets) 194 | value_losses_clipped = jnp.square(value_pred_clipped - targets) 195 | value_loss = ( 196 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() 197 | ) 198 | 199 | # CALCULATE ACTOR LOSS 200 | ratio = jnp.exp(log_prob - traj_batch.log_prob) 201 | gae = (gae - gae.mean()) / (gae.std() + 1e-8) 202 | loss_actor1 = ratio * gae 203 | loss_actor2 = ( 204 | jnp.clip( 205 | ratio, 206 | 1.0 - config["CLIP_EPS"], 207 | 1.0 + config["CLIP_EPS"], 208 | ) 209 | * gae 210 | ) 211 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2) 212 | loss_actor = loss_actor.mean() 213 | entropy = pi.entropy().mean() 214 | 215 | total_loss = ( 216 | loss_actor 217 | + config["VF_COEF"] * value_loss 218 | - config["ENT_COEF"] * entropy 219 | ) 220 | return total_loss, (value_loss, loss_actor, entropy) 221 | 222 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) 223 | total_loss, grads = grad_fn( 224 | train_state.params, traj_batch, advantages, targets 225 | ) 226 | train_state = train_state.apply_gradients(grads=grads) 227 | return train_state, total_loss 228 | 229 | train_state, traj_batch, advantages, targets, rng = update_state 230 | rng, _rng = jax.random.split(rng) 231 | batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"] 232 | assert ( 233 | batch_size == config["NUM_STEPS"] * config["NUM_ENVS"] 234 | ), "batch size must be equal to number of steps * number of envs" 235 | permutation = jax.random.permutation(_rng, batch_size) 236 | batch = (traj_batch, advantages, targets) 237 | batch = jax.tree_util.tree_map( 238 | lambda x: x.reshape((batch_size,) + x.shape[2:]), batch 239 | ) 240 | shuffled_batch = jax.tree_util.tree_map( 241 | lambda x: jnp.take(x, permutation, axis=0), batch 242 | ) 243 | minibatches = jax.tree_util.tree_map( 244 | lambda x: jnp.reshape( 245 | x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) 246 | ), 247 | shuffled_batch, 248 | ) 249 | train_state, total_loss = jax.lax.scan( 250 | _update_minbatch, train_state, minibatches 251 | ) 252 | update_state = (train_state, traj_batch, advantages, targets, rng) 253 | return update_state, total_loss 254 | 255 | update_state = (train_state, traj_batch, advantages, targets, rng) 256 | update_state, loss_info = jax.lax.scan( 257 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"] 258 | ) 259 | train_state = update_state[0] 260 | metric = traj_batch.info 261 | rng = update_state[-1] 262 | if config.get("DEBUG"): 263 | 264 | def callback(info): 265 | return_values = info["returned_episode_returns"][ 266 | info["returned_episode"] 267 | ] 268 | timesteps = ( 269 | info["timestep"][info["returned_episode"]] * config["NUM_ENVS"] 270 | ) 271 | for t in range(len(timesteps)): 272 | print( 273 | f"global step={timesteps[t]}, episodic return={return_values[t]}" 274 | ) 275 | 276 | jax.debug.callback(callback, metric) 277 | 278 | runner_state = (train_state, env_state, last_obs, rng) 279 | return runner_state, metric 280 | 281 | rng, _rng = jax.random.split(rng) 282 | runner_state = (train_state, env_state, obsv, _rng) 283 | runner_state, metric = jax.lax.scan( 284 | _update_step, runner_state, None, config["NUM_UPDATES"] 285 | ) 286 | return {"runner_state": runner_state, "metrics": metric} 287 | 288 | return train 289 | 290 | 291 | if __name__ == "__main__": 292 | config = { 293 | "LR": 3e-4, 294 | "NUM_ENVS": 2048, 295 | "NUM_STEPS": 10, 296 | "TOTAL_TIMESTEPS": 5e7, 297 | "UPDATE_EPOCHS": 4, 298 | "NUM_MINIBATCHES": 32, 299 | "GAMMA": 0.99, 300 | "GAE_LAMBDA": 0.95, 301 | "CLIP_EPS": 0.2, 302 | "ENT_COEF": 0.0, 303 | "VF_COEF": 0.5, 304 | "MAX_GRAD_NORM": 0.5, 305 | "ACTIVATION": "tanh", 306 | "ENV_NAME": "hopper", 307 | "ANNEAL_LR": False, 308 | "NORMALIZE_ENV": True, 309 | "DEBUG": True, 310 | } 311 | rng = jax.random.PRNGKey(30) 312 | train_jit = jax.jit(make_train(config)) 313 | out = train_jit(rng) 314 | -------------------------------------------------------------------------------- /purejaxrl/ppo_minigrid.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | import numpy as np 5 | import optax 6 | from flax.linen.initializers import constant, orthogonal 7 | from typing import Sequence, NamedTuple, Any 8 | from flax.training.train_state import TrainState 9 | import distrax 10 | import gymnax 11 | from wrappers import LogWrapper, FlattenObservationWrapper, NavixGymnaxWrapper 12 | 13 | 14 | class ActorCritic(nn.Module): 15 | action_dim: Sequence[int] 16 | activation: str = "tanh" 17 | 18 | @nn.compact 19 | def __call__(self, x): 20 | if self.activation == "relu": 21 | activation = nn.relu 22 | else: 23 | activation = nn.tanh 24 | actor_mean = nn.Dense( 25 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 26 | )(x) 27 | actor_mean = activation(actor_mean) 28 | actor_mean = nn.Dense( 29 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 30 | )(actor_mean) 31 | actor_mean = activation(actor_mean) 32 | actor_mean = nn.Dense( 33 | self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) 34 | )(actor_mean) 35 | pi = distrax.Categorical(logits=actor_mean) 36 | 37 | critic = nn.Dense( 38 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 39 | )(x) 40 | critic = activation(critic) 41 | critic = nn.Dense( 42 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 43 | )(critic) 44 | critic = activation(critic) 45 | critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( 46 | critic 47 | ) 48 | 49 | return pi, jnp.squeeze(critic, axis=-1) 50 | 51 | 52 | class Transition(NamedTuple): 53 | done: jnp.ndarray 54 | action: jnp.ndarray 55 | value: jnp.ndarray 56 | reward: jnp.ndarray 57 | log_prob: jnp.ndarray 58 | obs: jnp.ndarray 59 | info: jnp.ndarray 60 | 61 | 62 | def make_train(config): 63 | config["NUM_UPDATES"] = ( 64 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] 65 | ) 66 | config["MINIBATCH_SIZE"] = ( 67 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] 68 | ) 69 | env, env_params = NavixGymnaxWrapper(config["ENV_NAME"]), None 70 | env = FlattenObservationWrapper(env) 71 | env = LogWrapper(env) 72 | 73 | def linear_schedule(count): 74 | frac = ( 75 | 1.0 76 | - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) 77 | / config["NUM_UPDATES"] 78 | ) 79 | return config["LR"] * frac 80 | 81 | def train(rng): 82 | # INIT NETWORK 83 | network = ActorCritic( 84 | env.action_space(env_params).n, activation=config["ACTIVATION"] 85 | ) 86 | rng, _rng = jax.random.split(rng) 87 | init_x = jnp.zeros(env.observation_space(env_params).shape) 88 | import pdb; pdb.set_trace() 89 | network_params = network.init(_rng, init_x) 90 | if config["ANNEAL_LR"]: 91 | tx = optax.chain( 92 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), 93 | optax.adam(learning_rate=linear_schedule, eps=1e-5), 94 | ) 95 | else: 96 | tx = optax.chain( 97 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), 98 | optax.adam(config["LR"], eps=1e-5), 99 | ) 100 | train_state = TrainState.create( 101 | apply_fn=network.apply, 102 | params=network_params, 103 | tx=tx, 104 | ) 105 | 106 | # INIT ENV 107 | rng, _rng = jax.random.split(rng) 108 | reset_rng = jax.random.split(_rng, config["NUM_ENVS"]) 109 | obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params) 110 | 111 | # TRAIN LOOP 112 | def _update_step(runner_state, unused): 113 | # COLLECT TRAJECTORIES 114 | def _env_step(runner_state, unused): 115 | train_state, env_state, last_obs, rng = runner_state 116 | 117 | # SELECT ACTION 118 | rng, _rng = jax.random.split(rng) 119 | pi, value = network.apply(train_state.params, last_obs) 120 | action = pi.sample(seed=_rng) 121 | log_prob = pi.log_prob(action) 122 | 123 | # STEP ENV 124 | rng, _rng = jax.random.split(rng) 125 | rng_step = jax.random.split(_rng, config["NUM_ENVS"]) 126 | obsv, env_state, reward, done, info = jax.vmap( 127 | env.step, in_axes=(0, 0, 0, None) 128 | )(rng_step, env_state, action, env_params) 129 | transition = Transition( 130 | done, action, value, reward, log_prob, last_obs, info 131 | ) 132 | runner_state = (train_state, env_state, obsv, rng) 133 | return runner_state, transition 134 | 135 | runner_state, traj_batch = jax.lax.scan( 136 | _env_step, runner_state, None, config["NUM_STEPS"] 137 | ) 138 | 139 | # CALCULATE ADVANTAGE 140 | train_state, env_state, last_obs, rng = runner_state 141 | _, last_val = network.apply(train_state.params, last_obs) 142 | 143 | def _calculate_gae(traj_batch, last_val): 144 | def _get_advantages(gae_and_next_value, transition): 145 | gae, next_value = gae_and_next_value 146 | done, value, reward = ( 147 | transition.done, 148 | transition.value, 149 | transition.reward, 150 | ) 151 | delta = reward + config["GAMMA"] * next_value * (1 - done) - value 152 | gae = ( 153 | delta 154 | + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae 155 | ) 156 | return (gae, value), gae 157 | 158 | _, advantages = jax.lax.scan( 159 | _get_advantages, 160 | (jnp.zeros_like(last_val), last_val), 161 | traj_batch, 162 | reverse=True, 163 | unroll=16, 164 | ) 165 | return advantages, advantages + traj_batch.value 166 | 167 | advantages, targets = _calculate_gae(traj_batch, last_val) 168 | 169 | # UPDATE NETWORK 170 | def _update_epoch(update_state, unused): 171 | def _update_minbatch(train_state, batch_info): 172 | traj_batch, advantages, targets = batch_info 173 | 174 | def _loss_fn(params, traj_batch, gae, targets): 175 | # RERUN NETWORK 176 | pi, value = network.apply(params, traj_batch.obs) 177 | log_prob = pi.log_prob(traj_batch.action) 178 | 179 | # CALCULATE VALUE LOSS 180 | value_pred_clipped = traj_batch.value + ( 181 | value - traj_batch.value 182 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) 183 | value_losses = jnp.square(value - targets) 184 | value_losses_clipped = jnp.square(value_pred_clipped - targets) 185 | value_loss = ( 186 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() 187 | ) 188 | 189 | # CALCULATE ACTOR LOSS 190 | ratio = jnp.exp(log_prob - traj_batch.log_prob) 191 | gae = (gae - gae.mean()) / (gae.std() + 1e-8) 192 | loss_actor1 = ratio * gae 193 | loss_actor2 = ( 194 | jnp.clip( 195 | ratio, 196 | 1.0 - config["CLIP_EPS"], 197 | 1.0 + config["CLIP_EPS"], 198 | ) 199 | * gae 200 | ) 201 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2) 202 | loss_actor = loss_actor.mean() 203 | entropy = pi.entropy().mean() 204 | 205 | total_loss = ( 206 | loss_actor 207 | + config["VF_COEF"] * value_loss 208 | - config["ENT_COEF"] * entropy 209 | ) 210 | return total_loss, (value_loss, loss_actor, entropy) 211 | 212 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) 213 | total_loss, grads = grad_fn( 214 | train_state.params, traj_batch, advantages, targets 215 | ) 216 | train_state = train_state.apply_gradients(grads=grads) 217 | return train_state, total_loss 218 | 219 | train_state, traj_batch, advantages, targets, rng = update_state 220 | rng, _rng = jax.random.split(rng) 221 | # Batching and Shuffling 222 | batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"] 223 | assert ( 224 | batch_size == config["NUM_STEPS"] * config["NUM_ENVS"] 225 | ), "batch size must be equal to number of steps * number of envs" 226 | permutation = jax.random.permutation(_rng, batch_size) 227 | batch = (traj_batch, advantages, targets) 228 | batch = jax.tree_util.tree_map( 229 | lambda x: x.reshape((batch_size,) + x.shape[2:]), batch 230 | ) 231 | shuffled_batch = jax.tree_util.tree_map( 232 | lambda x: jnp.take(x, permutation, axis=0), batch 233 | ) 234 | # Mini-batch Updates 235 | minibatches = jax.tree_util.tree_map( 236 | lambda x: jnp.reshape( 237 | x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) 238 | ), 239 | shuffled_batch, 240 | ) 241 | train_state, total_loss = jax.lax.scan( 242 | _update_minbatch, train_state, minibatches 243 | ) 244 | update_state = (train_state, traj_batch, advantages, targets, rng) 245 | return update_state, total_loss 246 | # Updating Training State and Metrics: 247 | update_state = (train_state, traj_batch, advantages, targets, rng) 248 | update_state, loss_info = jax.lax.scan( 249 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"] 250 | ) 251 | train_state = update_state[0] 252 | metric = traj_batch.info 253 | rng = update_state[-1] 254 | 255 | # Debugging mode 256 | if config.get("DEBUG"): 257 | def callback(info): 258 | return_values = info["returned_episode_returns"][info["returned_episode"]] 259 | timesteps = info["timestep"][info["returned_episode"]] * config["NUM_ENVS"] 260 | for t in range(len(timesteps)): 261 | print(f"global step={timesteps[t]}, episodic return={return_values[t]}") 262 | jax.debug.callback(callback, metric) 263 | 264 | runner_state = (train_state, env_state, last_obs, rng) 265 | return runner_state, metric 266 | 267 | rng, _rng = jax.random.split(rng) 268 | runner_state = (train_state, env_state, obsv, _rng) 269 | runner_state, metric = jax.lax.scan( 270 | _update_step, runner_state, None, config["NUM_UPDATES"] 271 | ) 272 | return {"runner_state": runner_state, "metrics": metric} 273 | 274 | return train 275 | 276 | 277 | if __name__ == "__main__": 278 | config = { 279 | "LR": 2.5e-4, 280 | "NUM_ENVS": 16, 281 | "NUM_STEPS": 128, 282 | "TOTAL_TIMESTEPS": 1e6, 283 | "UPDATE_EPOCHS": 1, 284 | "NUM_MINIBATCHES": 8, 285 | "GAMMA": 0.99, 286 | "GAE_LAMBDA": 0.95, 287 | "CLIP_EPS": 0.2, 288 | "ENT_COEF": 0.01, 289 | "VF_COEF": 0.5, 290 | "MAX_GRAD_NORM": 0.5, 291 | "ACTIVATION": "tanh", 292 | "ENV_NAME": "Navix-DoorKey-5x5-v0", 293 | "ANNEAL_LR": True, 294 | "DEBUG": True, 295 | } 296 | rng = jax.random.PRNGKey(30) 297 | train_jit = jax.jit(make_train(config)) 298 | out = train_jit(rng) 299 | -------------------------------------------------------------------------------- /purejaxrl/ppo_rnn.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | import numpy as np 5 | import optax 6 | import time 7 | from flax.linen.initializers import constant, orthogonal 8 | from typing import Sequence, NamedTuple, Any, Dict 9 | from flax.training.train_state import TrainState 10 | import distrax 11 | import gymnax 12 | import functools 13 | from gymnax.environments import spaces 14 | from wrappers import FlattenObservationWrapper, LogWrapper 15 | 16 | 17 | class ScannedRNN(nn.Module): 18 | @functools.partial( 19 | nn.scan, 20 | variable_broadcast="params", 21 | in_axes=0, 22 | out_axes=0, 23 | split_rngs={"params": False}, 24 | ) 25 | @nn.compact 26 | def __call__(self, carry, x): 27 | """Applies the module.""" 28 | rnn_state = carry 29 | ins, resets = x 30 | rnn_state = jnp.where( 31 | resets[:, np.newaxis], 32 | self.initialize_carry(ins.shape[0], ins.shape[1]), 33 | rnn_state, 34 | ) 35 | new_rnn_state, y = nn.GRUCell()(rnn_state, ins) 36 | return new_rnn_state, y 37 | 38 | @staticmethod 39 | def initialize_carry(batch_size, hidden_size): 40 | # Use a dummy key since the default state init fn is just zeros. 41 | return nn.GRUCell.initialize_carry( 42 | jax.random.PRNGKey(0), (batch_size,), hidden_size 43 | ) 44 | 45 | 46 | class ActorCriticRNN(nn.Module): 47 | action_dim: Sequence[int] 48 | config: Dict 49 | 50 | @nn.compact 51 | def __call__(self, hidden, x): 52 | obs, dones = x 53 | embedding = nn.Dense( 54 | 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) 55 | )(obs) 56 | embedding = nn.relu(embedding) 57 | 58 | rnn_in = (embedding, dones) 59 | hidden, embedding = ScannedRNN()(hidden, rnn_in) 60 | 61 | actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))( 62 | embedding 63 | ) 64 | actor_mean = nn.relu(actor_mean) 65 | actor_mean = nn.Dense( 66 | self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) 67 | )(actor_mean) 68 | 69 | pi = distrax.Categorical(logits=actor_mean) 70 | 71 | critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))( 72 | embedding 73 | ) 74 | critic = nn.relu(critic) 75 | critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( 76 | critic 77 | ) 78 | 79 | return hidden, pi, jnp.squeeze(critic, axis=-1) 80 | 81 | 82 | class Transition(NamedTuple): 83 | done: jnp.ndarray 84 | action: jnp.ndarray 85 | value: jnp.ndarray 86 | reward: jnp.ndarray 87 | log_prob: jnp.ndarray 88 | obs: jnp.ndarray 89 | info: jnp.ndarray 90 | 91 | 92 | def make_train(config): 93 | config["NUM_UPDATES"] = ( 94 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] 95 | ) 96 | config["MINIBATCH_SIZE"] = ( 97 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] 98 | ) 99 | env, env_params = gymnax.make(config["ENV_NAME"]) 100 | env = FlattenObservationWrapper(env) 101 | env = LogWrapper(env) 102 | 103 | def linear_schedule(count): 104 | frac = ( 105 | 1.0 106 | - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) 107 | / config["NUM_UPDATES"] 108 | ) 109 | return config["LR"] * frac 110 | 111 | def train(rng): 112 | # INIT NETWORK 113 | network = ActorCriticRNN(env.action_space(env_params).n, config=config) 114 | rng, _rng = jax.random.split(rng) 115 | init_x = ( 116 | jnp.zeros( 117 | (1, config["NUM_ENVS"], *env.observation_space(env_params).shape) 118 | ), 119 | jnp.zeros((1, config["NUM_ENVS"])), 120 | ) 121 | init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128) 122 | network_params = network.init(_rng, init_hstate, init_x) 123 | if config["ANNEAL_LR"]: 124 | tx = optax.chain( 125 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), 126 | optax.adam(learning_rate=linear_schedule, eps=1e-5), 127 | ) 128 | else: 129 | tx = optax.chain( 130 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), 131 | optax.adam(config["LR"], eps=1e-5), 132 | ) 133 | train_state = TrainState.create( 134 | apply_fn=network.apply, 135 | params=network_params, 136 | tx=tx, 137 | ) 138 | 139 | # INIT ENV 140 | rng, _rng = jax.random.split(rng) 141 | reset_rng = jax.random.split(_rng, config["NUM_ENVS"]) 142 | obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params) 143 | init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128) 144 | 145 | # TRAIN LOOP 146 | def _update_step(runner_state, unused): 147 | # COLLECT TRAJECTORIES 148 | def _env_step(runner_state, unused): 149 | train_state, env_state, last_obs, last_done, hstate, rng = runner_state 150 | rng, _rng = jax.random.split(rng) 151 | 152 | # SELECT ACTION 153 | ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :]) 154 | hstate, pi, value = network.apply(train_state.params, hstate, ac_in) 155 | action = pi.sample(seed=_rng) 156 | log_prob = pi.log_prob(action) 157 | value, action, log_prob = ( 158 | value.squeeze(0), 159 | action.squeeze(0), 160 | log_prob.squeeze(0), 161 | ) 162 | 163 | # STEP ENV 164 | rng, _rng = jax.random.split(rng) 165 | rng_step = jax.random.split(_rng, config["NUM_ENVS"]) 166 | obsv, env_state, reward, done, info = jax.vmap( 167 | env.step, in_axes=(0, 0, 0, None) 168 | )(rng_step, env_state, action, env_params) 169 | transition = Transition( 170 | last_done, action, value, reward, log_prob, last_obs, info 171 | ) 172 | runner_state = (train_state, env_state, obsv, done, hstate, rng) 173 | return runner_state, transition 174 | 175 | initial_hstate = runner_state[-2] 176 | runner_state, traj_batch = jax.lax.scan( 177 | _env_step, runner_state, None, config["NUM_STEPS"] 178 | ) 179 | 180 | # CALCULATE ADVANTAGE 181 | train_state, env_state, last_obs, last_done, hstate, rng = runner_state 182 | ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :]) 183 | _, _, last_val = network.apply(train_state.params, hstate, ac_in) 184 | last_val = last_val.squeeze(0) 185 | def _calculate_gae(traj_batch, last_val, last_done): 186 | def _get_advantages(carry, transition): 187 | gae, next_value, next_done = carry 188 | done, value, reward = transition.done, transition.value, transition.reward 189 | delta = reward + config["GAMMA"] * next_value * (1 - next_done) - value 190 | gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - next_done) * gae 191 | return (gae, value, done), gae 192 | _, advantages = jax.lax.scan(_get_advantages, (jnp.zeros_like(last_val), last_val, last_done), traj_batch, reverse=True, unroll=16) 193 | return advantages, advantages + traj_batch.value 194 | advantages, targets = _calculate_gae(traj_batch, last_val, last_done) 195 | 196 | # UPDATE NETWORK 197 | def _update_epoch(update_state, unused): 198 | def _update_minbatch(train_state, batch_info): 199 | init_hstate, traj_batch, advantages, targets = batch_info 200 | 201 | def _loss_fn(params, init_hstate, traj_batch, gae, targets): 202 | # RERUN NETWORK 203 | _, pi, value = network.apply( 204 | params, init_hstate[0], (traj_batch.obs, traj_batch.done) 205 | ) 206 | log_prob = pi.log_prob(traj_batch.action) 207 | 208 | # CALCULATE VALUE LOSS 209 | value_pred_clipped = traj_batch.value + ( 210 | value - traj_batch.value 211 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) 212 | value_losses = jnp.square(value - targets) 213 | value_losses_clipped = jnp.square(value_pred_clipped - targets) 214 | value_loss = ( 215 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() 216 | ) 217 | 218 | # CALCULATE ACTOR LOSS 219 | ratio = jnp.exp(log_prob - traj_batch.log_prob) 220 | gae = (gae - gae.mean()) / (gae.std() + 1e-8) 221 | loss_actor1 = ratio * gae 222 | loss_actor2 = ( 223 | jnp.clip( 224 | ratio, 225 | 1.0 - config["CLIP_EPS"], 226 | 1.0 + config["CLIP_EPS"], 227 | ) 228 | * gae 229 | ) 230 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2) 231 | loss_actor = loss_actor.mean() 232 | entropy = pi.entropy().mean() 233 | 234 | total_loss = ( 235 | loss_actor 236 | + config["VF_COEF"] * value_loss 237 | - config["ENT_COEF"] * entropy 238 | ) 239 | return total_loss, (value_loss, loss_actor, entropy) 240 | 241 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) 242 | total_loss, grads = grad_fn( 243 | train_state.params, init_hstate, traj_batch, advantages, targets 244 | ) 245 | train_state = train_state.apply_gradients(grads=grads) 246 | return train_state, total_loss 247 | 248 | ( 249 | train_state, 250 | init_hstate, 251 | traj_batch, 252 | advantages, 253 | targets, 254 | rng, 255 | ) = update_state 256 | 257 | rng, _rng = jax.random.split(rng) 258 | permutation = jax.random.permutation(_rng, config["NUM_ENVS"]) 259 | batch = (init_hstate, traj_batch, advantages, targets) 260 | 261 | shuffled_batch = jax.tree_util.tree_map( 262 | lambda x: jnp.take(x, permutation, axis=1), batch 263 | ) 264 | 265 | minibatches = jax.tree_util.tree_map( 266 | lambda x: jnp.swapaxes( 267 | jnp.reshape( 268 | x, 269 | [x.shape[0], config["NUM_MINIBATCHES"], -1] 270 | + list(x.shape[2:]), 271 | ), 272 | 1, 273 | 0, 274 | ), 275 | shuffled_batch, 276 | ) 277 | 278 | train_state, total_loss = jax.lax.scan( 279 | _update_minbatch, train_state, minibatches 280 | ) 281 | update_state = ( 282 | train_state, 283 | init_hstate, 284 | traj_batch, 285 | advantages, 286 | targets, 287 | rng, 288 | ) 289 | return update_state, total_loss 290 | 291 | init_hstate = initial_hstate[None, :] # TBH 292 | update_state = ( 293 | train_state, 294 | init_hstate, 295 | traj_batch, 296 | advantages, 297 | targets, 298 | rng, 299 | ) 300 | update_state, loss_info = jax.lax.scan( 301 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"] 302 | ) 303 | train_state = update_state[0] 304 | metric = traj_batch.info 305 | rng = update_state[-1] 306 | if config.get("DEBUG"): 307 | 308 | def callback(info): 309 | return_values = info["returned_episode_returns"][ 310 | info["returned_episode"] 311 | ] 312 | timesteps = ( 313 | info["timestep"][info["returned_episode"]] * config["NUM_ENVS"] 314 | ) 315 | for t in range(len(timesteps)): 316 | print( 317 | f"global step={timesteps[t]}, episodic return={return_values[t]}" 318 | ) 319 | 320 | jax.debug.callback(callback, metric) 321 | 322 | runner_state = (train_state, env_state, last_obs, last_done, hstate, rng) 323 | return runner_state, metric 324 | 325 | rng, _rng = jax.random.split(rng) 326 | runner_state = ( 327 | train_state, 328 | env_state, 329 | obsv, 330 | jnp.zeros((config["NUM_ENVS"]), dtype=bool), 331 | init_hstate, 332 | _rng, 333 | ) 334 | runner_state, metric = jax.lax.scan( 335 | _update_step, runner_state, None, config["NUM_UPDATES"] 336 | ) 337 | return {"runner_state": runner_state, "metric": metric} 338 | 339 | return train 340 | 341 | 342 | if __name__ == "__main__": 343 | config = { 344 | "LR": 2.5e-4, 345 | "NUM_ENVS": 4, 346 | "NUM_STEPS": 128, 347 | "TOTAL_TIMESTEPS": 5e5, 348 | "UPDATE_EPOCHS": 4, 349 | "NUM_MINIBATCHES": 4, 350 | "GAMMA": 0.99, 351 | "GAE_LAMBDA": 0.95, 352 | "CLIP_EPS": 0.2, 353 | "ENT_COEF": 0.01, 354 | "VF_COEF": 0.5, 355 | "MAX_GRAD_NORM": 0.5, 356 | "ENV_NAME": "CartPole-v1", 357 | "ANNEAL_LR": True, 358 | "DEBUG": True, 359 | } 360 | 361 | rng = jax.random.PRNGKey(30) 362 | train_jit = jax.jit(make_train(config)) 363 | out = train_jit(rng) 364 | -------------------------------------------------------------------------------- /purejaxrl/wrappers.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import chex 4 | import numpy as np 5 | from flax import struct 6 | from functools import partial 7 | from typing import Optional, Tuple, Union, Any 8 | from gymnax.environments import environment, spaces 9 | from brax import envs 10 | from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper 11 | import navix as nx 12 | 13 | 14 | class GymnaxWrapper(object): 15 | """Base class for Gymnax wrappers.""" 16 | 17 | def __init__(self, env): 18 | self._env = env 19 | 20 | # provide proxy access to regular attributes of wrapped object 21 | def __getattr__(self, name): 22 | return getattr(self._env, name) 23 | 24 | 25 | class FlattenObservationWrapper(GymnaxWrapper): 26 | """Flatten the observations of the environment.""" 27 | 28 | def __init__(self, env: environment.Environment): 29 | super().__init__(env) 30 | 31 | def observation_space(self, params) -> spaces.Box: 32 | assert isinstance( 33 | self._env.observation_space(params), spaces.Box 34 | ), "Only Box spaces are supported for now." 35 | return spaces.Box( 36 | low=self._env.observation_space(params).low, 37 | high=self._env.observation_space(params).high, 38 | shape=(np.prod(self._env.observation_space(params).shape),), 39 | dtype=self._env.observation_space(params).dtype, 40 | ) 41 | 42 | @partial(jax.jit, static_argnums=(0,)) 43 | def reset( 44 | self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None 45 | ) -> Tuple[chex.Array, environment.EnvState]: 46 | obs, state = self._env.reset(key, params) 47 | obs = jnp.reshape(obs, (-1,)) 48 | return obs, state 49 | 50 | @partial(jax.jit, static_argnums=(0,)) 51 | def step( 52 | self, 53 | key: chex.PRNGKey, 54 | state: environment.EnvState, 55 | action: Union[int, float], 56 | params: Optional[environment.EnvParams] = None, 57 | ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]: 58 | obs, state, reward, done, info = self._env.step(key, state, action, params) 59 | obs = jnp.reshape(obs, (-1,)) 60 | return obs, state, reward, done, info 61 | 62 | 63 | @struct.dataclass 64 | class LogEnvState: 65 | env_state: environment.EnvState 66 | episode_returns: float 67 | episode_lengths: int 68 | returned_episode_returns: float 69 | returned_episode_lengths: int 70 | timestep: int 71 | 72 | 73 | class LogWrapper(GymnaxWrapper): 74 | """Log the episode returns and lengths.""" 75 | 76 | def __init__(self, env: environment.Environment): 77 | super().__init__(env) 78 | 79 | @partial(jax.jit, static_argnums=(0,)) 80 | def reset( 81 | self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None 82 | ) -> Tuple[chex.Array, environment.EnvState]: 83 | obs, env_state = self._env.reset(key, params) 84 | state = LogEnvState(env_state, 0, 0, 0, 0, 0) 85 | return obs, state 86 | 87 | @partial(jax.jit, static_argnums=(0,)) 88 | def step( 89 | self, 90 | key: chex.PRNGKey, 91 | state: environment.EnvState, 92 | action: Union[int, float], 93 | params: Optional[environment.EnvParams] = None, 94 | ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]: 95 | obs, env_state, reward, done, info = self._env.step( 96 | key, state.env_state, action, params 97 | ) 98 | new_episode_return = state.episode_returns + reward 99 | new_episode_length = state.episode_lengths + 1 100 | state = LogEnvState( 101 | env_state=env_state, 102 | episode_returns=new_episode_return * (1 - done), 103 | episode_lengths=new_episode_length * (1 - done), 104 | returned_episode_returns=state.returned_episode_returns * (1 - done) 105 | + new_episode_return * done, 106 | returned_episode_lengths=state.returned_episode_lengths * (1 - done) 107 | + new_episode_length * done, 108 | timestep=state.timestep + 1, 109 | ) 110 | info["returned_episode_returns"] = state.returned_episode_returns 111 | info["returned_episode_lengths"] = state.returned_episode_lengths 112 | info["timestep"] = state.timestep 113 | info["returned_episode"] = done 114 | return obs, state, reward, done, info 115 | 116 | 117 | class BraxGymnaxWrapper: 118 | def __init__(self, env_name, backend="positional"): 119 | env = envs.get_environment(env_name=env_name, backend=backend) 120 | env = EpisodeWrapper(env, episode_length=1000, action_repeat=1) 121 | env = AutoResetWrapper(env) 122 | self._env = env 123 | self.action_size = env.action_size 124 | self.observation_size = (env.observation_size,) 125 | 126 | def reset(self, key, params=None): 127 | state = self._env.reset(key) 128 | return state.obs, state 129 | 130 | def step(self, key, state, action, params=None): 131 | next_state = self._env.step(state, action) 132 | return next_state.obs, next_state, next_state.reward, next_state.done > 0.5, {} 133 | 134 | def observation_space(self, params): 135 | return spaces.Box( 136 | low=-jnp.inf, 137 | high=jnp.inf, 138 | shape=(self._env.observation_size,), 139 | ) 140 | 141 | def action_space(self, params): 142 | return spaces.Box( 143 | low=-1.0, 144 | high=1.0, 145 | shape=(self._env.action_size,), 146 | ) 147 | 148 | class NavixGymnaxWrapper: 149 | def __init__(self, env_name): 150 | self._env = nx.make(env_name) 151 | 152 | def reset(self, key, params=None): 153 | timestep = self._env.reset(key) 154 | return timestep.observation, timestep 155 | 156 | def step(self, key, state, action, params=None): 157 | timestep = self._env.step(state, action) 158 | return timestep.observation, timestep, timestep.reward, timestep.is_done(), {} 159 | 160 | def observation_space(self, params): 161 | return spaces.Box( 162 | low=self._env.observation_space.minimum, 163 | high=self._env.observation_space.maximum, 164 | shape=(np.prod(self._env.observation_space.shape),), 165 | dtype=self._env.observation_space.dtype, 166 | ) 167 | 168 | def action_space(self, params): 169 | return spaces.Discrete( 170 | num_categories=self._env.action_space.maximum.item() + 1, 171 | ) 172 | 173 | 174 | class ClipAction(GymnaxWrapper): 175 | def __init__(self, env, low=-1.0, high=1.0): 176 | super().__init__(env) 177 | self.low = low 178 | self.high = high 179 | 180 | def step(self, key, state, action, params=None): 181 | """TODO: In theory the below line should be the way to do this.""" 182 | # action = jnp.clip(action, self.env.action_space.low, self.env.action_space.high) 183 | action = jnp.clip(action, self.low, self.high) 184 | return self._env.step(key, state, action, params) 185 | 186 | 187 | class TransformObservation(GymnaxWrapper): 188 | def __init__(self, env, transform_obs): 189 | super().__init__(env) 190 | self.transform_obs = transform_obs 191 | 192 | def reset(self, key, params=None): 193 | obs, state = self._env.reset(key, params) 194 | return self.transform_obs(obs), state 195 | 196 | def step(self, key, state, action, params=None): 197 | obs, state, reward, done, info = self._env.step(key, state, action, params) 198 | return self.transform_obs(obs), state, reward, done, info 199 | 200 | 201 | class TransformReward(GymnaxWrapper): 202 | def __init__(self, env, transform_reward): 203 | super().__init__(env) 204 | self.transform_reward = transform_reward 205 | 206 | def step(self, key, state, action, params=None): 207 | obs, state, reward, done, info = self._env.step(key, state, action, params) 208 | return obs, state, self.transform_reward(reward), done, info 209 | 210 | 211 | class VecEnv(GymnaxWrapper): 212 | def __init__(self, env): 213 | super().__init__(env) 214 | self.reset = jax.vmap(self._env.reset, in_axes=(0, None)) 215 | self.step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None)) 216 | 217 | 218 | @struct.dataclass 219 | class NormalizeVecObsEnvState: 220 | mean: jnp.ndarray 221 | var: jnp.ndarray 222 | count: float 223 | env_state: environment.EnvState 224 | 225 | 226 | class NormalizeVecObservation(GymnaxWrapper): 227 | def __init__(self, env): 228 | super().__init__(env) 229 | 230 | def reset(self, key, params=None): 231 | obs, state = self._env.reset(key, params) 232 | state = NormalizeVecObsEnvState( 233 | mean=jnp.zeros_like(obs), 234 | var=jnp.ones_like(obs), 235 | count=1e-4, 236 | env_state=state, 237 | ) 238 | batch_mean = jnp.mean(obs, axis=0) 239 | batch_var = jnp.var(obs, axis=0) 240 | batch_count = obs.shape[0] 241 | 242 | delta = batch_mean - state.mean 243 | tot_count = state.count + batch_count 244 | 245 | new_mean = state.mean + delta * batch_count / tot_count 246 | m_a = state.var * state.count 247 | m_b = batch_var * batch_count 248 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count 249 | new_var = M2 / tot_count 250 | new_count = tot_count 251 | 252 | state = NormalizeVecObsEnvState( 253 | mean=new_mean, 254 | var=new_var, 255 | count=new_count, 256 | env_state=state.env_state, 257 | ) 258 | 259 | return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state 260 | 261 | def step(self, key, state, action, params=None): 262 | obs, env_state, reward, done, info = self._env.step( 263 | key, state.env_state, action, params 264 | ) 265 | 266 | batch_mean = jnp.mean(obs, axis=0) 267 | batch_var = jnp.var(obs, axis=0) 268 | batch_count = obs.shape[0] 269 | 270 | delta = batch_mean - state.mean 271 | tot_count = state.count + batch_count 272 | 273 | new_mean = state.mean + delta * batch_count / tot_count 274 | m_a = state.var * state.count 275 | m_b = batch_var * batch_count 276 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count 277 | new_var = M2 / tot_count 278 | new_count = tot_count 279 | 280 | state = NormalizeVecObsEnvState( 281 | mean=new_mean, 282 | var=new_var, 283 | count=new_count, 284 | env_state=env_state, 285 | ) 286 | return ( 287 | (obs - state.mean) / jnp.sqrt(state.var + 1e-8), 288 | state, 289 | reward, 290 | done, 291 | info, 292 | ) 293 | 294 | 295 | @struct.dataclass 296 | class NormalizeVecRewEnvState: 297 | mean: jnp.ndarray 298 | var: jnp.ndarray 299 | count: float 300 | return_val: float 301 | env_state: environment.EnvState 302 | 303 | 304 | class NormalizeVecReward(GymnaxWrapper): 305 | def __init__(self, env, gamma): 306 | super().__init__(env) 307 | self.gamma = gamma 308 | 309 | def reset(self, key, params=None): 310 | obs, state = self._env.reset(key, params) 311 | batch_count = obs.shape[0] 312 | state = NormalizeVecRewEnvState( 313 | mean=0.0, 314 | var=1.0, 315 | count=1e-4, 316 | return_val=jnp.zeros((batch_count,)), 317 | env_state=state, 318 | ) 319 | return obs, state 320 | 321 | def step(self, key, state, action, params=None): 322 | obs, env_state, reward, done, info = self._env.step( 323 | key, state.env_state, action, params 324 | ) 325 | return_val = state.return_val * self.gamma * (1 - done) + reward 326 | 327 | batch_mean = jnp.mean(return_val, axis=0) 328 | batch_var = jnp.var(return_val, axis=0) 329 | batch_count = obs.shape[0] 330 | 331 | delta = batch_mean - state.mean 332 | tot_count = state.count + batch_count 333 | 334 | new_mean = state.mean + delta * batch_count / tot_count 335 | m_a = state.var * state.count 336 | m_b = batch_var * batch_count 337 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count 338 | new_var = M2 / tot_count 339 | new_count = tot_count 340 | 341 | state = NormalizeVecRewEnvState( 342 | mean=new_mean, 343 | var=new_var, 344 | count=new_count, 345 | return_val=return_val, 346 | env_state=env_state, 347 | ) 348 | return obs, state, reward / jnp.sqrt(state.var + 1e-8), done, info 349 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax>=0.2.26 2 | jaxlib>=0.1.74 3 | gymnax 4 | evosax 5 | distrax 6 | optax 7 | flax 8 | numpy 9 | brax 10 | wandb 11 | flashbax 12 | navix --------------------------------------------------------------------------------