├── .gitignore
├── LICENSE
├── README.md
├── RESOURCES.md
├── docs
├── cartpole_plot_parallel.png
├── cartpole_plot_parallel_old.png
├── cartpole_plot_seconds.png
├── minatar_plot_parallel.png
├── minatar_plot_parallel_old.png
└── minatar_plot_seconds.png
├── examples
├── brax_minatar.ipynb
└── walkthrough.ipynb
├── purejaxrl
├── dpo_continuous_action.py
├── dqn.py
├── experimental
│ └── s5
│ │ ├── README.md
│ │ ├── ppo_s5.py
│ │ ├── s5.py
│ │ └── wrappers.py
├── ppo.py
├── ppo_continuous_action.py
├── ppo_minigrid.py
├── ppo_rnn.py
└── wrappers.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | *.py[co]
2 | __pycache__/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 Chris Lu
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PureJaxRL (End-to-End RL Training in Pure Jax)
2 |
3 | [
](https://github.com/luchris429/purejaxrl/LICENSE)
4 | [](https://github.com/psf/black)
5 | [](https://colab.research.google.com/github/luchris429/purejaxrl/blob/main/examples/walkthrough.ipynb)
6 |
7 | PureJaxRL is a high-performance, end-to-end Jax Reinforcement Learning (RL) implementation. When running many agents in parallel on GPUs, our implementation is over 1000x faster than standard PyTorch RL implementations. Unlike other Jax RL implementations, we implement the *entire training pipeline in JAX*, including the environment. This allows us to get significant speedups through JIT compilation and by avoiding CPU-GPU data transfer. It also results in easier debugging because the system is fully synchronous. More importantly, this code allows you to use jax to `jit`, `vmap`, `pmap`, and `scan` entire RL training pipelines. With this, we can:
8 |
9 | - 🏃 Efficiently run tons of seeds in parallel on one GPU
10 | - 💻 Perform rapid hyperparameter tuning
11 | - 🦎 Discover new RL algorithms with meta-evolution
12 |
13 | For more details, visit the accompanying blog post: https://chrislu.page/blog/meta-disco/
14 |
15 | This notebook walks through the basic usage: [](https://colab.research.google.com/github/luchris429/purejaxrl/blob/main/examples/walkthrough.ipynb)
16 |
17 | ## CHECK OUT [RESOURCES.MD](https://github.com/luchris429/purejaxrl/blob/main/RESOURCES.md) to see github repos that are part of the Jax RL Ecosystem!
18 |
19 | ## Performance
20 |
21 | Without vectorization, our implementation runs 10x faster than [CleanRL's PyTorch baselines](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py), as shown in the single-thread performance plot.
22 |
23 | Cartpole | Minatar-Breakout
24 | :-------------------------:|:-------------------------:
25 |  | 
26 |
27 |
28 | With vectorized training, we can train 2048 PPO agents in half the time it takes to train a single PyTorch PPO agent on a single GPU. The vectorized agent training allows for simultaneous training across multiple seeds, rapid hyperparameter tuning, and even evolutionary Meta-RL.
29 |
30 | Vectorised Cartpole | Vectorised Minatar-Breakout
31 | :-------------------------:|:-------------------------:
32 |  | 
33 |
34 |
35 | ## Code Philosophy
36 |
37 | PureJaxRL is inspired by [CleanRL](https://github.com/vwxyzjn/cleanrl), providing high-quality single-file implementations with research-friendly features. Like CleanRL, this is not a modular library and is not meant to be imported. The repository focuses on simplicity and clarity in its implementations, making it an excellent resource for researchers and practitioners.
38 |
39 | ## Installation
40 |
41 | Install dependencies using the requirements.txt file:
42 |
43 | ```
44 | pip install -r requirements.txt
45 | ```
46 |
47 | In order to use JAX on your accelerators, you can find more details in the [JAX documentation](https://github.com/google/jax#installation).
48 |
49 | ## Example Usage
50 |
51 | [`examples/walkthrough.ipynb`](https://github.com/luchris429/purejaxrl/blob/main/examples/walkthrough.ipynb) walks through the basic usage. [](https://colab.research.google.com/github/luchris429/purejaxrl/blob/main/examples/walkthrough.ipynb)
52 |
53 | [`examples/brax_minatar.ipynb`](https://github.com/luchris429/purejaxrl/blob/main/examples/brax_minatar.ipynb) walks through using PureJaxRL for Brax and MinAtar. [](https://colab.research.google.com/github/luchris429/purejaxrl/blob/main/examples/brax_minatar.ipynb)
54 |
55 | ## Related Work
56 |
57 | Check out the list of [RESOURCES](https://github.com/luchris429/purejaxrl/blob/main/RESOURCES.md) to see libraries that are closely related to PureJaxRL!
58 |
59 | The following repositories and projects were pre-cursors to `purejaxrl`:
60 |
61 | - [Model-Free Opponent Shaping](https://arxiv.org/abs/2205.01447) (ICML 2022) (https://github.com/luchris429/Model-Free-Opponent-Shaping)
62 |
63 | - [Discovered Policy Optimisation](https://arxiv.org/abs/2210.05639) (NeurIPS 2022) (https://github.com/luchris429/discovered-policy-optimisation)
64 |
65 | - [Adversarial Cheap Talk](https://arxiv.org/abs/2211.11030) (ICML 2023) (https://github.com/luchris429/adversarial-cheap-talk)
66 |
67 | ## Citation
68 |
69 | If you use PureJaxRL in your work, please cite the following paper:
70 |
71 | ```
72 | @article{lu2022discovered,
73 | title={Discovered policy optimisation},
74 | author={Lu, Chris and Kuba, Jakub and Letcher, Alistair and Metz, Luke and Schroeder de Witt, Christian and Foerster, Jakob},
75 | journal={Advances in Neural Information Processing Systems},
76 | volume={35},
77 | pages={16455--16468},
78 | year={2022}
79 | }
80 | ```
81 |
--------------------------------------------------------------------------------
/RESOURCES.md:
--------------------------------------------------------------------------------
1 | # PureJaxRL Resources
2 |
3 | Last year, I released [PureJaxRL](https://github.com/luchris429/purejaxrl), a simple repository that implements RL algorithms entirely end-to-end in JAX, which enables speedups of up to 4000x in RL training. PureJaxRL, in turn, was inspired by multiple projects, including [CleanRL](https://github.com/vwxyzjn/cleanrl) and [Gymnax](https://github.com/RobertTLange/gymnax). Since the release of PureJaxRL, a large number of projects related to or inspired by PureJaxRL have come out, vastly expanding its use case from standard single-agent RL settings. This curated list contains those projects alongside other relevant implementations of algorithms, environments, tools, and tutorials.
4 |
5 | To understand more about the benefits PureJaxRL, I recommend viewing the [original blog post](https://chrislu.page/blog/meta-disco/) or [tweet thread](https://x.com/_chris_lu_/status/1643992216413831171).
6 |
7 | The PureJaxRL repository can be found here:
8 |
9 | [https://github.com/luchris429/purejaxrl/](https://github.com/luchris429/purejaxrl/).
10 |
11 | The format of the list is from [awesome](https://github.com/sindresorhus/awesome) and [awesome-jax](https://github.com/n2cholas/awesome-jax). While this list is curated, it is certainly not complete. If you have a repository you would like to add, please contribute!
12 |
13 | If you find this resource useful, please *star* the repo! It helps establish and grow the end-to-end JAX RL community.
14 |
15 | ## Contents
16 |
17 | - [Algorithms](#algorithms)
18 | - [Environments](#environments)
19 | - [Related Components](#components)
20 | - [Tutorials and Blog Posts](#tutorials-and-blog-posts)
21 | - [Related Papers](#papers)
22 |
23 | ## Algorithms
24 |
25 | ### End-to-End JAX RL Implementations
26 |
27 | - [purejaxrl](https://github.com/luchris429/purejaxrl) - Classic and simple end-to-end RL training in pure JAX.
28 |
29 | - [rejax](https://github.com/keraJLi/rejax) - Modular and importable end-to-end JAX RL training.
30 |
31 | - [Stoix](https://github.com/EdanToledo/Stoix) - End-to-end JAX RL training with advanced logging, configs, and more.
32 |
33 | - [purejaxql](https://github.com/mttga/purejaxql/) - Simple single-file end-to-end JAX baselines for Q-Learning.
34 |
35 | - [jym](https://github.com/rpegoud/jym) - Educational and beginner-friendly end-to-end JAX RL training.
36 |
37 | ### Jax RL (But Not End-to-End) Repos
38 |
39 | - [cleanrl](https://github.com/vwxyzjn/cleanrl) - Clean implementations of RL Algorithms (in both PyTorch and JAX!).
40 |
41 | - [jaxrl](https://github.com/ikostrikov/jaxrl) - JAX implementation of algorithms for Deep Reinforcement Learning with continuous action spaces.
42 |
43 | - [rlbase](https://github.com/kvfrans/rlbase_stable) - Single-file JAX implementations of Deep RL algorithms.
44 |
45 | ### Multi-Agent RL
46 |
47 | - [JaxMARL](https://github.com/FLAIROx/JaxMARL) - Multi-Agent RL Algorithms and Environments in pure JAX.
48 |
49 | - [Mava](https://github.com/instadeepai/Mava) - Multi-Agent RL Algorithms in pure JAX (previously tensorflow-based algorithms).
50 |
51 | - [pax](https://github.com/ucl-dark/pax) - Scalable Opponent Shaping Algorithms in pure JAX.
52 |
53 | ### Offline RL
54 |
55 | - [JAX-CORL](https://github.com/nissymori/JAX-CORL) - Single-file implementations of offline RL algorithms in JAX.
56 |
57 | ### Inverse-RL
58 |
59 | - [jaxirl](https://github.com/FLAIROx/jaxirl) - Pure JAX for Inverse Reinforcement Learning.
60 |
61 | ### Unsupervised Environment Design
62 |
63 | - [minimax](https://github.com/facebookresearch/minimax) - Canonical implementations of UED algorithms in pure JAX, including SSM-based acceleration.
64 |
65 | - [jaxued](https://github.com/DramaCow/jaxued) - Single-file implementations of UED algorithms in pure JAX.
66 |
67 | ### Quality-Diversity
68 |
69 | - [QDax](https://github.com/adaptive-intelligent-robotics/QDax) - Quality-Diversity algorithms in pure JAX.
70 |
71 | ### Partially-Observed RL
72 |
73 | - [popjaxrl](https://github.com/luchris429/popjaxrl) - Partially-observed RL environments (POPGym) and architectures (incl. SSM's) in pure JAX.
74 |
75 | ### Meta-Learning RL Objectives
76 |
77 | - [groove](https://github.com/EmptyJackson/groove) - Library for [LPG-like](https://arxiv.org/abs/2007.08794) meta-RL in Pure JAX.
78 |
79 | - [discovered-policy-optimisation](https://github.com/luchris429/discovered-policy-optimisation) - Library for [LPO](https://arxiv.org/abs/2210.05639) meta-RL in Pure JAX.
80 |
81 | - [rl-learned-optimization](https://github.com/AlexGoldie/rl-learned-optimization) - Library for [OPEN](https://arxiv.org/abs/2407.07082) in Pure JAX.
82 |
83 | ## Environments
84 |
85 | - [gymnax](https://github.com/RobertTLange/gymnax) - Classic RL environments in JAX.
86 |
87 | - [brax](https://github.com/google/brax) - Continuous control environments in JAX.
88 |
89 | - [JaxMARL](https://github.com/FLAIROx/JaxMARL) - Multi-agent algorithms and environments in pure JAX.
90 |
91 | - [jumanji](https://github.com/instadeepai/jumanji) - Suite of unique RL environments in JAX.
92 |
93 | - [pgx](https://github.com/sotetsuk/pgx) - Suite of popular board games in JAX.
94 |
95 | - [popjaxrl](https://github.com/luchris429/popjaxrl) - Partially-observed RL environments (POPGym) in JAX.
96 |
97 | - [waymax](https://github.com/waymo-research/waymax) - Self-driving car simulator in JAX.
98 |
99 | - [Craftax](https://github.com/MichaelTMatthews/Craftax) - A challenging crafter-like and nethack-inspired benchmark in JAX.
100 |
101 | - [xland-minigrid](https://github.com/corl-team/xland-minigrid) - A large-scale meta-RL environment in JAX.
102 |
103 | - [navix](https://github.com/epignatelli/navix) - Classic minigrid environments in JAX.
104 |
105 | - [autoverse](https://github.com/smearle/autoverse) - A fast, evolvable description language for reinforcement learning environments.
106 |
107 | - [qdx](https://github.com/jolle-ag/qdx) - Quantum Error Corection with JAX.
108 |
109 | - [matrax](https://github.com/instadeepai/matrax) - Matrix games in JAX.
110 |
111 | - [AlphaTrade](https://github.com/KangOxford/AlphaTrade) - Limit Order Book (LOB) in JAX.
112 |
113 | ## Relevant Tools and Components
114 |
115 | - [evosax](https://github.com/RobertTLange/evosax) - Evolution strategies in JAX.
116 |
117 | - [evojax](https://github.com/google/evojax) - Evolution strategies in JAX.
118 |
119 | - [flashbax](https://github.com/instadeepai/flashbax) - Accelerated replay buffers in JAX.
120 |
121 | - [dejax](https://github.com/hr0nix/dejax) - Accelerated replay buffers in JAX.
122 |
123 | - [rlax](https://github.com/google-deepmind/rlax) - RL components and building blocks in JAX.
124 |
125 | - [mctx](https://github.com/google-deepmind/mctx) - Monte Carlo tree searh in JAX.
126 |
127 | - [distrax](https://github.com/google-deepmind/distrax) - Distributions in JAX.
128 |
129 | - [optax](https://github.com/google-deepmind/optax) - Gradient-based optimizers in JAX.
130 |
131 | - [flax](https://github.com/google/flax) - Neural Networks in JAX.
132 |
133 | ## Tutorials and Blog Posts
134 |
135 | - [Achieving 4000x Speedups with PureJaxRL](https://chrislu.page/blog/meta-disco/) - A blog post on how JAX can massively speedup RL training through vectorisation.
136 |
137 | - [Breaking down State-of-the-Art PPO Implementations in JAX](https://towardsdatascience.com/breaking-down-state-of-the-art-ppo-implementations-in-jax-6f102c06c149) - A blog post explaining PureJaxRL's PPO Implementation in depth.
138 |
139 | - [A Gentle Introduction to Deep Reinforcement Learning in JAX](https://towardsdatascience.com/a-gentle-introduction-to-deep-reinforcement-learning-in-jax-c1e45a179b92) - A JAX tutorial on Deep RL.
140 |
141 | - [Writing an RL Environment in JAX](https://medium.com/@ngoodger_7766/writing-an-rl-environment-in-jax-9f74338898ba) - A JAX tutorial on making environments.
142 |
143 | - [Getting started with JAX (MLPs, CNNs & RNNs)](https://roberttlange.com/posts/2020/03/blog-post-10/) - A basic JAX neural network tutorial.
144 |
145 | - [awesome-jax](https://github.com/n2cholas/awesome-jax) - A list of useful libraries in JAX
146 |
--------------------------------------------------------------------------------
/docs/cartpole_plot_parallel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luchris429/purejaxrl/31756b197773a52db763fdbe6d635e4b46522a73/docs/cartpole_plot_parallel.png
--------------------------------------------------------------------------------
/docs/cartpole_plot_parallel_old.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luchris429/purejaxrl/31756b197773a52db763fdbe6d635e4b46522a73/docs/cartpole_plot_parallel_old.png
--------------------------------------------------------------------------------
/docs/cartpole_plot_seconds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luchris429/purejaxrl/31756b197773a52db763fdbe6d635e4b46522a73/docs/cartpole_plot_seconds.png
--------------------------------------------------------------------------------
/docs/minatar_plot_parallel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luchris429/purejaxrl/31756b197773a52db763fdbe6d635e4b46522a73/docs/minatar_plot_parallel.png
--------------------------------------------------------------------------------
/docs/minatar_plot_parallel_old.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luchris429/purejaxrl/31756b197773a52db763fdbe6d635e4b46522a73/docs/minatar_plot_parallel_old.png
--------------------------------------------------------------------------------
/docs/minatar_plot_seconds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luchris429/purejaxrl/31756b197773a52db763fdbe6d635e4b46522a73/docs/minatar_plot_seconds.png
--------------------------------------------------------------------------------
/purejaxrl/dpo_continuous_action.py:
--------------------------------------------------------------------------------
1 | """Re-implementation of Discovered Policy Optimisation (DPO)
2 |
3 | https://arxiv.org/abs/2210.05639
4 |
5 | This differs from PPO in just a few lines of the policy objective.
6 |
7 | Please refer to the paper for more details.
8 | """
9 | import jax
10 | import jax.numpy as jnp
11 | import flax.linen as nn
12 | import numpy as np
13 | import optax
14 | from flax.linen.initializers import constant, orthogonal
15 | from typing import Sequence, NamedTuple, Any
16 | from flax.training.train_state import TrainState
17 | import distrax
18 | from wrappers import (
19 | LogWrapper,
20 | BraxGymnaxWrapper,
21 | VecEnv,
22 | NormalizeVecObservation,
23 | NormalizeVecReward,
24 | ClipAction,
25 | )
26 |
27 |
28 | class ActorCritic(nn.Module):
29 | action_dim: Sequence[int]
30 | activation: str = "tanh"
31 |
32 | @nn.compact
33 | def __call__(self, x):
34 | if self.activation == "relu":
35 | activation = nn.relu
36 | else:
37 | activation = nn.tanh
38 | actor_mean = nn.Dense(
39 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
40 | )(x)
41 | actor_mean = activation(actor_mean)
42 | actor_mean = nn.Dense(
43 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
44 | )(actor_mean)
45 | actor_mean = activation(actor_mean)
46 | actor_mean = nn.Dense(
47 | self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
48 | )(actor_mean)
49 | actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
50 | pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))
51 |
52 | critic = nn.Dense(
53 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
54 | )(x)
55 | critic = activation(critic)
56 | critic = nn.Dense(
57 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
58 | )(critic)
59 | critic = activation(critic)
60 | critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
61 | critic
62 | )
63 |
64 | return pi, jnp.squeeze(critic, axis=-1)
65 |
66 |
67 | class Transition(NamedTuple):
68 | done: jnp.ndarray
69 | action: jnp.ndarray
70 | value: jnp.ndarray
71 | reward: jnp.ndarray
72 | log_prob: jnp.ndarray
73 | obs: jnp.ndarray
74 | info: jnp.ndarray
75 |
76 |
77 | def make_train(config):
78 | config["NUM_UPDATES"] = (
79 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
80 | )
81 | config["MINIBATCH_SIZE"] = (
82 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
83 | )
84 | env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None
85 | env = LogWrapper(env)
86 | env = ClipAction(env)
87 | env = VecEnv(env)
88 | if config["NORMALIZE_ENV"]:
89 | env = NormalizeVecObservation(env)
90 | env = NormalizeVecReward(env, config["GAMMA"])
91 |
92 | def linear_schedule(count):
93 | frac = (
94 | 1.0
95 | - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
96 | / config["NUM_UPDATES"]
97 | )
98 | return config["LR"] * frac
99 |
100 | def train(rng):
101 | # INIT NETWORK
102 | network = ActorCritic(
103 | env.action_space(env_params).shape[0], activation=config["ACTIVATION"]
104 | )
105 | rng, _rng = jax.random.split(rng)
106 | init_x = jnp.zeros(env.observation_space(env_params).shape)
107 | network_params = network.init(_rng, init_x)
108 | if config["ANNEAL_LR"]:
109 | tx = optax.chain(
110 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
111 | optax.adam(learning_rate=linear_schedule, eps=1e-5),
112 | )
113 | else:
114 | tx = optax.chain(
115 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
116 | optax.adam(config["LR"], eps=1e-5),
117 | )
118 | train_state = TrainState.create(
119 | apply_fn=network.apply,
120 | params=network_params,
121 | tx=tx,
122 | )
123 |
124 | # INIT ENV
125 | rng, _rng = jax.random.split(rng)
126 | reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
127 | obsv, env_state = env.reset(reset_rng, env_params)
128 |
129 | # TRAIN LOOP
130 | def _update_step(runner_state, unused):
131 | # COLLECT TRAJECTORIES
132 | def _env_step(runner_state, unused):
133 | train_state, env_state, last_obs, rng = runner_state
134 |
135 | # SELECT ACTION
136 | rng, _rng = jax.random.split(rng)
137 | pi, value = network.apply(train_state.params, last_obs)
138 | action = pi.sample(seed=_rng)
139 | log_prob = pi.log_prob(action)
140 |
141 | # STEP ENV
142 | rng, _rng = jax.random.split(rng)
143 | rng_step = jax.random.split(_rng, config["NUM_ENVS"])
144 | obsv, env_state, reward, done, info = env.step(
145 | rng_step, env_state, action, env_params
146 | )
147 | transition = Transition(
148 | done, action, value, reward, log_prob, last_obs, info
149 | )
150 | runner_state = (train_state, env_state, obsv, rng)
151 | return runner_state, transition
152 |
153 | runner_state, traj_batch = jax.lax.scan(
154 | _env_step, runner_state, None, config["NUM_STEPS"]
155 | )
156 |
157 | # CALCULATE ADVANTAGE
158 | train_state, env_state, last_obs, rng = runner_state
159 | _, last_val = network.apply(train_state.params, last_obs)
160 |
161 | def _calculate_gae(traj_batch, last_val):
162 | def _get_advantages(gae_and_next_value, transition):
163 | gae, next_value = gae_and_next_value
164 | done, value, reward = (
165 | transition.done,
166 | transition.value,
167 | transition.reward,
168 | )
169 | delta = reward + config["GAMMA"] * next_value * (1 - done) - value
170 | gae = (
171 | delta
172 | + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
173 | )
174 | return (gae, value), gae
175 |
176 | _, advantages = jax.lax.scan(
177 | _get_advantages,
178 | (jnp.zeros_like(last_val), last_val),
179 | traj_batch,
180 | reverse=True,
181 | unroll=16,
182 | )
183 | return advantages, advantages + traj_batch.value
184 |
185 | advantages, targets = _calculate_gae(traj_batch, last_val)
186 |
187 | # UPDATE NETWORK
188 | def _update_epoch(update_state, unused):
189 | def _update_minbatch(train_state, batch_info):
190 | traj_batch, advantages, targets = batch_info
191 |
192 | def _loss_fn(params, traj_batch, gae, targets):
193 | # RERUN NETWORK
194 | pi, value = network.apply(params, traj_batch.obs)
195 | log_prob = pi.log_prob(traj_batch.action)
196 |
197 | # CALCULATE VALUE LOSS
198 | value_pred_clipped = traj_batch.value + (
199 | value - traj_batch.value
200 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
201 | value_losses = jnp.square(value - targets)
202 | value_losses_clipped = jnp.square(value_pred_clipped - targets)
203 | value_loss = (
204 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
205 | )
206 |
207 | # CALCULATE ACTOR LOSS
208 | alpha = config["DPO_ALPHA"]
209 | beta = config["DPO_BETA"]
210 | log_diff = log_prob - traj_batch.log_prob
211 | ratio = jnp.exp(log_diff)
212 | gae = (gae - gae.mean()) / (gae.std() + 1e-8)
213 | is_pos = (gae >= 0.0).astype("float32")
214 | r1 = ratio - 1.0
215 | drift1 = nn.relu(r1 * gae - alpha * nn.tanh(r1 * gae / alpha))
216 | drift2 = nn.relu(
217 | log_diff * gae - beta * nn.tanh(log_diff * gae / beta)
218 | )
219 | drift = drift1 * is_pos + drift2 * (1 - is_pos)
220 | loss_actor = -(ratio * gae - drift).mean()
221 | entropy = pi.entropy().mean()
222 |
223 | total_loss = (
224 | loss_actor
225 | + config["VF_COEF"] * value_loss
226 | - config["ENT_COEF"] * entropy
227 | )
228 | return total_loss, (value_loss, loss_actor, entropy)
229 |
230 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
231 | total_loss, grads = grad_fn(
232 | train_state.params, traj_batch, advantages, targets
233 | )
234 | train_state = train_state.apply_gradients(grads=grads)
235 | return train_state, total_loss
236 |
237 | train_state, traj_batch, advantages, targets, rng = update_state
238 | rng, _rng = jax.random.split(rng)
239 | batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
240 | assert (
241 | batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
242 | ), "batch size must be equal to number of steps * number of envs"
243 | permutation = jax.random.permutation(_rng, batch_size)
244 | batch = (traj_batch, advantages, targets)
245 | batch = jax.tree_util.tree_map(
246 | lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
247 | )
248 | shuffled_batch = jax.tree_util.tree_map(
249 | lambda x: jnp.take(x, permutation, axis=0), batch
250 | )
251 | minibatches = jax.tree_util.tree_map(
252 | lambda x: jnp.reshape(
253 | x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
254 | ),
255 | shuffled_batch,
256 | )
257 | train_state, total_loss = jax.lax.scan(
258 | _update_minbatch, train_state, minibatches
259 | )
260 | update_state = (train_state, traj_batch, advantages, targets, rng)
261 | return update_state, total_loss
262 |
263 | update_state = (train_state, traj_batch, advantages, targets, rng)
264 | update_state, loss_info = jax.lax.scan(
265 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
266 | )
267 | train_state = update_state[0]
268 | metric = traj_batch.info
269 | rng = update_state[-1]
270 | if config.get("DEBUG"):
271 |
272 | def callback(info):
273 | return_values = info["returned_episode_returns"][
274 | info["returned_episode"]
275 | ]
276 | timesteps = (
277 | info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
278 | )
279 | for t in range(len(timesteps)):
280 | print(
281 | f"global step={timesteps[t]}, episodic return={return_values[t]}"
282 | )
283 |
284 | jax.debug.callback(callback, metric)
285 |
286 | runner_state = (train_state, env_state, last_obs, rng)
287 | return runner_state, metric
288 |
289 | rng, _rng = jax.random.split(rng)
290 | runner_state = (train_state, env_state, obsv, _rng)
291 | runner_state, metric = jax.lax.scan(
292 | _update_step, runner_state, None, config["NUM_UPDATES"]
293 | )
294 | return {"runner_state": runner_state, "metrics": metric}
295 |
296 | return train
297 |
298 |
299 | if __name__ == "__main__":
300 | config = {
301 | "LR": 3e-4,
302 | "NUM_ENVS": 2048,
303 | "NUM_STEPS": 10,
304 | "TOTAL_TIMESTEPS": 5e7,
305 | "UPDATE_EPOCHS": 4,
306 | "NUM_MINIBATCHES": 32,
307 | "GAMMA": 0.99,
308 | "GAE_LAMBDA": 0.95,
309 | "CLIP_EPS": 0.2,
310 | "DPO_ALPHA": 2.0,
311 | "DPO_BETA": 0.6,
312 | "ENT_COEF": 0.0,
313 | "VF_COEF": 0.5,
314 | "MAX_GRAD_NORM": 0.5,
315 | "ACTIVATION": "tanh",
316 | "ENV_NAME": "hopper",
317 | "ANNEAL_LR": False,
318 | "NORMALIZE_ENV": True,
319 | "DEBUG": True,
320 | }
321 | rng = jax.random.PRNGKey(30)
322 | train_jit = jax.jit(make_train(config))
323 | out = train_jit(rng)
324 |
--------------------------------------------------------------------------------
/purejaxrl/dqn.py:
--------------------------------------------------------------------------------
1 | """
2 | PureJaxRL version of CleanRL's DQN: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_jax.py
3 | """
4 | import os
5 | import jax
6 | import jax.numpy as jnp
7 |
8 | import chex
9 | import flax
10 | import wandb
11 | import optax
12 | import flax.linen as nn
13 | from flax.training.train_state import TrainState
14 | from gymnax.wrappers.purerl import FlattenObservationWrapper, LogWrapper
15 | import gymnax
16 | import flashbax as fbx
17 |
18 |
19 | class QNetwork(nn.Module):
20 | action_dim: int
21 |
22 | @nn.compact
23 | def __call__(self, x: jnp.ndarray):
24 | x = nn.Dense(120)(x)
25 | x = nn.relu(x)
26 | x = nn.Dense(84)(x)
27 | x = nn.relu(x)
28 | x = nn.Dense(self.action_dim)(x)
29 | return x
30 |
31 |
32 | @chex.dataclass(frozen=True)
33 | class TimeStep:
34 | obs: chex.Array
35 | action: chex.Array
36 | reward: chex.Array
37 | done: chex.Array
38 |
39 |
40 | class CustomTrainState(TrainState):
41 | target_network_params: flax.core.FrozenDict
42 | timesteps: int
43 | n_updates: int
44 |
45 |
46 | def make_train(config):
47 |
48 | config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // config["NUM_ENVS"]
49 |
50 | basic_env, env_params = gymnax.make(config["ENV_NAME"])
51 | env = FlattenObservationWrapper(basic_env)
52 | env = LogWrapper(env)
53 |
54 | vmap_reset = lambda n_envs: lambda rng: jax.vmap(env.reset, in_axes=(0, None))(
55 | jax.random.split(rng, n_envs), env_params
56 | )
57 | vmap_step = lambda n_envs: lambda rng, env_state, action: jax.vmap(
58 | env.step, in_axes=(0, 0, 0, None)
59 | )(jax.random.split(rng, n_envs), env_state, action, env_params)
60 |
61 | def train(rng):
62 |
63 | # INIT ENV
64 | rng, _rng = jax.random.split(rng)
65 | init_obs, env_state = vmap_reset(config["NUM_ENVS"])(_rng)
66 |
67 | # INIT BUFFER
68 | buffer = fbx.make_flat_buffer(
69 | max_length=config["BUFFER_SIZE"],
70 | min_length=config["BUFFER_BATCH_SIZE"],
71 | sample_batch_size=config["BUFFER_BATCH_SIZE"],
72 | add_sequences=False,
73 | add_batch_size=config["NUM_ENVS"],
74 | )
75 | buffer = buffer.replace(
76 | init=jax.jit(buffer.init),
77 | add=jax.jit(buffer.add, donate_argnums=0),
78 | sample=jax.jit(buffer.sample),
79 | can_sample=jax.jit(buffer.can_sample),
80 | )
81 | rng = jax.random.PRNGKey(0) # use a dummy rng here
82 | _action = basic_env.action_space().sample(rng)
83 | _, _env_state = env.reset(rng, env_params)
84 | _obs, _, _reward, _done, _ = env.step(rng, _env_state, _action, env_params)
85 | _timestep = TimeStep(obs=_obs, action=_action, reward=_reward, done=_done)
86 | buffer_state = buffer.init(_timestep)
87 |
88 | # INIT NETWORK AND OPTIMIZER
89 | network = QNetwork(action_dim=env.action_space(env_params).n)
90 | rng, _rng = jax.random.split(rng)
91 | init_x = jnp.zeros(env.observation_space(env_params).shape)
92 | network_params = network.init(_rng, init_x)
93 |
94 | def linear_schedule(count):
95 | frac = 1.0 - (count / config["NUM_UPDATES"])
96 | return config["LR"] * frac
97 |
98 | lr = linear_schedule if config.get("LR_LINEAR_DECAY", False) else config["LR"]
99 | tx = optax.adam(learning_rate=lr)
100 |
101 | train_state = CustomTrainState.create(
102 | apply_fn=network.apply,
103 | params=network_params,
104 | target_network_params=jax.tree_map(lambda x: jnp.copy(x), network_params),
105 | tx=tx,
106 | timesteps=0,
107 | n_updates=0,
108 | )
109 |
110 | # epsilon-greedy exploration
111 | def eps_greedy_exploration(rng, q_vals, t):
112 | rng_a, rng_e = jax.random.split(
113 | rng, 2
114 | ) # a key for sampling random actions and one for picking
115 | eps = jnp.clip( # get epsilon
116 | (
117 | (config["EPSILON_FINISH"] - config["EPSILON_START"])
118 | / config["EPSILON_ANNEAL_TIME"]
119 | )
120 | * t
121 | + config["EPSILON_START"],
122 | config["EPSILON_FINISH"],
123 | )
124 | greedy_actions = jnp.argmax(q_vals, axis=-1) # get the greedy actions
125 | chosed_actions = jnp.where(
126 | jax.random.uniform(rng_e, greedy_actions.shape)
127 | < eps, # pick the actions that should be random
128 | jax.random.randint(
129 | rng_a, shape=greedy_actions.shape, minval=0, maxval=q_vals.shape[-1]
130 | ), # sample random actions,
131 | greedy_actions,
132 | )
133 | return chosed_actions
134 |
135 | # TRAINING LOOP
136 | def _update_step(runner_state, unused):
137 |
138 | train_state, buffer_state, env_state, last_obs, rng = runner_state
139 |
140 | # STEP THE ENV
141 | rng, rng_a, rng_s = jax.random.split(rng, 3)
142 | q_vals = network.apply(train_state.params, last_obs)
143 | action = eps_greedy_exploration(
144 | rng_a, q_vals, train_state.timesteps
145 | ) # explore with epsilon greedy_exploration
146 | obs, env_state, reward, done, info = vmap_step(config["NUM_ENVS"])(
147 | rng_s, env_state, action
148 | )
149 | train_state = train_state.replace(
150 | timesteps=train_state.timesteps + config["NUM_ENVS"]
151 | ) # update timesteps count
152 |
153 | # BUFFER UPDATE
154 | timestep = TimeStep(obs=last_obs, action=action, reward=reward, done=done)
155 | buffer_state = buffer.add(buffer_state, timestep)
156 |
157 | # NETWORKS UPDATE
158 | def _learn_phase(train_state, rng):
159 |
160 | learn_batch = buffer.sample(buffer_state, rng).experience
161 |
162 | q_next_target = network.apply(
163 | train_state.target_network_params, learn_batch.second.obs
164 | ) # (batch_size, num_actions)
165 | q_next_target = jnp.max(q_next_target, axis=-1) # (batch_size,)
166 | target = (
167 | learn_batch.first.reward
168 | + (1 - learn_batch.first.done) * config["GAMMA"] * q_next_target
169 | )
170 |
171 | def _loss_fn(params):
172 | q_vals = network.apply(
173 | params, learn_batch.first.obs
174 | ) # (batch_size, num_actions)
175 | chosen_action_qvals = jnp.take_along_axis(
176 | q_vals,
177 | jnp.expand_dims(learn_batch.first.action, axis=-1),
178 | axis=-1,
179 | ).squeeze(axis=-1)
180 | return jnp.mean((chosen_action_qvals - target) ** 2)
181 |
182 | loss, grads = jax.value_and_grad(_loss_fn)(train_state.params)
183 | train_state = train_state.apply_gradients(grads=grads)
184 | train_state = train_state.replace(n_updates=train_state.n_updates + 1)
185 | return train_state, loss
186 |
187 | rng, _rng = jax.random.split(rng)
188 | is_learn_time = (
189 | (buffer.can_sample(buffer_state))
190 | & ( # enough experience in buffer
191 | train_state.timesteps > config["LEARNING_STARTS"]
192 | )
193 | & ( # pure exploration phase ended
194 | train_state.timesteps % config["TRAINING_INTERVAL"] == 0
195 | ) # training interval
196 | )
197 | train_state, loss = jax.lax.cond(
198 | is_learn_time,
199 | lambda train_state, rng: _learn_phase(train_state, rng),
200 | lambda train_state, rng: (train_state, jnp.array(0.0)), # do nothing
201 | train_state,
202 | _rng,
203 | )
204 |
205 | # update target network
206 | train_state = jax.lax.cond(
207 | train_state.timesteps % config["TARGET_UPDATE_INTERVAL"] == 0,
208 | lambda train_state: train_state.replace(
209 | target_network_params=optax.incremental_update(
210 | train_state.params,
211 | train_state.target_network_params,
212 | config["TAU"],
213 | )
214 | ),
215 | lambda train_state: train_state,
216 | operand=train_state,
217 | )
218 |
219 | metrics = {
220 | "timesteps": train_state.timesteps,
221 | "updates": train_state.n_updates,
222 | "loss": loss.mean(),
223 | "returns": info["returned_episode_returns"].mean(),
224 | }
225 |
226 | # report on wandb if required
227 | if config.get("WANDB_MODE", "disabled") == "online":
228 |
229 | def callback(metrics):
230 | if metrics["timesteps"] % 100 == 0:
231 | wandb.log(metrics)
232 |
233 | jax.debug.callback(callback, metrics)
234 |
235 | runner_state = (train_state, buffer_state, env_state, obs, rng)
236 |
237 | return runner_state, metrics
238 |
239 | # train
240 | rng, _rng = jax.random.split(rng)
241 | runner_state = (train_state, buffer_state, env_state, init_obs, _rng)
242 |
243 | runner_state, metrics = jax.lax.scan(
244 | _update_step, runner_state, None, config["NUM_UPDATES"]
245 | )
246 | return {"runner_state": runner_state, "metrics": metrics}
247 |
248 | return train
249 |
250 |
251 | def main():
252 |
253 | config = {
254 | "NUM_ENVS": 10,
255 | "BUFFER_SIZE": 10000,
256 | "BUFFER_BATCH_SIZE": 128,
257 | "TOTAL_TIMESTEPS": 5e5,
258 | "EPSILON_START": 1.0,
259 | "EPSILON_FINISH": 0.05,
260 | "EPSILON_ANNEAL_TIME": 25e4,
261 | "TARGET_UPDATE_INTERVAL": 500,
262 | "LR": 2.5e-4,
263 | "LEARNING_STARTS": 10000,
264 | "TRAINING_INTERVAL": 10,
265 | "LR_LINEAR_DECAY": False,
266 | "GAMMA": 0.99,
267 | "TAU": 1.0,
268 | "ENV_NAME": "CartPole-v1",
269 | "SEED": 0,
270 | "NUM_SEEDS": 1,
271 | "WANDB_MODE": "disabled", # set to online to activate wandb
272 | "ENTITY": "",
273 | "PROJECT": "",
274 | }
275 |
276 | wandb.init(
277 | entity=config["ENTITY"],
278 | project=config["PROJECT"],
279 | tags=["DQN", config["ENV_NAME"].upper(), f"jax_{jax.__version__}"],
280 | name=f'purejaxrl_dqn_{config["ENV_NAME"]}',
281 | config=config,
282 | mode=config["WANDB_MODE"],
283 | )
284 |
285 | rng = jax.random.PRNGKey(config["SEED"])
286 | rngs = jax.random.split(rng, config["NUM_SEEDS"])
287 | train_vjit = jax.jit(jax.vmap(make_train(config)))
288 | outs = jax.block_until_ready(train_vjit(rngs))
289 |
290 |
291 | if __name__ == "__main__":
292 | main()
293 |
--------------------------------------------------------------------------------
/purejaxrl/experimental/s5/README.md:
--------------------------------------------------------------------------------
1 | # PPO S5
2 |
3 | This is a re-implementation of the architecture from [this paper](https://arxiv.org/abs/2303.03982).
4 |
5 | This is currently a work-in-progress since the code in its current state needs to be cleaned.
6 |
7 | If you use this code in an academic paper, please cite:
8 |
9 |
10 | ```
11 | @article{lu2023structured,
12 | title={Structured State Space Models for In-Context Reinforcement Learning},
13 | author={Lu, Chris and Schroecker, Yannick and Gu, Albert and Parisotto, Emilio and Foerster, Jakob and Singh, Satinder and Behbahani, Feryal},
14 | journal={arXiv preprint arXiv:2303.03982},
15 | year={2023}
16 | }
17 |
18 | @article{smith2022simplified,
19 | title={Simplified state space layers for sequence modeling},
20 | author={Smith, Jimmy TH and Warrington, Andrew and Linderman, Scott W},
21 | journal={arXiv preprint arXiv:2208.04933},
22 | year={2022}
23 | }
24 |
25 | @article{lu2022discovered,
26 | title={Discovered policy optimisation},
27 | author={Lu, Chris and Kuba, Jakub and Letcher, Alistair and Metz, Luke and Schroeder de Witt, Christian and Foerster, Jakob},
28 | journal={Advances in Neural Information Processing Systems},
29 | volume={35},
30 | pages={16455--16468},
31 | year={2022}
32 | }
33 | ```
--------------------------------------------------------------------------------
/purejaxrl/experimental/s5/ppo_s5.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flax.linen as nn
4 | import numpy as np
5 | import optax
6 | from flax.linen.initializers import constant, orthogonal
7 | from typing import Sequence, NamedTuple, Any, Dict
8 | from flax.training.train_state import TrainState
9 | import distrax
10 | import gymnax
11 | from wrappers import FlattenObservationWrapper, LogWrapper
12 | from gymnax.environments import spaces
13 | from s5 import init_S5SSM, make_DPLR_HiPPO, StackedEncoderModel
14 |
15 | d_model = 256
16 | ssm_size = 256
17 | C_init = "lecun_normal"
18 | discretization="zoh"
19 | dt_min=0.001
20 | dt_max=0.1
21 | n_layers = 4
22 | conj_sym=True
23 | clip_eigs=False
24 | bidirectional=False
25 |
26 | blocks = 1
27 | block_size = int(ssm_size / blocks)
28 |
29 | Lambda, _, B, V, B_orig = make_DPLR_HiPPO(ssm_size)
30 |
31 | block_size = block_size // 2
32 | ssm_size = ssm_size // 2
33 |
34 | Lambda = Lambda[:block_size]
35 | V = V[:, :block_size]
36 |
37 | Vinv = V.conj().T
38 |
39 |
40 | ssm_init_fn = init_S5SSM(H=d_model,
41 | P=ssm_size,
42 | Lambda_re_init=Lambda.real,
43 | Lambda_im_init=Lambda.imag,
44 | V=V,
45 | Vinv=Vinv,
46 | C_init=C_init,
47 | discretization=discretization,
48 | dt_min=dt_min,
49 | dt_max=dt_max,
50 | conj_sym=conj_sym,
51 | clip_eigs=clip_eigs,
52 | bidirectional=bidirectional)
53 |
54 | class ActorCriticS5(nn.Module):
55 | action_dim: Sequence[int]
56 | config: Dict
57 |
58 | def setup(self):
59 | self.encoder_0 = nn.Dense(128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))
60 | self.encoder_1 = nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))
61 |
62 | self.action_body_0 = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))
63 | self.action_body_1 = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))
64 | self.action_decoder = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))
65 |
66 | self.value_body_0 = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))
67 | self.value_body_1 = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))
68 | self.value_decoder = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))
69 |
70 | self.s5 = StackedEncoderModel(
71 | ssm=ssm_init_fn,
72 | d_model=d_model,
73 | n_layers=n_layers,
74 | activation="half_glu1",
75 | )
76 |
77 | def __call__(self, hidden, x):
78 | obs, dones = x
79 | embedding = self.encoder_0(obs)
80 | embedding = nn.leaky_relu(embedding)
81 | embedding = self.encoder_1(embedding)
82 | embedding = nn.leaky_relu(embedding)
83 |
84 | hidden, embedding = self.s5(hidden, embedding, dones)
85 |
86 | actor_mean = self.action_body_0(embedding)
87 | actor_mean = nn.leaky_relu(actor_mean)
88 | actor_mean = self.action_body_1(actor_mean)
89 | actor_mean = nn.leaky_relu(actor_mean)
90 | actor_mean = self.action_decoder(actor_mean)
91 |
92 | pi = distrax.Categorical(logits=actor_mean)
93 |
94 | critic = self.value_body_0(embedding)
95 | critic = nn.leaky_relu(critic)
96 | critic = self.value_body_1(critic)
97 | critic = nn.leaky_relu(critic)
98 | critic = self.value_decoder(critic)
99 |
100 | return hidden, pi, jnp.squeeze(critic, axis=-1)
101 |
102 | class Transition(NamedTuple):
103 | done: jnp.ndarray
104 | action: jnp.ndarray
105 | value: jnp.ndarray
106 | reward: jnp.ndarray
107 | log_prob: jnp.ndarray
108 | obs: jnp.ndarray
109 | info: jnp.ndarray
110 |
111 |
112 | def make_train(config):
113 | config["NUM_UPDATES"] = (
114 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
115 | )
116 | config["MINIBATCH_SIZE"] = (
117 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
118 | )
119 | env, env_params = gymnax.make(config["ENV_NAME"])
120 | env = FlattenObservationWrapper(env)
121 | env = LogWrapper(env)
122 |
123 | def linear_schedule(count):
124 | frac = (
125 | 1.0
126 | - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
127 | / config["NUM_UPDATES"]
128 | )
129 | return config["LR"] * frac
130 |
131 | def train(rng):
132 | # INIT NETWORK
133 | network = ActorCriticS5(env.action_space(env_params).n, config=config)
134 | rng, _rng = jax.random.split(rng)
135 | init_x = (
136 | jnp.zeros(
137 | (1, config["NUM_ENVS"], *env.observation_space(env_params).shape)
138 | ),
139 | jnp.zeros((1, config["NUM_ENVS"])),
140 | )
141 | init_hstate = StackedEncoderModel.initialize_carry(config["NUM_ENVS"], ssm_size, n_layers)
142 | network_params = network.init(_rng, init_hstate, init_x)
143 | if config["ANNEAL_LR"]:
144 | tx = optax.chain(
145 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
146 | optax.adam(learning_rate=linear_schedule, eps=1e-5),
147 | )
148 | else:
149 | tx = optax.chain(
150 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
151 | optax.adam(config["LR"], eps=1e-5),
152 | )
153 | train_state = TrainState.create(
154 | apply_fn=network.apply,
155 | params=network_params,
156 | tx=tx,
157 | )
158 |
159 | # INIT ENV
160 | rng, _rng = jax.random.split(rng)
161 | reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
162 | obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
163 | init_hstate = StackedEncoderModel.initialize_carry(config["NUM_ENVS"], ssm_size, n_layers)
164 |
165 | # TRAIN LOOP
166 | def _update_step(runner_state, unused):
167 | # COLLECT TRAJECTORIES
168 | def _env_step(runner_state, unused):
169 | train_state, env_state, last_obs, last_done, hstate, rng = runner_state
170 | rng, _rng = jax.random.split(rng)
171 |
172 | # SELECT ACTION
173 | ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :])
174 | hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
175 | action = pi.sample(seed=_rng)
176 | log_prob = pi.log_prob(action)
177 | value, action, log_prob = (
178 | value.squeeze(0),
179 | action.squeeze(0),
180 | log_prob.squeeze(0),
181 | )
182 |
183 | # STEP ENV
184 | rng, _rng = jax.random.split(rng)
185 | rng_step = jax.random.split(_rng, config["NUM_ENVS"])
186 | obsv, env_state, reward, done, info = jax.vmap(
187 | env.step, in_axes=(0, 0, 0, None)
188 | )(rng_step, env_state, action, env_params)
189 | transition = Transition(
190 | last_done, action, value, reward, log_prob, last_obs, info
191 | )
192 | runner_state = (train_state, env_state, obsv, done, hstate, rng)
193 | return runner_state, transition
194 |
195 | initial_hstate = runner_state[-2]
196 | runner_state, traj_batch = jax.lax.scan(
197 | _env_step, runner_state, None, config["NUM_STEPS"]
198 | )
199 |
200 | # CALCULATE ADVANTAGE
201 | train_state, env_state, last_obs, last_done, hstate, rng = runner_state
202 | ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :])
203 | _, _, last_val = network.apply(train_state.params, hstate, ac_in)
204 | last_val = last_val.squeeze(0)
205 | def _calculate_gae(traj_batch, last_val, last_done):
206 | def _get_advantages(carry, transition):
207 | gae, next_value, next_done = carry
208 | done, value, reward = transition.done, transition.value, transition.reward
209 | delta = reward + config["GAMMA"] * next_value * (1 - next_done) - value
210 | gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - next_done) * gae
211 | return (gae, value, done), gae
212 | _, advantages = jax.lax.scan(_get_advantages, (jnp.zeros_like(last_val), last_val, last_done), traj_batch, reverse=True, unroll=16)
213 | return advantages, advantages + traj_batch.value
214 | advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
215 |
216 | # UPDATE NETWORK
217 | def _update_epoch(update_state, unused):
218 | def _update_minbatch(train_state, batch_info):
219 | init_hstate, traj_batch, advantages, targets = batch_info
220 |
221 | def _loss_fn(params, init_hstate, traj_batch, gae, targets):
222 | # RERUN NETWORK
223 | _, pi, value = network.apply(
224 | params, init_hstate, (traj_batch.obs, traj_batch.done)
225 | )
226 | log_prob = pi.log_prob(traj_batch.action)
227 |
228 | # CALCULATE VALUE LOSS
229 | value_pred_clipped = traj_batch.value + (
230 | value - traj_batch.value
231 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
232 | value_losses = jnp.square(value - targets)
233 | value_losses_clipped = jnp.square(value_pred_clipped - targets)
234 | value_loss = (
235 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
236 | )
237 |
238 | # CALCULATE ACTOR LOSS
239 | ratio = jnp.exp(log_prob - traj_batch.log_prob)
240 | gae = (gae - gae.mean()) / (gae.std() + 1e-8)
241 | loss_actor1 = ratio * gae
242 | loss_actor2 = (
243 | jnp.clip(
244 | ratio,
245 | 1.0 - config["CLIP_EPS"],
246 | 1.0 + config["CLIP_EPS"],
247 | )
248 | * gae
249 | )
250 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
251 | loss_actor = loss_actor.mean()
252 | entropy = pi.entropy().mean()
253 |
254 | total_loss = (
255 | loss_actor
256 | + config["VF_COEF"] * value_loss
257 | - config["ENT_COEF"] * entropy
258 | )
259 | return total_loss, (value_loss, loss_actor, entropy)
260 |
261 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
262 | total_loss, grads = grad_fn(
263 | train_state.params, init_hstate, traj_batch, advantages, targets
264 | )
265 | train_state = train_state.apply_gradients(grads=grads)
266 | return train_state, total_loss
267 |
268 | (
269 | train_state,
270 | init_hstate,
271 | traj_batch,
272 | advantages,
273 | targets,
274 | rng,
275 | ) = update_state
276 |
277 | rng, _rng = jax.random.split(rng)
278 | permutation = jax.random.permutation(_rng, config["NUM_ENVS"])
279 | batch = (init_hstate, traj_batch, advantages, targets)
280 |
281 | shuffled_batch = jax.tree_util.tree_map(
282 | lambda x: jnp.take(x, permutation, axis=1), batch
283 | )
284 |
285 | minibatches = jax.tree_util.tree_map(
286 | lambda x: jnp.swapaxes(
287 | jnp.reshape(
288 | x,
289 | [x.shape[0], config["NUM_MINIBATCHES"], -1]
290 | + list(x.shape[2:]),
291 | ),
292 | 1,
293 | 0,
294 | ),
295 | shuffled_batch,
296 | )
297 |
298 | train_state, total_loss = jax.lax.scan(
299 | _update_minbatch, train_state, minibatches
300 | )
301 | update_state = (
302 | train_state,
303 | init_hstate,
304 | traj_batch,
305 | advantages,
306 | targets,
307 | rng,
308 | )
309 | return update_state, total_loss
310 |
311 | init_hstate = initial_hstate # TBH
312 | update_state = (
313 | train_state,
314 | init_hstate,
315 | traj_batch,
316 | advantages,
317 | targets,
318 | rng,
319 | )
320 | update_state, loss_info = jax.lax.scan(
321 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
322 | )
323 | train_state = update_state[0]
324 | metric = traj_batch.info
325 | rng = update_state[-1]
326 | if config.get("DEBUG"):
327 | def callback(info):
328 | return_values = info["returned_episode_returns"][info["returned_episode"]]
329 | timesteps = info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
330 | for t in range(len(timesteps)):
331 | print(f"global step={timesteps[t]}, episodic return={return_values[t]}")
332 | jax.debug.callback(callback, metric)
333 |
334 | runner_state = (train_state, env_state, last_obs, last_done, hstate, rng)
335 | return runner_state, metric
336 |
337 | rng, _rng = jax.random.split(rng)
338 | runner_state = (
339 | train_state,
340 | env_state,
341 | obsv,
342 | jnp.zeros((config["NUM_ENVS"]), dtype=bool),
343 | init_hstate,
344 | _rng,
345 | )
346 | runner_state, metric = jax.lax.scan(
347 | _update_step, runner_state, None, config["NUM_UPDATES"]
348 | )
349 | return {"runner_state": runner_state, "metric": metric}
350 |
351 | return train
352 |
353 |
354 | if __name__ == "__main__":
355 | config = {
356 | "LR": 2.5e-4,
357 | "NUM_ENVS": 4,
358 | "NUM_STEPS": 128,
359 | "TOTAL_TIMESTEPS": 5e5,
360 | "UPDATE_EPOCHS": 4,
361 | "NUM_MINIBATCHES": 4,
362 | "GAMMA": 0.99,
363 | "GAE_LAMBDA": 0.95,
364 | "CLIP_EPS": 0.2,
365 | "ENT_COEF": 0.01,
366 | "VF_COEF": 0.5,
367 | "MAX_GRAD_NORM": 0.5,
368 | "ENV_NAME": "CartPole-v1",
369 | "ANNEAL_LR": True,
370 | "DEBUG": True,
371 | }
372 |
373 | rng = jax.random.PRNGKey(30)
374 | train_jit = jax.jit(make_train(config))
375 | out = train_jit(rng)
376 |
--------------------------------------------------------------------------------
/purejaxrl/experimental/s5/s5.py:
--------------------------------------------------------------------------------
1 | """Modified from https://github.com/lindermanlab/S5"""
2 |
3 | from functools import partial
4 | import jax
5 | import jax.numpy as np
6 | import jax.numpy as jnp
7 | from flax import linen as nn
8 | from jax.nn.initializers import lecun_normal, normal
9 | from jax import random
10 | from jax.numpy.linalg import eigh
11 |
12 | class SequenceLayer(nn.Module):
13 | """ Defines a single S5 layer, with S5 SSM, nonlinearity,
14 | dropout, batch/layer norm, etc.
15 | Args:
16 | ssm (nn.Module): the SSM to be used (i.e. S5 ssm)
17 | dropout (float32): dropout rate
18 | d_model (int32): this is the feature size of the layer inputs and outputs
19 | we usually refer to this size as H
20 | activation (string): Type of activation function to use
21 | training (bool): whether in training mode or not
22 | prenorm (bool): apply prenorm if true or postnorm if false
23 | batchnorm (bool): apply batchnorm if true or layernorm if false
24 | bn_momentum (float32): the batchnorm momentum if batchnorm is used
25 | step_rescale (float32): allows for uniformly changing the timescale parameter,
26 | e.g. after training on a different resolution for
27 | the speech commands benchmark
28 | """
29 | ssm: nn.Module
30 | # dropout: float
31 | d_model: int
32 | activation: str = "gelu"
33 | # training: bool = True
34 | # prenorm: bool = False
35 | # batchnorm: bool = False
36 | # bn_momentum: float = 0.90
37 | step_rescale: float = 1.0
38 |
39 | def setup(self):
40 | """Initializes the ssm, batch/layer norm and dropout
41 | """
42 | self.seq = self.ssm(step_rescale=self.step_rescale)
43 |
44 | if self.activation in ["full_glu"]:
45 | self.out1 = nn.Dense(self.d_model)
46 | self.out2 = nn.Dense(self.d_model)
47 | elif self.activation in ["half_glu1", "half_glu2"]:
48 | self.out2 = nn.Dense(self.d_model)
49 |
50 | # if self.batchnorm:
51 | # self.norm = nn.BatchNorm(use_running_average=not self.training,
52 | # momentum=self.bn_momentum, axis_name='batch')
53 | # else:
54 | # self.norm = nn.LayerNorm()
55 |
56 | # self.drop = nn.Dropout(
57 | # self.dropout,
58 | # broadcast_dims=[0],
59 | # deterministic=not self.training,
60 | # )
61 | self.drop = lambda x: x
62 |
63 | def __call__(self, hidden, x, d):
64 | """
65 | Compute the LxH output of S5 layer given an LxH input.
66 | Args:
67 | x (float32): input sequence (L, d_model)
68 | d (bool): reset signal (L,)
69 | Returns:
70 | output sequence (float32): (L, d_model)
71 | """
72 | skip = x
73 | # if self.prenorm:
74 | # x = self.norm(x)
75 | # hidden, x = self.seq(hidden, x, d)
76 | hidden, x = jax.vmap(self.seq, in_axes=1, out_axes=1)(hidden, x, d)
77 | # hidden = jnp.swapaxes(hidden, 1, 0)
78 |
79 | if self.activation in ["full_glu"]:
80 | x = self.drop(nn.gelu(x))
81 | x = self.out1(x) * jax.nn.sigmoid(self.out2(x))
82 | x = self.drop(x)
83 | elif self.activation in ["half_glu1"]:
84 | x = self.drop(nn.gelu(x))
85 | x = x * jax.nn.sigmoid(self.out2(x))
86 | x = self.drop(x)
87 | elif self.activation in ["half_glu2"]:
88 | # Only apply GELU to the gate input
89 | x1 = self.drop(nn.gelu(x))
90 | x = x * jax.nn.sigmoid(self.out2(x1))
91 | x = self.drop(x)
92 | elif self.activation in ["gelu"]:
93 | x = self.drop(nn.gelu(x))
94 | else:
95 | raise NotImplementedError(
96 | "Activation: {} not implemented".format(self.activation))
97 |
98 | x = skip + x
99 | # if not self.prenorm:
100 | # x = self.norm(x)
101 | return hidden, x
102 |
103 | @staticmethod
104 | def initialize_carry(batch_size, hidden_size):
105 | # Use a dummy key since the default state init fn is just zeros.
106 | # return nn.LSTMCell.initialize_carry(
107 | # jax.random.PRNGKey(0), (batch_size,), hidden_size)
108 | return jnp.zeros((1, batch_size, hidden_size), dtype=jnp.complex64)
109 |
110 | def log_step_initializer(dt_min=0.001, dt_max=0.1):
111 | """ Initialize the learnable timescale Delta by sampling
112 | uniformly between dt_min and dt_max.
113 | Args:
114 | dt_min (float32): minimum value
115 | dt_max (float32): maximum value
116 | Returns:
117 | init function
118 | """
119 | def init(key, shape):
120 | """ Init function
121 | Args:
122 | key: jax random key
123 | shape tuple: desired shape
124 | Returns:
125 | sampled log_step (float32)
126 | """
127 | return random.uniform(key, shape) * (
128 | np.log(dt_max) - np.log(dt_min)
129 | ) + np.log(dt_min)
130 |
131 | return init
132 |
133 |
134 | def init_log_steps(key, input):
135 | """ Initialize an array of learnable timescale parameters
136 | Args:
137 | key: jax random key
138 | input: tuple containing the array shape H and
139 | dt_min and dt_max
140 | Returns:
141 | initialized array of timescales (float32): (H,)
142 | """
143 | H, dt_min, dt_max = input
144 | log_steps = []
145 | for i in range(H):
146 | key, skey = random.split(key)
147 | log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,))
148 | log_steps.append(log_step)
149 |
150 | return np.array(log_steps)
151 |
152 |
153 | def init_VinvB(init_fun, rng, shape, Vinv):
154 | """ Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B.
155 | Note we will parameterize this with two different matrices for complex
156 | numbers.
157 | Args:
158 | init_fun: the initialization function to use, e.g. lecun_normal()
159 | rng: jax random key to be used with init function.
160 | shape (tuple): desired shape (P,H)
161 | Vinv: (complex64) the inverse eigenvectors used for initialization
162 | Returns:
163 | B_tilde (complex64) of shape (P,H,2)
164 | """
165 | B = init_fun(rng, shape)
166 | VinvB = Vinv @ B
167 | VinvB_real = VinvB.real
168 | VinvB_imag = VinvB.imag
169 | return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1)
170 |
171 |
172 | def trunc_standard_normal(key, shape):
173 | """ Sample C with a truncated normal distribution with standard deviation 1.
174 | Args:
175 | key: jax random key
176 | shape (tuple): desired shape, of length 3, (H,P,_)
177 | Returns:
178 | sampled C matrix (float32) of shape (H,P,2) (for complex parameterization)
179 | """
180 | H, P, _ = shape
181 | Cs = []
182 | for i in range(H):
183 | key, skey = random.split(key)
184 | C = lecun_normal()(skey, shape=(1, P, 2))
185 | Cs.append(C)
186 | return np.array(Cs)[:, 0]
187 |
188 |
189 | def init_CV(init_fun, rng, shape, V):
190 | """ Initialize C_tilde=CV. First sample C. Then compute CV.
191 | Note we will parameterize this with two different matrices for complex
192 | numbers.
193 | Args:
194 | init_fun: the initialization function to use, e.g. lecun_normal()
195 | rng: jax random key to be used with init function.
196 | shape (tuple): desired shape (H,P)
197 | V: (complex64) the eigenvectors used for initialization
198 | Returns:
199 | C_tilde (complex64) of shape (H,P,2)
200 | """
201 | C_ = init_fun(rng, shape)
202 | C = C_[..., 0] + 1j * C_[..., 1]
203 | CV = C @ V
204 | CV_real = CV.real
205 | CV_imag = CV.imag
206 | return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1)
207 |
208 |
209 | # Discretization functions
210 | def discretize_bilinear(Lambda, B_tilde, Delta):
211 | """ Discretize a diagonalized, continuous-time linear SSM
212 | using bilinear transform method.
213 | Args:
214 | Lambda (complex64): diagonal state matrix (P,)
215 | B_tilde (complex64): input matrix (P, H)
216 | Delta (float32): discretization step sizes (P,)
217 | Returns:
218 | discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H)
219 | """
220 | Identity = np.ones(Lambda.shape[0])
221 |
222 | BL = 1 / (Identity - (Delta / 2.0) * Lambda)
223 | Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda)
224 | B_bar = (BL * Delta)[..., None] * B_tilde
225 | return Lambda_bar, B_bar
226 |
227 |
228 | def discretize_zoh(Lambda, B_tilde, Delta):
229 | """ Discretize a diagonalized, continuous-time linear SSM
230 | using zero-order hold method.
231 | Args:
232 | Lambda (complex64): diagonal state matrix (P,)
233 | B_tilde (complex64): input matrix (P, H)
234 | Delta (float32): discretization step sizes (P,)
235 | Returns:
236 | discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H)
237 | """
238 | Identity = np.ones(Lambda.shape[0])
239 | Lambda_bar = np.exp(Lambda * Delta)
240 | B_bar = (1/Lambda * (Lambda_bar-Identity))[..., None] * B_tilde
241 | return Lambda_bar, B_bar
242 |
243 |
244 | # Parallel scan operations
245 | @jax.vmap
246 | def binary_operator(q_i, q_j):
247 | """ Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A.
248 | Args:
249 | q_i: tuple containing A_i and Bu_i at position i (P,), (P,)
250 | q_j: tuple containing A_j and Bu_j at position j (P,), (P,)
251 | Returns:
252 | new element ( A_out, Bu_out )
253 | """
254 | A_i, b_i = q_i
255 | A_j, b_j = q_j
256 | return A_j * A_i, A_j * b_i + b_j
257 |
258 | # Parallel scan operations
259 | @jax.vmap
260 | def binary_operator_reset(q_i, q_j):
261 | """ Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A.
262 | Args:
263 | q_i: tuple containing A_i and Bu_i at position i (P,), (P,)
264 | q_j: tuple containing A_j and Bu_j at position j (P,), (P,)
265 | Returns:
266 | new element ( A_out, Bu_out )
267 | """
268 | A_i, b_i, c_i = q_i
269 | A_j, b_j, c_j = q_j
270 | return (
271 | (A_j * A_i)*(1 - c_j) + A_j * c_j,
272 | (A_j * b_i + b_j)*(1 - c_j) + b_j * c_j,
273 | c_i * (1 - c_j) + c_j,
274 | )
275 |
276 |
277 |
278 | def apply_ssm(Lambda_bar, B_bar, C_tilde, hidden, input_sequence, resets, conj_sym, bidirectional):
279 | """ Compute the LxH output of discretized SSM given an LxH input.
280 | Args:
281 | Lambda_bar (complex64): discretized diagonal state matrix (P,)
282 | B_bar (complex64): discretized input matrix (P, H)
283 | C_tilde (complex64): output matrix (H, P)
284 | input_sequence (float32): input sequence of features (L, H)
285 | reset (bool): input sequence of features (L,)
286 | conj_sym (bool): whether conjugate symmetry is enforced
287 | bidirectional (bool): whether bidirectional setup is used,
288 | Note for this case C_tilde will have 2P cols
289 | Returns:
290 | ys (float32): the SSM outputs (S5 layer preactivations) (L, H)
291 | """
292 | Lambda_elements = Lambda_bar * jnp.ones((input_sequence.shape[0],
293 | Lambda_bar.shape[0]))
294 | Bu_elements = jax.vmap(lambda u: B_bar @ u)(input_sequence)
295 |
296 | Lambda_elements = jnp.concatenate([
297 | jnp.ones((1, Lambda_bar.shape[0])),
298 | Lambda_elements,
299 | ])
300 |
301 | Bu_elements = jnp.concatenate([
302 | hidden,
303 | Bu_elements,
304 | ])
305 |
306 | resets = jnp.concatenate([
307 | jnp.zeros(1),
308 | resets,
309 | ])
310 |
311 |
312 | _, xs, _ = jax.lax.associative_scan(binary_operator_reset, (Lambda_elements, Bu_elements, resets))
313 | xs = xs[1:]
314 |
315 | if conj_sym:
316 | return xs[np.newaxis, -1], jax.vmap(lambda x: 2*(C_tilde @ x).real)(xs)
317 | else:
318 | return xs[np.newaxis, -1], jax.vmap(lambda x: (C_tilde @ x).real)(xs)
319 |
320 |
321 | class S5SSM(nn.Module):
322 | Lambda_re_init: np.DeviceArray
323 | Lambda_im_init: np.DeviceArray
324 | V: np.DeviceArray
325 | Vinv: np.DeviceArray
326 |
327 | H: int
328 | P: int
329 | C_init: str
330 | discretization: str
331 | dt_min: float
332 | dt_max: float
333 | conj_sym: bool = True
334 | clip_eigs: bool = False
335 | bidirectional: bool = False
336 | step_rescale: float = 1.0
337 |
338 | """ The S5 SSM
339 | Args:
340 | Lambda_re_init (complex64): Real part of init diag state matrix (P,)
341 | Lambda_im_init (complex64): Imag part of init diag state matrix (P,)
342 | V (complex64): Eigenvectors used for init (P,P)
343 | Vinv (complex64): Inverse eigenvectors used for init (P,P)
344 | H (int32): Number of features of input seq
345 | P (int32): state size
346 | C_init (string): Specifies How C is initialized
347 | Options: [trunc_standard_normal: sample from truncated standard normal
348 | and then multiply by V, i.e. C_tilde=CV.
349 | lecun_normal: sample from Lecun_normal and then multiply by V.
350 | complex_normal: directly sample a complex valued output matrix
351 | from standard normal, does not multiply by V]
352 | conj_sym (bool): Whether conjugate symmetry is enforced
353 | clip_eigs (bool): Whether to enforce left-half plane condition, i.e.
354 | constrain real part of eigenvalues to be negative.
355 | True recommended for autoregressive task/unbounded sequence lengths
356 | Discussed in https://arxiv.org/pdf/2206.11893.pdf.
357 | bidirectional (bool): Whether model is bidirectional, if True, uses two C matrices
358 | discretization: (string) Specifies discretization method
359 | options: [zoh: zero-order hold method,
360 | bilinear: bilinear transform]
361 | dt_min: (float32): minimum value to draw timescale values from when
362 | initializing log_step
363 | dt_max: (float32): maximum value to draw timescale values from when
364 | initializing log_step
365 | step_rescale: (float32): allows for uniformly changing the timescale parameter, e.g. after training
366 | on a different resolution for the speech commands benchmark
367 | """
368 |
369 | def setup(self):
370 | """Initializes parameters once and performs discretization each time
371 | the SSM is applied to a sequence
372 | """
373 |
374 | if self.conj_sym:
375 | # Need to account for case where we actually sample real B and C, and then multiply
376 | # by the half sized Vinv and possibly V
377 | local_P = 2*self.P
378 | else:
379 | local_P = self.P
380 |
381 | # Initialize diagonal state to state matrix Lambda (eigenvalues)
382 | self.Lambda_re = self.param("Lambda_re", lambda rng, shape: self.Lambda_re_init, (None,))
383 | self.Lambda_im = self.param("Lambda_im", lambda rng, shape: self.Lambda_im_init, (None,))
384 | if self.clip_eigs:
385 | self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im
386 | else:
387 | self.Lambda = self.Lambda_re + 1j * self.Lambda_im
388 |
389 | # Initialize input to state (B) matrix
390 | B_init = lecun_normal()
391 | B_shape = (local_P, self.H)
392 | self.B = self.param("B",
393 | lambda rng, shape: init_VinvB(B_init,
394 | rng,
395 | shape,
396 | self.Vinv),
397 | B_shape)
398 | B_tilde = self.B[..., 0] + 1j * self.B[..., 1]
399 |
400 | # Initialize state to output (C) matrix
401 | if self.C_init in ["trunc_standard_normal"]:
402 | C_init = trunc_standard_normal
403 | C_shape = (self.H, local_P, 2)
404 | elif self.C_init in ["lecun_normal"]:
405 | C_init = lecun_normal()
406 | C_shape = (self.H, local_P, 2)
407 | elif self.C_init in ["complex_normal"]:
408 | C_init = normal(stddev=0.5 ** 0.5)
409 | else:
410 | raise NotImplementedError(
411 | "C_init method {} not implemented".format(self.C_init))
412 |
413 | if self.C_init in ["complex_normal"]:
414 | if self.bidirectional:
415 | C = self.param("C", C_init, (self.H, 2 * self.P, 2))
416 | self.C_tilde = C[..., 0] + 1j * C[..., 1]
417 |
418 | else:
419 | C = self.param("C", C_init, (self.H, self.P, 2))
420 | self.C_tilde = C[..., 0] + 1j * C[..., 1]
421 |
422 | else:
423 | if self.bidirectional:
424 | self.C1 = self.param("C1",
425 | lambda rng, shape: init_CV(C_init, rng, shape, self.V),
426 | C_shape)
427 | self.C2 = self.param("C2",
428 | lambda rng, shape: init_CV(C_init, rng, shape, self.V),
429 | C_shape)
430 |
431 | C1 = self.C1[..., 0] + 1j * self.C1[..., 1]
432 | C2 = self.C2[..., 0] + 1j * self.C2[..., 1]
433 | self.C_tilde = np.concatenate((C1, C2), axis=-1)
434 |
435 | else:
436 | self.C = self.param("C",
437 | lambda rng, shape: init_CV(C_init, rng, shape, self.V),
438 | C_shape)
439 |
440 | self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1]
441 |
442 | # Initialize feedthrough (D) matrix
443 | self.D = self.param("D", normal(stddev=1.0), (self.H,))
444 |
445 | # Initialize learnable discretization timescale value
446 | self.log_step = self.param("log_step",
447 | init_log_steps,
448 | (self.P, self.dt_min, self.dt_max))
449 | step = self.step_rescale * np.exp(self.log_step[:, 0])
450 |
451 | # Discretize
452 | if self.discretization in ["zoh"]:
453 | self.Lambda_bar, self.B_bar = discretize_zoh(self.Lambda, B_tilde, step)
454 | elif self.discretization in ["bilinear"]:
455 | self.Lambda_bar, self.B_bar = discretize_bilinear(self.Lambda, B_tilde, step)
456 | else:
457 | raise NotImplementedError("Discretization method {} not implemented".format(self.discretization))
458 |
459 | def __call__(self, hidden, input_sequence, resets):
460 | """
461 | Compute the LxH output of the S5 SSM given an LxH input sequence
462 | using a parallel scan.
463 | Args:
464 | input_sequence (float32): input sequence (L, H)
465 | resets (bool): input sequence (L,)
466 | Returns:
467 | output sequence (float32): (L, H)
468 | """
469 | hidden, ys = apply_ssm(self.Lambda_bar,
470 | self.B_bar,
471 | self.C_tilde,
472 | hidden,
473 | input_sequence,
474 | resets,
475 | self.conj_sym,
476 | self.bidirectional)
477 | # Add feedthrough matrix output Du;
478 | Du = jax.vmap(lambda u: self.D * u)(input_sequence)
479 | return hidden, ys + Du
480 |
481 |
482 | def init_S5SSM(H,
483 | P,
484 | Lambda_re_init,
485 | Lambda_im_init,
486 | V,
487 | Vinv,
488 | C_init,
489 | discretization,
490 | dt_min,
491 | dt_max,
492 | conj_sym,
493 | clip_eigs,
494 | bidirectional
495 | ):
496 | """Convenience function that will be used to initialize the SSM.
497 | Same arguments as defined in S5SSM above."""
498 | return partial(S5SSM,
499 | H=H,
500 | P=P,
501 | Lambda_re_init=Lambda_re_init,
502 | Lambda_im_init=Lambda_im_init,
503 | V=V,
504 | Vinv=Vinv,
505 | C_init=C_init,
506 | discretization=discretization,
507 | dt_min=dt_min,
508 | dt_max=dt_max,
509 | conj_sym=conj_sym,
510 | clip_eigs=clip_eigs,
511 | bidirectional=bidirectional)
512 |
513 |
514 | def make_HiPPO(N):
515 | """ Create a HiPPO-LegS matrix.
516 | From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
517 | Args:
518 | N (int32): state size
519 | Returns:
520 | N x N HiPPO LegS matrix
521 | """
522 | P = np.sqrt(1 + 2 * np.arange(N))
523 | A = P[:, np.newaxis] * P[np.newaxis, :]
524 | A = np.tril(A) - np.diag(np.arange(N))
525 | return -A
526 |
527 |
528 | def make_NPLR_HiPPO(N):
529 | """
530 | Makes components needed for NPLR representation of HiPPO-LegS
531 | From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
532 | Args:
533 | N (int32): state size
534 | Returns:
535 | N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B
536 | """
537 | # Make -HiPPO
538 | hippo = make_HiPPO(N)
539 |
540 | # Add in a rank 1 term. Makes it Normal.
541 | P = np.sqrt(np.arange(N) + 0.5)
542 |
543 | # HiPPO also specifies the B matrix
544 | B = np.sqrt(2 * np.arange(N) + 1.0)
545 | return hippo, P, B
546 |
547 |
548 | def make_DPLR_HiPPO(N):
549 | """
550 | Makes components needed for DPLR representation of HiPPO-LegS
551 | From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
552 | Note, we will only use the diagonal part
553 | Args:
554 | N:
555 | Returns:
556 | eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B,
557 | eigenvectors V, HiPPO B pre-conjugation
558 | """
559 | A, P, B = make_NPLR_HiPPO(N)
560 |
561 | S = A + P[:, np.newaxis] * P[np.newaxis, :]
562 |
563 | S_diag = np.diagonal(S)
564 | Lambda_real = np.mean(S_diag) * np.ones_like(S_diag)
565 |
566 | # Diagonalize S to V \Lambda V^*
567 | Lambda_imag, V = eigh(S * -1j)
568 |
569 | P = V.conj().T @ P
570 | B_orig = B
571 | B = V.conj().T @ B
572 | return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig
573 |
574 | class StackedEncoderModel(nn.Module):
575 | """ Defines a stack of S5 layers to be used as an encoder.
576 | Args:
577 | ssm (nn.Module): the SSM to be used (i.e. S5 ssm)
578 | d_model (int32): this is the feature size of the layer inputs and outputs
579 | we usually refer to this size as H
580 | n_layers (int32): the number of S5 layers to stack
581 | activation (string): Type of activation function to use
582 | dropout (float32): dropout rate
583 | training (bool): whether in training mode or not
584 | prenorm (bool): apply prenorm if true or postnorm if false
585 | batchnorm (bool): apply batchnorm if true or layernorm if false
586 | bn_momentum (float32): the batchnorm momentum if batchnorm is used
587 | step_rescale (float32): allows for uniformly changing the timescale parameter,
588 | e.g. after training on a different resolution for
589 | the speech commands benchmark
590 | """
591 | ssm: nn.Module
592 | d_model: int
593 | n_layers: int
594 | activation: str = "gelu"
595 |
596 | def setup(self):
597 | """
598 | Initializes a linear encoder and the stack of S5 layers.
599 | """
600 | self.layers = [
601 | SequenceLayer(
602 | ssm=self.ssm,
603 | d_model=self.d_model,
604 | activation=self.activation,
605 | )
606 | for _ in range(self.n_layers)
607 | ]
608 |
609 | def __call__(self, hidden, x, d):
610 | """
611 | Compute the LxH output of the stacked encoder given an Lxd_input
612 | input sequence.
613 | Args:
614 | x (float32): input sequence (L, d_input)
615 | Returns:
616 | output sequence (float32): (L, d_model)
617 | """
618 | new_hiddens = []
619 | for i, layer in enumerate(self.layers):
620 | new_h, x = layer(hidden[i], x, d)
621 | new_hiddens.append(new_h)
622 |
623 | return new_hiddens, x
624 |
625 | @staticmethod
626 | def initialize_carry(batch_size, hidden_size, n_layers):
627 | # Use a dummy key since the default state init fn is just zeros.
628 | return [jnp.zeros((1, batch_size, hidden_size), dtype=jnp.complex64) for _ in range(n_layers)]
--------------------------------------------------------------------------------
/purejaxrl/experimental/s5/wrappers.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import chex
4 | import numpy as np
5 | from flax import struct
6 | from functools import partial
7 | from typing import Optional, Tuple, Union, Any
8 | from gymnax.environments import environment, spaces
9 | from brax import envs
10 |
11 | class GymnaxWrapper(object):
12 | """Base class for Gymnax wrappers."""
13 |
14 | def __init__(self, env):
15 | self._env = env
16 |
17 | # provide proxy access to regular attributes of wrapped object
18 | def __getattr__(self, name):
19 | return getattr(self._env, name)
20 |
21 | class FlattenObservationWrapper(GymnaxWrapper):
22 | """Flatten the observations of the environment."""
23 |
24 | def __init__(self, env: environment.Environment):
25 | super().__init__(env)
26 |
27 | def observation_space(self, params) -> spaces.Box:
28 | assert isinstance(self._env.observation_space(params), spaces.Box), "Only Box spaces are supported for now."
29 | return spaces.Box(
30 | low=self._env.observation_space(params).low,
31 | high=self._env.observation_space(params).high,
32 | shape=(np.prod(self._env.observation_space(params).shape),),
33 | dtype=self._env.observation_space(params).dtype,
34 | )
35 |
36 | @partial(jax.jit, static_argnums=(0,))
37 | def reset(
38 | self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
39 | ) -> Tuple[chex.Array, environment.EnvState]:
40 | obs, state = self._env.reset(key, params)
41 | obs = jnp.reshape(obs, (-1,))
42 | return obs, state
43 |
44 | @partial(jax.jit, static_argnums=(0,))
45 | def step(
46 | self,
47 | key: chex.PRNGKey,
48 | state: environment.EnvState,
49 | action: Union[int, float],
50 | params: Optional[environment.EnvParams] = None,
51 | ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
52 | obs, state, reward, done, info = self._env.step(key, state, action, params)
53 | obs = jnp.reshape(obs, (-1,))
54 | return obs, state, reward, done, info
55 |
56 | @struct.dataclass
57 | class LogEnvState:
58 | env_state: environment.EnvState
59 | episode_returns: float
60 | episode_lengths: int
61 | returned_episode_returns: float
62 | returned_episode_lengths: int
63 | timestep: int
64 |
65 | class LogWrapper(GymnaxWrapper):
66 | """Log the episode returns and lengths."""
67 |
68 | def __init__(self, env: environment.Environment):
69 | super().__init__(env)
70 |
71 | @partial(jax.jit, static_argnums=(0,))
72 | def reset(
73 | self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
74 | ) -> Tuple[chex.Array, environment.EnvState]:
75 | obs, env_state = self._env.reset(key, params)
76 | state = LogEnvState(env_state, 0, 0, 0, 0, 0)
77 | return obs, state
78 |
79 | @partial(jax.jit, static_argnums=(0,))
80 | def step(
81 | self,
82 | key: chex.PRNGKey,
83 | state: environment.EnvState,
84 | action: Union[int, float],
85 | params: Optional[environment.EnvParams] = None,
86 | ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
87 | obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params)
88 | new_episode_return = state.episode_returns + reward
89 | new_episode_length = state.episode_lengths + 1
90 | state = LogEnvState(
91 | env_state = env_state,
92 | episode_returns = new_episode_return * (1 - done),
93 | episode_lengths = new_episode_length * (1 - done),
94 | returned_episode_returns = state.returned_episode_returns * (1 - done) + new_episode_return * done,
95 | returned_episode_lengths = state.returned_episode_lengths * (1 - done) + new_episode_length * done,
96 | timestep = state.timestep + 1,
97 | )
98 | info["returned_episode_returns"] = state.returned_episode_returns
99 | info["returned_episode_lengths"] = state.returned_episode_lengths
100 | info["timestep"] = state.timestep
101 | info["returned_episode"] = done
102 | return obs, state, reward, done, info
103 |
104 | class BraxGymnaxWrapper:
105 | def __init__(self, env_name, backend="positional"):
106 | env = envs.get_environment(env_name=env_name, backend=backend)
107 | env = envs.wrapper.EpisodeWrapper(env, episode_length=1000, action_repeat=1)
108 | env = envs.wrapper.AutoResetWrapper(env)
109 | self._env = env
110 | self.action_size = env.action_size
111 | self.observation_size = (env.observation_size,)
112 |
113 | def reset(self, key, params=None):
114 | state = self._env.reset(key)
115 | return state.obs, state
116 |
117 | def step(self, key, state, action, params=None):
118 | next_state = self._env.step(state, action)
119 | return next_state.obs, next_state, next_state.reward, next_state.done > 0.5, {}
120 |
121 | def observation_space(self, params):
122 | return spaces.Box(
123 | low=-jnp.inf,
124 | high=jnp.inf,
125 | shape=(self._env.observation_size,),
126 | )
127 |
128 | def action_space(self, params):
129 | return spaces.Box(
130 | low=-1.0,
131 | high=1.0,
132 | shape=(self._env.action_size,),
133 | )
134 |
135 | class ClipAction(GymnaxWrapper):
136 | def __init__(self, env, low=-1.0, high=1.0):
137 | super().__init__(env)
138 | self.low = low
139 | self.high = high
140 |
141 | def step(self, key, state, action, params=None):
142 | """TODO: In theory the below line should be the way to do this."""
143 | # action = jnp.clip(action, self.env.action_space.low, self.env.action_space.high)
144 | action = jnp.clip(action, self.low, self.high)
145 | return self._env.step(key, state, action, params)
146 |
147 | class TransformObservation(GymnaxWrapper):
148 | def __init__(self, env, transform_obs):
149 | super().__init__(env)
150 | self.transform_obs = transform_obs
151 |
152 | def reset(self, key, params=None):
153 | obs, state = self._env.reset(key, params)
154 | return self.transform_obs(obs), state
155 |
156 | def step(self, key, state, action, params=None):
157 | obs, state, reward, done, info = self._env.step(key, state, action, params)
158 | return self.transform_obs(obs), state, reward, done, info
159 |
160 | class TransformReward(GymnaxWrapper):
161 | def __init__(self, env, transform_reward):
162 | super().__init__(env)
163 | self.transform_reward = transform_reward
164 |
165 | def step(self, key, state, action, params=None):
166 | obs, state, reward, done, info = self._env.step(key, state, action, params)
167 | return obs, state, self.transform_reward(reward), done, info
168 |
169 |
170 | class VecEnv(GymnaxWrapper):
171 | def __init__(self, env):
172 | super().__init__(env)
173 | self.reset = jax.vmap(self._env.reset, in_axes=(0, None))
174 | self.step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
175 |
176 | @struct.dataclass
177 | class NormalizeVecObsEnvState:
178 | mean: jnp.ndarray
179 | var: jnp.ndarray
180 | count: float
181 | env_state: environment.EnvState
182 |
183 | class NormalizeVecObservation(GymnaxWrapper):
184 | def __init__(self, env):
185 | super().__init__(env)
186 |
187 | def reset(self, key, params=None):
188 | obs, state = self._env.reset(key, params)
189 | state = NormalizeVecObsEnvState(
190 | mean=jnp.zeros_like(obs),
191 | var=jnp.ones_like(obs),
192 | count=1e-4,
193 | env_state=state,
194 | )
195 | batch_mean = jnp.mean(obs, axis=0)
196 | batch_var = jnp.var(obs, axis=0)
197 | batch_count = obs.shape[0]
198 |
199 | delta = batch_mean - state.mean
200 | tot_count = state.count + batch_count
201 |
202 | new_mean = state.mean + delta * batch_count / tot_count
203 | m_a = state.var * state.count
204 | m_b = batch_var * batch_count
205 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
206 | new_var = M2 / tot_count
207 | new_count = tot_count
208 |
209 | state = NormalizeVecObsEnvState(
210 | mean=new_mean,
211 | var=new_var,
212 | count=new_count,
213 | env_state=state.env_state,
214 | )
215 |
216 | return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state
217 |
218 | def step(self, key, state, action, params=None):
219 | obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params)
220 |
221 | batch_mean = jnp.mean(obs, axis=0)
222 | batch_var = jnp.var(obs, axis=0)
223 | batch_count = obs.shape[0]
224 |
225 | delta = batch_mean - state.mean
226 | tot_count = state.count + batch_count
227 |
228 | new_mean = state.mean + delta * batch_count / tot_count
229 | m_a = state.var * state.count
230 | m_b = batch_var * batch_count
231 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
232 | new_var = M2 / tot_count
233 | new_count = tot_count
234 |
235 | state = NormalizeVecObsEnvState(
236 | mean=new_mean,
237 | var=new_var,
238 | count=new_count,
239 | env_state=env_state,
240 | )
241 | return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state, reward, done, info
242 |
243 |
244 | @struct.dataclass
245 | class NormalizeVecRewEnvState:
246 | mean: jnp.ndarray
247 | var: jnp.ndarray
248 | count: float
249 | return_val: float
250 | env_state: environment.EnvState
251 |
252 | class NormalizeVecReward(GymnaxWrapper):
253 |
254 | def __init__(self, env, gamma):
255 | super().__init__(env)
256 | self.gamma = gamma
257 |
258 | def reset(self, key, params=None):
259 | obs, state = self._env.reset(key, params)
260 | batch_count = obs.shape[0]
261 | state = NormalizeVecRewEnvState(
262 | mean=0.0,
263 | var=1.0,
264 | count=1e-4,
265 | return_val=jnp.zeros((batch_count,)),
266 | env_state=state,
267 | )
268 | return obs, state
269 |
270 | def step(self, key, state, action, params=None):
271 | obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params)
272 | return_val = (state.return_val * self.gamma * (1 - done) + reward)
273 |
274 | batch_mean = jnp.mean(return_val, axis=0)
275 | batch_var = jnp.var(return_val, axis=0)
276 | batch_count = obs.shape[0]
277 |
278 | delta = batch_mean - state.mean
279 | tot_count = state.count + batch_count
280 |
281 | new_mean = state.mean + delta * batch_count / tot_count
282 | m_a = state.var * state.count
283 | m_b = batch_var * batch_count
284 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
285 | new_var = M2 / tot_count
286 | new_count = tot_count
287 |
288 | state = NormalizeVecRewEnvState(
289 | mean=new_mean,
290 | var=new_var,
291 | count=new_count,
292 | return_val=return_val,
293 | env_state=env_state,
294 | )
295 | return obs, state, reward / jnp.sqrt(state.var + 1e-8), done, info
296 |
--------------------------------------------------------------------------------
/purejaxrl/ppo.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flax.linen as nn
4 | import numpy as np
5 | import optax
6 | from flax.linen.initializers import constant, orthogonal
7 | from typing import Sequence, NamedTuple, Any
8 | from flax.training.train_state import TrainState
9 | import distrax
10 | import gymnax
11 | from wrappers import LogWrapper, FlattenObservationWrapper
12 |
13 |
14 | class ActorCritic(nn.Module):
15 | action_dim: Sequence[int]
16 | activation: str = "tanh"
17 |
18 | @nn.compact
19 | def __call__(self, x):
20 | if self.activation == "relu":
21 | activation = nn.relu
22 | else:
23 | activation = nn.tanh
24 | actor_mean = nn.Dense(
25 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
26 | )(x)
27 | actor_mean = activation(actor_mean)
28 | actor_mean = nn.Dense(
29 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
30 | )(actor_mean)
31 | actor_mean = activation(actor_mean)
32 | actor_mean = nn.Dense(
33 | self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
34 | )(actor_mean)
35 | pi = distrax.Categorical(logits=actor_mean)
36 |
37 | critic = nn.Dense(
38 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
39 | )(x)
40 | critic = activation(critic)
41 | critic = nn.Dense(
42 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
43 | )(critic)
44 | critic = activation(critic)
45 | critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
46 | critic
47 | )
48 |
49 | return pi, jnp.squeeze(critic, axis=-1)
50 |
51 |
52 | class Transition(NamedTuple):
53 | done: jnp.ndarray
54 | action: jnp.ndarray
55 | value: jnp.ndarray
56 | reward: jnp.ndarray
57 | log_prob: jnp.ndarray
58 | obs: jnp.ndarray
59 | info: jnp.ndarray
60 |
61 |
62 | def make_train(config):
63 | config["NUM_UPDATES"] = (
64 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
65 | )
66 | config["MINIBATCH_SIZE"] = (
67 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
68 | )
69 | env, env_params = gymnax.make(config["ENV_NAME"])
70 | env = FlattenObservationWrapper(env)
71 | env = LogWrapper(env)
72 |
73 | def linear_schedule(count):
74 | frac = (
75 | 1.0
76 | - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
77 | / config["NUM_UPDATES"]
78 | )
79 | return config["LR"] * frac
80 |
81 | def train(rng):
82 | # INIT NETWORK
83 | network = ActorCritic(
84 | env.action_space(env_params).n, activation=config["ACTIVATION"]
85 | )
86 | rng, _rng = jax.random.split(rng)
87 | init_x = jnp.zeros(env.observation_space(env_params).shape)
88 | network_params = network.init(_rng, init_x)
89 | if config["ANNEAL_LR"]:
90 | tx = optax.chain(
91 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
92 | optax.adam(learning_rate=linear_schedule, eps=1e-5),
93 | )
94 | else:
95 | tx = optax.chain(
96 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
97 | optax.adam(config["LR"], eps=1e-5),
98 | )
99 | train_state = TrainState.create(
100 | apply_fn=network.apply,
101 | params=network_params,
102 | tx=tx,
103 | )
104 |
105 | # INIT ENV
106 | rng, _rng = jax.random.split(rng)
107 | reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
108 | obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
109 |
110 | # TRAIN LOOP
111 | def _update_step(runner_state, unused):
112 | # COLLECT TRAJECTORIES
113 | def _env_step(runner_state, unused):
114 | train_state, env_state, last_obs, rng = runner_state
115 |
116 | # SELECT ACTION
117 | rng, _rng = jax.random.split(rng)
118 | pi, value = network.apply(train_state.params, last_obs)
119 | action = pi.sample(seed=_rng)
120 | log_prob = pi.log_prob(action)
121 |
122 | # STEP ENV
123 | rng, _rng = jax.random.split(rng)
124 | rng_step = jax.random.split(_rng, config["NUM_ENVS"])
125 | obsv, env_state, reward, done, info = jax.vmap(
126 | env.step, in_axes=(0, 0, 0, None)
127 | )(rng_step, env_state, action, env_params)
128 | transition = Transition(
129 | done, action, value, reward, log_prob, last_obs, info
130 | )
131 | runner_state = (train_state, env_state, obsv, rng)
132 | return runner_state, transition
133 |
134 | runner_state, traj_batch = jax.lax.scan(
135 | _env_step, runner_state, None, config["NUM_STEPS"]
136 | )
137 |
138 | # CALCULATE ADVANTAGE
139 | train_state, env_state, last_obs, rng = runner_state
140 | _, last_val = network.apply(train_state.params, last_obs)
141 |
142 | def _calculate_gae(traj_batch, last_val):
143 | def _get_advantages(gae_and_next_value, transition):
144 | gae, next_value = gae_and_next_value
145 | done, value, reward = (
146 | transition.done,
147 | transition.value,
148 | transition.reward,
149 | )
150 | delta = reward + config["GAMMA"] * next_value * (1 - done) - value
151 | gae = (
152 | delta
153 | + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
154 | )
155 | return (gae, value), gae
156 |
157 | _, advantages = jax.lax.scan(
158 | _get_advantages,
159 | (jnp.zeros_like(last_val), last_val),
160 | traj_batch,
161 | reverse=True,
162 | unroll=16,
163 | )
164 | return advantages, advantages + traj_batch.value
165 |
166 | advantages, targets = _calculate_gae(traj_batch, last_val)
167 |
168 | # UPDATE NETWORK
169 | def _update_epoch(update_state, unused):
170 | def _update_minbatch(train_state, batch_info):
171 | traj_batch, advantages, targets = batch_info
172 |
173 | def _loss_fn(params, traj_batch, gae, targets):
174 | # RERUN NETWORK
175 | pi, value = network.apply(params, traj_batch.obs)
176 | log_prob = pi.log_prob(traj_batch.action)
177 |
178 | # CALCULATE VALUE LOSS
179 | value_pred_clipped = traj_batch.value + (
180 | value - traj_batch.value
181 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
182 | value_losses = jnp.square(value - targets)
183 | value_losses_clipped = jnp.square(value_pred_clipped - targets)
184 | value_loss = (
185 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
186 | )
187 |
188 | # CALCULATE ACTOR LOSS
189 | ratio = jnp.exp(log_prob - traj_batch.log_prob)
190 | gae = (gae - gae.mean()) / (gae.std() + 1e-8)
191 | loss_actor1 = ratio * gae
192 | loss_actor2 = (
193 | jnp.clip(
194 | ratio,
195 | 1.0 - config["CLIP_EPS"],
196 | 1.0 + config["CLIP_EPS"],
197 | )
198 | * gae
199 | )
200 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
201 | loss_actor = loss_actor.mean()
202 | entropy = pi.entropy().mean()
203 |
204 | total_loss = (
205 | loss_actor
206 | + config["VF_COEF"] * value_loss
207 | - config["ENT_COEF"] * entropy
208 | )
209 | return total_loss, (value_loss, loss_actor, entropy)
210 |
211 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
212 | total_loss, grads = grad_fn(
213 | train_state.params, traj_batch, advantages, targets
214 | )
215 | train_state = train_state.apply_gradients(grads=grads)
216 | return train_state, total_loss
217 |
218 | train_state, traj_batch, advantages, targets, rng = update_state
219 | rng, _rng = jax.random.split(rng)
220 | # Batching and Shuffling
221 | batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
222 | assert (
223 | batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
224 | ), "batch size must be equal to number of steps * number of envs"
225 | permutation = jax.random.permutation(_rng, batch_size)
226 | batch = (traj_batch, advantages, targets)
227 | batch = jax.tree_util.tree_map(
228 | lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
229 | )
230 | shuffled_batch = jax.tree_util.tree_map(
231 | lambda x: jnp.take(x, permutation, axis=0), batch
232 | )
233 | # Mini-batch Updates
234 | minibatches = jax.tree_util.tree_map(
235 | lambda x: jnp.reshape(
236 | x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
237 | ),
238 | shuffled_batch,
239 | )
240 | train_state, total_loss = jax.lax.scan(
241 | _update_minbatch, train_state, minibatches
242 | )
243 | update_state = (train_state, traj_batch, advantages, targets, rng)
244 | return update_state, total_loss
245 | # Updating Training State and Metrics:
246 | update_state = (train_state, traj_batch, advantages, targets, rng)
247 | update_state, loss_info = jax.lax.scan(
248 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
249 | )
250 | train_state = update_state[0]
251 | metric = traj_batch.info
252 | rng = update_state[-1]
253 |
254 | # Debugging mode
255 | if config.get("DEBUG"):
256 | def callback(info):
257 | return_values = info["returned_episode_returns"][info["returned_episode"]]
258 | timesteps = info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
259 | for t in range(len(timesteps)):
260 | print(f"global step={timesteps[t]}, episodic return={return_values[t]}")
261 | jax.debug.callback(callback, metric)
262 |
263 | runner_state = (train_state, env_state, last_obs, rng)
264 | return runner_state, metric
265 |
266 | rng, _rng = jax.random.split(rng)
267 | runner_state = (train_state, env_state, obsv, _rng)
268 | runner_state, metric = jax.lax.scan(
269 | _update_step, runner_state, None, config["NUM_UPDATES"]
270 | )
271 | return {"runner_state": runner_state, "metrics": metric}
272 |
273 | return train
274 |
275 |
276 | if __name__ == "__main__":
277 | config = {
278 | "LR": 2.5e-4,
279 | "NUM_ENVS": 4,
280 | "NUM_STEPS": 128,
281 | "TOTAL_TIMESTEPS": 5e5,
282 | "UPDATE_EPOCHS": 4,
283 | "NUM_MINIBATCHES": 4,
284 | "GAMMA": 0.99,
285 | "GAE_LAMBDA": 0.95,
286 | "CLIP_EPS": 0.2,
287 | "ENT_COEF": 0.01,
288 | "VF_COEF": 0.5,
289 | "MAX_GRAD_NORM": 0.5,
290 | "ACTIVATION": "tanh",
291 | "ENV_NAME": "CartPole-v1",
292 | "ANNEAL_LR": True,
293 | "DEBUG": True,
294 | }
295 | rng = jax.random.PRNGKey(30)
296 | train_jit = jax.jit(make_train(config))
297 | out = train_jit(rng)
298 |
--------------------------------------------------------------------------------
/purejaxrl/ppo_continuous_action.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flax.linen as nn
4 | import numpy as np
5 | import optax
6 | from flax.linen.initializers import constant, orthogonal
7 | from typing import Sequence, NamedTuple, Any
8 | from flax.training.train_state import TrainState
9 | import distrax
10 | from wrappers import (
11 | LogWrapper,
12 | BraxGymnaxWrapper,
13 | VecEnv,
14 | NormalizeVecObservation,
15 | NormalizeVecReward,
16 | ClipAction,
17 | )
18 |
19 |
20 | class ActorCritic(nn.Module):
21 | action_dim: Sequence[int]
22 | activation: str = "tanh"
23 |
24 | @nn.compact
25 | def __call__(self, x):
26 | if self.activation == "relu":
27 | activation = nn.relu
28 | else:
29 | activation = nn.tanh
30 | actor_mean = nn.Dense(
31 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
32 | )(x)
33 | actor_mean = activation(actor_mean)
34 | actor_mean = nn.Dense(
35 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
36 | )(actor_mean)
37 | actor_mean = activation(actor_mean)
38 | actor_mean = nn.Dense(
39 | self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
40 | )(actor_mean)
41 | actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
42 | pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))
43 |
44 | critic = nn.Dense(
45 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
46 | )(x)
47 | critic = activation(critic)
48 | critic = nn.Dense(
49 | 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
50 | )(critic)
51 | critic = activation(critic)
52 | critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
53 | critic
54 | )
55 |
56 | return pi, jnp.squeeze(critic, axis=-1)
57 |
58 |
59 | class Transition(NamedTuple):
60 | done: jnp.ndarray
61 | action: jnp.ndarray
62 | value: jnp.ndarray
63 | reward: jnp.ndarray
64 | log_prob: jnp.ndarray
65 | obs: jnp.ndarray
66 | info: jnp.ndarray
67 |
68 |
69 | def make_train(config):
70 | config["NUM_UPDATES"] = (
71 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
72 | )
73 | config["MINIBATCH_SIZE"] = (
74 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
75 | )
76 | env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None
77 | env = LogWrapper(env)
78 | env = ClipAction(env)
79 | env = VecEnv(env)
80 | if config["NORMALIZE_ENV"]:
81 | env = NormalizeVecObservation(env)
82 | env = NormalizeVecReward(env, config["GAMMA"])
83 |
84 | def linear_schedule(count):
85 | frac = (
86 | 1.0
87 | - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
88 | / config["NUM_UPDATES"]
89 | )
90 | return config["LR"] * frac
91 |
92 | def train(rng):
93 | # INIT NETWORK
94 | network = ActorCritic(
95 | env.action_space(env_params).shape[0], activation=config["ACTIVATION"]
96 | )
97 | rng, _rng = jax.random.split(rng)
98 | init_x = jnp.zeros(env.observation_space(env_params).shape)
99 | network_params = network.init(_rng, init_x)
100 | if config["ANNEAL_LR"]:
101 | tx = optax.chain(
102 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
103 | optax.adam(learning_rate=linear_schedule, eps=1e-5),
104 | )
105 | else:
106 | tx = optax.chain(
107 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
108 | optax.adam(config["LR"], eps=1e-5),
109 | )
110 | train_state = TrainState.create(
111 | apply_fn=network.apply,
112 | params=network_params,
113 | tx=tx,
114 | )
115 |
116 | # INIT ENV
117 | rng, _rng = jax.random.split(rng)
118 | reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
119 | obsv, env_state = env.reset(reset_rng, env_params)
120 |
121 | # TRAIN LOOP
122 | def _update_step(runner_state, unused):
123 | # COLLECT TRAJECTORIES
124 | def _env_step(runner_state, unused):
125 | train_state, env_state, last_obs, rng = runner_state
126 |
127 | # SELECT ACTION
128 | rng, _rng = jax.random.split(rng)
129 | pi, value = network.apply(train_state.params, last_obs)
130 | action = pi.sample(seed=_rng)
131 | log_prob = pi.log_prob(action)
132 |
133 | # STEP ENV
134 | rng, _rng = jax.random.split(rng)
135 | rng_step = jax.random.split(_rng, config["NUM_ENVS"])
136 | obsv, env_state, reward, done, info = env.step(
137 | rng_step, env_state, action, env_params
138 | )
139 | transition = Transition(
140 | done, action, value, reward, log_prob, last_obs, info
141 | )
142 | runner_state = (train_state, env_state, obsv, rng)
143 | return runner_state, transition
144 |
145 | runner_state, traj_batch = jax.lax.scan(
146 | _env_step, runner_state, None, config["NUM_STEPS"]
147 | )
148 |
149 | # CALCULATE ADVANTAGE
150 | train_state, env_state, last_obs, rng = runner_state
151 | _, last_val = network.apply(train_state.params, last_obs)
152 |
153 | def _calculate_gae(traj_batch, last_val):
154 | def _get_advantages(gae_and_next_value, transition):
155 | gae, next_value = gae_and_next_value
156 | done, value, reward = (
157 | transition.done,
158 | transition.value,
159 | transition.reward,
160 | )
161 | delta = reward + config["GAMMA"] * next_value * (1 - done) - value
162 | gae = (
163 | delta
164 | + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
165 | )
166 | return (gae, value), gae
167 |
168 | _, advantages = jax.lax.scan(
169 | _get_advantages,
170 | (jnp.zeros_like(last_val), last_val),
171 | traj_batch,
172 | reverse=True,
173 | unroll=16,
174 | )
175 | return advantages, advantages + traj_batch.value
176 |
177 | advantages, targets = _calculate_gae(traj_batch, last_val)
178 |
179 | # UPDATE NETWORK
180 | def _update_epoch(update_state, unused):
181 | def _update_minbatch(train_state, batch_info):
182 | traj_batch, advantages, targets = batch_info
183 |
184 | def _loss_fn(params, traj_batch, gae, targets):
185 | # RERUN NETWORK
186 | pi, value = network.apply(params, traj_batch.obs)
187 | log_prob = pi.log_prob(traj_batch.action)
188 |
189 | # CALCULATE VALUE LOSS
190 | value_pred_clipped = traj_batch.value + (
191 | value - traj_batch.value
192 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
193 | value_losses = jnp.square(value - targets)
194 | value_losses_clipped = jnp.square(value_pred_clipped - targets)
195 | value_loss = (
196 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
197 | )
198 |
199 | # CALCULATE ACTOR LOSS
200 | ratio = jnp.exp(log_prob - traj_batch.log_prob)
201 | gae = (gae - gae.mean()) / (gae.std() + 1e-8)
202 | loss_actor1 = ratio * gae
203 | loss_actor2 = (
204 | jnp.clip(
205 | ratio,
206 | 1.0 - config["CLIP_EPS"],
207 | 1.0 + config["CLIP_EPS"],
208 | )
209 | * gae
210 | )
211 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
212 | loss_actor = loss_actor.mean()
213 | entropy = pi.entropy().mean()
214 |
215 | total_loss = (
216 | loss_actor
217 | + config["VF_COEF"] * value_loss
218 | - config["ENT_COEF"] * entropy
219 | )
220 | return total_loss, (value_loss, loss_actor, entropy)
221 |
222 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
223 | total_loss, grads = grad_fn(
224 | train_state.params, traj_batch, advantages, targets
225 | )
226 | train_state = train_state.apply_gradients(grads=grads)
227 | return train_state, total_loss
228 |
229 | train_state, traj_batch, advantages, targets, rng = update_state
230 | rng, _rng = jax.random.split(rng)
231 | batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
232 | assert (
233 | batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
234 | ), "batch size must be equal to number of steps * number of envs"
235 | permutation = jax.random.permutation(_rng, batch_size)
236 | batch = (traj_batch, advantages, targets)
237 | batch = jax.tree_util.tree_map(
238 | lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
239 | )
240 | shuffled_batch = jax.tree_util.tree_map(
241 | lambda x: jnp.take(x, permutation, axis=0), batch
242 | )
243 | minibatches = jax.tree_util.tree_map(
244 | lambda x: jnp.reshape(
245 | x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
246 | ),
247 | shuffled_batch,
248 | )
249 | train_state, total_loss = jax.lax.scan(
250 | _update_minbatch, train_state, minibatches
251 | )
252 | update_state = (train_state, traj_batch, advantages, targets, rng)
253 | return update_state, total_loss
254 |
255 | update_state = (train_state, traj_batch, advantages, targets, rng)
256 | update_state, loss_info = jax.lax.scan(
257 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
258 | )
259 | train_state = update_state[0]
260 | metric = traj_batch.info
261 | rng = update_state[-1]
262 | if config.get("DEBUG"):
263 |
264 | def callback(info):
265 | return_values = info["returned_episode_returns"][
266 | info["returned_episode"]
267 | ]
268 | timesteps = (
269 | info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
270 | )
271 | for t in range(len(timesteps)):
272 | print(
273 | f"global step={timesteps[t]}, episodic return={return_values[t]}"
274 | )
275 |
276 | jax.debug.callback(callback, metric)
277 |
278 | runner_state = (train_state, env_state, last_obs, rng)
279 | return runner_state, metric
280 |
281 | rng, _rng = jax.random.split(rng)
282 | runner_state = (train_state, env_state, obsv, _rng)
283 | runner_state, metric = jax.lax.scan(
284 | _update_step, runner_state, None, config["NUM_UPDATES"]
285 | )
286 | return {"runner_state": runner_state, "metrics": metric}
287 |
288 | return train
289 |
290 |
291 | if __name__ == "__main__":
292 | config = {
293 | "LR": 3e-4,
294 | "NUM_ENVS": 2048,
295 | "NUM_STEPS": 10,
296 | "TOTAL_TIMESTEPS": 5e7,
297 | "UPDATE_EPOCHS": 4,
298 | "NUM_MINIBATCHES": 32,
299 | "GAMMA": 0.99,
300 | "GAE_LAMBDA": 0.95,
301 | "CLIP_EPS": 0.2,
302 | "ENT_COEF": 0.0,
303 | "VF_COEF": 0.5,
304 | "MAX_GRAD_NORM": 0.5,
305 | "ACTIVATION": "tanh",
306 | "ENV_NAME": "hopper",
307 | "ANNEAL_LR": False,
308 | "NORMALIZE_ENV": True,
309 | "DEBUG": True,
310 | }
311 | rng = jax.random.PRNGKey(30)
312 | train_jit = jax.jit(make_train(config))
313 | out = train_jit(rng)
314 |
--------------------------------------------------------------------------------
/purejaxrl/ppo_minigrid.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flax.linen as nn
4 | import numpy as np
5 | import optax
6 | from flax.linen.initializers import constant, orthogonal
7 | from typing import Sequence, NamedTuple, Any
8 | from flax.training.train_state import TrainState
9 | import distrax
10 | import gymnax
11 | from wrappers import LogWrapper, FlattenObservationWrapper, NavixGymnaxWrapper
12 |
13 |
14 | class ActorCritic(nn.Module):
15 | action_dim: Sequence[int]
16 | activation: str = "tanh"
17 |
18 | @nn.compact
19 | def __call__(self, x):
20 | if self.activation == "relu":
21 | activation = nn.relu
22 | else:
23 | activation = nn.tanh
24 | actor_mean = nn.Dense(
25 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
26 | )(x)
27 | actor_mean = activation(actor_mean)
28 | actor_mean = nn.Dense(
29 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
30 | )(actor_mean)
31 | actor_mean = activation(actor_mean)
32 | actor_mean = nn.Dense(
33 | self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
34 | )(actor_mean)
35 | pi = distrax.Categorical(logits=actor_mean)
36 |
37 | critic = nn.Dense(
38 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
39 | )(x)
40 | critic = activation(critic)
41 | critic = nn.Dense(
42 | 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
43 | )(critic)
44 | critic = activation(critic)
45 | critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
46 | critic
47 | )
48 |
49 | return pi, jnp.squeeze(critic, axis=-1)
50 |
51 |
52 | class Transition(NamedTuple):
53 | done: jnp.ndarray
54 | action: jnp.ndarray
55 | value: jnp.ndarray
56 | reward: jnp.ndarray
57 | log_prob: jnp.ndarray
58 | obs: jnp.ndarray
59 | info: jnp.ndarray
60 |
61 |
62 | def make_train(config):
63 | config["NUM_UPDATES"] = (
64 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
65 | )
66 | config["MINIBATCH_SIZE"] = (
67 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
68 | )
69 | env, env_params = NavixGymnaxWrapper(config["ENV_NAME"]), None
70 | env = FlattenObservationWrapper(env)
71 | env = LogWrapper(env)
72 |
73 | def linear_schedule(count):
74 | frac = (
75 | 1.0
76 | - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
77 | / config["NUM_UPDATES"]
78 | )
79 | return config["LR"] * frac
80 |
81 | def train(rng):
82 | # INIT NETWORK
83 | network = ActorCritic(
84 | env.action_space(env_params).n, activation=config["ACTIVATION"]
85 | )
86 | rng, _rng = jax.random.split(rng)
87 | init_x = jnp.zeros(env.observation_space(env_params).shape)
88 | import pdb; pdb.set_trace()
89 | network_params = network.init(_rng, init_x)
90 | if config["ANNEAL_LR"]:
91 | tx = optax.chain(
92 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
93 | optax.adam(learning_rate=linear_schedule, eps=1e-5),
94 | )
95 | else:
96 | tx = optax.chain(
97 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
98 | optax.adam(config["LR"], eps=1e-5),
99 | )
100 | train_state = TrainState.create(
101 | apply_fn=network.apply,
102 | params=network_params,
103 | tx=tx,
104 | )
105 |
106 | # INIT ENV
107 | rng, _rng = jax.random.split(rng)
108 | reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
109 | obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
110 |
111 | # TRAIN LOOP
112 | def _update_step(runner_state, unused):
113 | # COLLECT TRAJECTORIES
114 | def _env_step(runner_state, unused):
115 | train_state, env_state, last_obs, rng = runner_state
116 |
117 | # SELECT ACTION
118 | rng, _rng = jax.random.split(rng)
119 | pi, value = network.apply(train_state.params, last_obs)
120 | action = pi.sample(seed=_rng)
121 | log_prob = pi.log_prob(action)
122 |
123 | # STEP ENV
124 | rng, _rng = jax.random.split(rng)
125 | rng_step = jax.random.split(_rng, config["NUM_ENVS"])
126 | obsv, env_state, reward, done, info = jax.vmap(
127 | env.step, in_axes=(0, 0, 0, None)
128 | )(rng_step, env_state, action, env_params)
129 | transition = Transition(
130 | done, action, value, reward, log_prob, last_obs, info
131 | )
132 | runner_state = (train_state, env_state, obsv, rng)
133 | return runner_state, transition
134 |
135 | runner_state, traj_batch = jax.lax.scan(
136 | _env_step, runner_state, None, config["NUM_STEPS"]
137 | )
138 |
139 | # CALCULATE ADVANTAGE
140 | train_state, env_state, last_obs, rng = runner_state
141 | _, last_val = network.apply(train_state.params, last_obs)
142 |
143 | def _calculate_gae(traj_batch, last_val):
144 | def _get_advantages(gae_and_next_value, transition):
145 | gae, next_value = gae_and_next_value
146 | done, value, reward = (
147 | transition.done,
148 | transition.value,
149 | transition.reward,
150 | )
151 | delta = reward + config["GAMMA"] * next_value * (1 - done) - value
152 | gae = (
153 | delta
154 | + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
155 | )
156 | return (gae, value), gae
157 |
158 | _, advantages = jax.lax.scan(
159 | _get_advantages,
160 | (jnp.zeros_like(last_val), last_val),
161 | traj_batch,
162 | reverse=True,
163 | unroll=16,
164 | )
165 | return advantages, advantages + traj_batch.value
166 |
167 | advantages, targets = _calculate_gae(traj_batch, last_val)
168 |
169 | # UPDATE NETWORK
170 | def _update_epoch(update_state, unused):
171 | def _update_minbatch(train_state, batch_info):
172 | traj_batch, advantages, targets = batch_info
173 |
174 | def _loss_fn(params, traj_batch, gae, targets):
175 | # RERUN NETWORK
176 | pi, value = network.apply(params, traj_batch.obs)
177 | log_prob = pi.log_prob(traj_batch.action)
178 |
179 | # CALCULATE VALUE LOSS
180 | value_pred_clipped = traj_batch.value + (
181 | value - traj_batch.value
182 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
183 | value_losses = jnp.square(value - targets)
184 | value_losses_clipped = jnp.square(value_pred_clipped - targets)
185 | value_loss = (
186 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
187 | )
188 |
189 | # CALCULATE ACTOR LOSS
190 | ratio = jnp.exp(log_prob - traj_batch.log_prob)
191 | gae = (gae - gae.mean()) / (gae.std() + 1e-8)
192 | loss_actor1 = ratio * gae
193 | loss_actor2 = (
194 | jnp.clip(
195 | ratio,
196 | 1.0 - config["CLIP_EPS"],
197 | 1.0 + config["CLIP_EPS"],
198 | )
199 | * gae
200 | )
201 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
202 | loss_actor = loss_actor.mean()
203 | entropy = pi.entropy().mean()
204 |
205 | total_loss = (
206 | loss_actor
207 | + config["VF_COEF"] * value_loss
208 | - config["ENT_COEF"] * entropy
209 | )
210 | return total_loss, (value_loss, loss_actor, entropy)
211 |
212 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
213 | total_loss, grads = grad_fn(
214 | train_state.params, traj_batch, advantages, targets
215 | )
216 | train_state = train_state.apply_gradients(grads=grads)
217 | return train_state, total_loss
218 |
219 | train_state, traj_batch, advantages, targets, rng = update_state
220 | rng, _rng = jax.random.split(rng)
221 | # Batching and Shuffling
222 | batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
223 | assert (
224 | batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
225 | ), "batch size must be equal to number of steps * number of envs"
226 | permutation = jax.random.permutation(_rng, batch_size)
227 | batch = (traj_batch, advantages, targets)
228 | batch = jax.tree_util.tree_map(
229 | lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
230 | )
231 | shuffled_batch = jax.tree_util.tree_map(
232 | lambda x: jnp.take(x, permutation, axis=0), batch
233 | )
234 | # Mini-batch Updates
235 | minibatches = jax.tree_util.tree_map(
236 | lambda x: jnp.reshape(
237 | x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
238 | ),
239 | shuffled_batch,
240 | )
241 | train_state, total_loss = jax.lax.scan(
242 | _update_minbatch, train_state, minibatches
243 | )
244 | update_state = (train_state, traj_batch, advantages, targets, rng)
245 | return update_state, total_loss
246 | # Updating Training State and Metrics:
247 | update_state = (train_state, traj_batch, advantages, targets, rng)
248 | update_state, loss_info = jax.lax.scan(
249 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
250 | )
251 | train_state = update_state[0]
252 | metric = traj_batch.info
253 | rng = update_state[-1]
254 |
255 | # Debugging mode
256 | if config.get("DEBUG"):
257 | def callback(info):
258 | return_values = info["returned_episode_returns"][info["returned_episode"]]
259 | timesteps = info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
260 | for t in range(len(timesteps)):
261 | print(f"global step={timesteps[t]}, episodic return={return_values[t]}")
262 | jax.debug.callback(callback, metric)
263 |
264 | runner_state = (train_state, env_state, last_obs, rng)
265 | return runner_state, metric
266 |
267 | rng, _rng = jax.random.split(rng)
268 | runner_state = (train_state, env_state, obsv, _rng)
269 | runner_state, metric = jax.lax.scan(
270 | _update_step, runner_state, None, config["NUM_UPDATES"]
271 | )
272 | return {"runner_state": runner_state, "metrics": metric}
273 |
274 | return train
275 |
276 |
277 | if __name__ == "__main__":
278 | config = {
279 | "LR": 2.5e-4,
280 | "NUM_ENVS": 16,
281 | "NUM_STEPS": 128,
282 | "TOTAL_TIMESTEPS": 1e6,
283 | "UPDATE_EPOCHS": 1,
284 | "NUM_MINIBATCHES": 8,
285 | "GAMMA": 0.99,
286 | "GAE_LAMBDA": 0.95,
287 | "CLIP_EPS": 0.2,
288 | "ENT_COEF": 0.01,
289 | "VF_COEF": 0.5,
290 | "MAX_GRAD_NORM": 0.5,
291 | "ACTIVATION": "tanh",
292 | "ENV_NAME": "Navix-DoorKey-5x5-v0",
293 | "ANNEAL_LR": True,
294 | "DEBUG": True,
295 | }
296 | rng = jax.random.PRNGKey(30)
297 | train_jit = jax.jit(make_train(config))
298 | out = train_jit(rng)
299 |
--------------------------------------------------------------------------------
/purejaxrl/ppo_rnn.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flax.linen as nn
4 | import numpy as np
5 | import optax
6 | import time
7 | from flax.linen.initializers import constant, orthogonal
8 | from typing import Sequence, NamedTuple, Any, Dict
9 | from flax.training.train_state import TrainState
10 | import distrax
11 | import gymnax
12 | import functools
13 | from gymnax.environments import spaces
14 | from wrappers import FlattenObservationWrapper, LogWrapper
15 |
16 |
17 | class ScannedRNN(nn.Module):
18 | @functools.partial(
19 | nn.scan,
20 | variable_broadcast="params",
21 | in_axes=0,
22 | out_axes=0,
23 | split_rngs={"params": False},
24 | )
25 | @nn.compact
26 | def __call__(self, carry, x):
27 | """Applies the module."""
28 | rnn_state = carry
29 | ins, resets = x
30 | rnn_state = jnp.where(
31 | resets[:, np.newaxis],
32 | self.initialize_carry(ins.shape[0], ins.shape[1]),
33 | rnn_state,
34 | )
35 | new_rnn_state, y = nn.GRUCell()(rnn_state, ins)
36 | return new_rnn_state, y
37 |
38 | @staticmethod
39 | def initialize_carry(batch_size, hidden_size):
40 | # Use a dummy key since the default state init fn is just zeros.
41 | return nn.GRUCell.initialize_carry(
42 | jax.random.PRNGKey(0), (batch_size,), hidden_size
43 | )
44 |
45 |
46 | class ActorCriticRNN(nn.Module):
47 | action_dim: Sequence[int]
48 | config: Dict
49 |
50 | @nn.compact
51 | def __call__(self, hidden, x):
52 | obs, dones = x
53 | embedding = nn.Dense(
54 | 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
55 | )(obs)
56 | embedding = nn.relu(embedding)
57 |
58 | rnn_in = (embedding, dones)
59 | hidden, embedding = ScannedRNN()(hidden, rnn_in)
60 |
61 | actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
62 | embedding
63 | )
64 | actor_mean = nn.relu(actor_mean)
65 | actor_mean = nn.Dense(
66 | self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
67 | )(actor_mean)
68 |
69 | pi = distrax.Categorical(logits=actor_mean)
70 |
71 | critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
72 | embedding
73 | )
74 | critic = nn.relu(critic)
75 | critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
76 | critic
77 | )
78 |
79 | return hidden, pi, jnp.squeeze(critic, axis=-1)
80 |
81 |
82 | class Transition(NamedTuple):
83 | done: jnp.ndarray
84 | action: jnp.ndarray
85 | value: jnp.ndarray
86 | reward: jnp.ndarray
87 | log_prob: jnp.ndarray
88 | obs: jnp.ndarray
89 | info: jnp.ndarray
90 |
91 |
92 | def make_train(config):
93 | config["NUM_UPDATES"] = (
94 | config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
95 | )
96 | config["MINIBATCH_SIZE"] = (
97 | config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
98 | )
99 | env, env_params = gymnax.make(config["ENV_NAME"])
100 | env = FlattenObservationWrapper(env)
101 | env = LogWrapper(env)
102 |
103 | def linear_schedule(count):
104 | frac = (
105 | 1.0
106 | - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
107 | / config["NUM_UPDATES"]
108 | )
109 | return config["LR"] * frac
110 |
111 | def train(rng):
112 | # INIT NETWORK
113 | network = ActorCriticRNN(env.action_space(env_params).n, config=config)
114 | rng, _rng = jax.random.split(rng)
115 | init_x = (
116 | jnp.zeros(
117 | (1, config["NUM_ENVS"], *env.observation_space(env_params).shape)
118 | ),
119 | jnp.zeros((1, config["NUM_ENVS"])),
120 | )
121 | init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
122 | network_params = network.init(_rng, init_hstate, init_x)
123 | if config["ANNEAL_LR"]:
124 | tx = optax.chain(
125 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
126 | optax.adam(learning_rate=linear_schedule, eps=1e-5),
127 | )
128 | else:
129 | tx = optax.chain(
130 | optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
131 | optax.adam(config["LR"], eps=1e-5),
132 | )
133 | train_state = TrainState.create(
134 | apply_fn=network.apply,
135 | params=network_params,
136 | tx=tx,
137 | )
138 |
139 | # INIT ENV
140 | rng, _rng = jax.random.split(rng)
141 | reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
142 | obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
143 | init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
144 |
145 | # TRAIN LOOP
146 | def _update_step(runner_state, unused):
147 | # COLLECT TRAJECTORIES
148 | def _env_step(runner_state, unused):
149 | train_state, env_state, last_obs, last_done, hstate, rng = runner_state
150 | rng, _rng = jax.random.split(rng)
151 |
152 | # SELECT ACTION
153 | ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :])
154 | hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
155 | action = pi.sample(seed=_rng)
156 | log_prob = pi.log_prob(action)
157 | value, action, log_prob = (
158 | value.squeeze(0),
159 | action.squeeze(0),
160 | log_prob.squeeze(0),
161 | )
162 |
163 | # STEP ENV
164 | rng, _rng = jax.random.split(rng)
165 | rng_step = jax.random.split(_rng, config["NUM_ENVS"])
166 | obsv, env_state, reward, done, info = jax.vmap(
167 | env.step, in_axes=(0, 0, 0, None)
168 | )(rng_step, env_state, action, env_params)
169 | transition = Transition(
170 | last_done, action, value, reward, log_prob, last_obs, info
171 | )
172 | runner_state = (train_state, env_state, obsv, done, hstate, rng)
173 | return runner_state, transition
174 |
175 | initial_hstate = runner_state[-2]
176 | runner_state, traj_batch = jax.lax.scan(
177 | _env_step, runner_state, None, config["NUM_STEPS"]
178 | )
179 |
180 | # CALCULATE ADVANTAGE
181 | train_state, env_state, last_obs, last_done, hstate, rng = runner_state
182 | ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :])
183 | _, _, last_val = network.apply(train_state.params, hstate, ac_in)
184 | last_val = last_val.squeeze(0)
185 | def _calculate_gae(traj_batch, last_val, last_done):
186 | def _get_advantages(carry, transition):
187 | gae, next_value, next_done = carry
188 | done, value, reward = transition.done, transition.value, transition.reward
189 | delta = reward + config["GAMMA"] * next_value * (1 - next_done) - value
190 | gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - next_done) * gae
191 | return (gae, value, done), gae
192 | _, advantages = jax.lax.scan(_get_advantages, (jnp.zeros_like(last_val), last_val, last_done), traj_batch, reverse=True, unroll=16)
193 | return advantages, advantages + traj_batch.value
194 | advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
195 |
196 | # UPDATE NETWORK
197 | def _update_epoch(update_state, unused):
198 | def _update_minbatch(train_state, batch_info):
199 | init_hstate, traj_batch, advantages, targets = batch_info
200 |
201 | def _loss_fn(params, init_hstate, traj_batch, gae, targets):
202 | # RERUN NETWORK
203 | _, pi, value = network.apply(
204 | params, init_hstate[0], (traj_batch.obs, traj_batch.done)
205 | )
206 | log_prob = pi.log_prob(traj_batch.action)
207 |
208 | # CALCULATE VALUE LOSS
209 | value_pred_clipped = traj_batch.value + (
210 | value - traj_batch.value
211 | ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
212 | value_losses = jnp.square(value - targets)
213 | value_losses_clipped = jnp.square(value_pred_clipped - targets)
214 | value_loss = (
215 | 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
216 | )
217 |
218 | # CALCULATE ACTOR LOSS
219 | ratio = jnp.exp(log_prob - traj_batch.log_prob)
220 | gae = (gae - gae.mean()) / (gae.std() + 1e-8)
221 | loss_actor1 = ratio * gae
222 | loss_actor2 = (
223 | jnp.clip(
224 | ratio,
225 | 1.0 - config["CLIP_EPS"],
226 | 1.0 + config["CLIP_EPS"],
227 | )
228 | * gae
229 | )
230 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
231 | loss_actor = loss_actor.mean()
232 | entropy = pi.entropy().mean()
233 |
234 | total_loss = (
235 | loss_actor
236 | + config["VF_COEF"] * value_loss
237 | - config["ENT_COEF"] * entropy
238 | )
239 | return total_loss, (value_loss, loss_actor, entropy)
240 |
241 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
242 | total_loss, grads = grad_fn(
243 | train_state.params, init_hstate, traj_batch, advantages, targets
244 | )
245 | train_state = train_state.apply_gradients(grads=grads)
246 | return train_state, total_loss
247 |
248 | (
249 | train_state,
250 | init_hstate,
251 | traj_batch,
252 | advantages,
253 | targets,
254 | rng,
255 | ) = update_state
256 |
257 | rng, _rng = jax.random.split(rng)
258 | permutation = jax.random.permutation(_rng, config["NUM_ENVS"])
259 | batch = (init_hstate, traj_batch, advantages, targets)
260 |
261 | shuffled_batch = jax.tree_util.tree_map(
262 | lambda x: jnp.take(x, permutation, axis=1), batch
263 | )
264 |
265 | minibatches = jax.tree_util.tree_map(
266 | lambda x: jnp.swapaxes(
267 | jnp.reshape(
268 | x,
269 | [x.shape[0], config["NUM_MINIBATCHES"], -1]
270 | + list(x.shape[2:]),
271 | ),
272 | 1,
273 | 0,
274 | ),
275 | shuffled_batch,
276 | )
277 |
278 | train_state, total_loss = jax.lax.scan(
279 | _update_minbatch, train_state, minibatches
280 | )
281 | update_state = (
282 | train_state,
283 | init_hstate,
284 | traj_batch,
285 | advantages,
286 | targets,
287 | rng,
288 | )
289 | return update_state, total_loss
290 |
291 | init_hstate = initial_hstate[None, :] # TBH
292 | update_state = (
293 | train_state,
294 | init_hstate,
295 | traj_batch,
296 | advantages,
297 | targets,
298 | rng,
299 | )
300 | update_state, loss_info = jax.lax.scan(
301 | _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
302 | )
303 | train_state = update_state[0]
304 | metric = traj_batch.info
305 | rng = update_state[-1]
306 | if config.get("DEBUG"):
307 |
308 | def callback(info):
309 | return_values = info["returned_episode_returns"][
310 | info["returned_episode"]
311 | ]
312 | timesteps = (
313 | info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
314 | )
315 | for t in range(len(timesteps)):
316 | print(
317 | f"global step={timesteps[t]}, episodic return={return_values[t]}"
318 | )
319 |
320 | jax.debug.callback(callback, metric)
321 |
322 | runner_state = (train_state, env_state, last_obs, last_done, hstate, rng)
323 | return runner_state, metric
324 |
325 | rng, _rng = jax.random.split(rng)
326 | runner_state = (
327 | train_state,
328 | env_state,
329 | obsv,
330 | jnp.zeros((config["NUM_ENVS"]), dtype=bool),
331 | init_hstate,
332 | _rng,
333 | )
334 | runner_state, metric = jax.lax.scan(
335 | _update_step, runner_state, None, config["NUM_UPDATES"]
336 | )
337 | return {"runner_state": runner_state, "metric": metric}
338 |
339 | return train
340 |
341 |
342 | if __name__ == "__main__":
343 | config = {
344 | "LR": 2.5e-4,
345 | "NUM_ENVS": 4,
346 | "NUM_STEPS": 128,
347 | "TOTAL_TIMESTEPS": 5e5,
348 | "UPDATE_EPOCHS": 4,
349 | "NUM_MINIBATCHES": 4,
350 | "GAMMA": 0.99,
351 | "GAE_LAMBDA": 0.95,
352 | "CLIP_EPS": 0.2,
353 | "ENT_COEF": 0.01,
354 | "VF_COEF": 0.5,
355 | "MAX_GRAD_NORM": 0.5,
356 | "ENV_NAME": "CartPole-v1",
357 | "ANNEAL_LR": True,
358 | "DEBUG": True,
359 | }
360 |
361 | rng = jax.random.PRNGKey(30)
362 | train_jit = jax.jit(make_train(config))
363 | out = train_jit(rng)
364 |
--------------------------------------------------------------------------------
/purejaxrl/wrappers.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import chex
4 | import numpy as np
5 | from flax import struct
6 | from functools import partial
7 | from typing import Optional, Tuple, Union, Any
8 | from gymnax.environments import environment, spaces
9 | from brax import envs
10 | from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper
11 | import navix as nx
12 |
13 |
14 | class GymnaxWrapper(object):
15 | """Base class for Gymnax wrappers."""
16 |
17 | def __init__(self, env):
18 | self._env = env
19 |
20 | # provide proxy access to regular attributes of wrapped object
21 | def __getattr__(self, name):
22 | return getattr(self._env, name)
23 |
24 |
25 | class FlattenObservationWrapper(GymnaxWrapper):
26 | """Flatten the observations of the environment."""
27 |
28 | def __init__(self, env: environment.Environment):
29 | super().__init__(env)
30 |
31 | def observation_space(self, params) -> spaces.Box:
32 | assert isinstance(
33 | self._env.observation_space(params), spaces.Box
34 | ), "Only Box spaces are supported for now."
35 | return spaces.Box(
36 | low=self._env.observation_space(params).low,
37 | high=self._env.observation_space(params).high,
38 | shape=(np.prod(self._env.observation_space(params).shape),),
39 | dtype=self._env.observation_space(params).dtype,
40 | )
41 |
42 | @partial(jax.jit, static_argnums=(0,))
43 | def reset(
44 | self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
45 | ) -> Tuple[chex.Array, environment.EnvState]:
46 | obs, state = self._env.reset(key, params)
47 | obs = jnp.reshape(obs, (-1,))
48 | return obs, state
49 |
50 | @partial(jax.jit, static_argnums=(0,))
51 | def step(
52 | self,
53 | key: chex.PRNGKey,
54 | state: environment.EnvState,
55 | action: Union[int, float],
56 | params: Optional[environment.EnvParams] = None,
57 | ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
58 | obs, state, reward, done, info = self._env.step(key, state, action, params)
59 | obs = jnp.reshape(obs, (-1,))
60 | return obs, state, reward, done, info
61 |
62 |
63 | @struct.dataclass
64 | class LogEnvState:
65 | env_state: environment.EnvState
66 | episode_returns: float
67 | episode_lengths: int
68 | returned_episode_returns: float
69 | returned_episode_lengths: int
70 | timestep: int
71 |
72 |
73 | class LogWrapper(GymnaxWrapper):
74 | """Log the episode returns and lengths."""
75 |
76 | def __init__(self, env: environment.Environment):
77 | super().__init__(env)
78 |
79 | @partial(jax.jit, static_argnums=(0,))
80 | def reset(
81 | self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
82 | ) -> Tuple[chex.Array, environment.EnvState]:
83 | obs, env_state = self._env.reset(key, params)
84 | state = LogEnvState(env_state, 0, 0, 0, 0, 0)
85 | return obs, state
86 |
87 | @partial(jax.jit, static_argnums=(0,))
88 | def step(
89 | self,
90 | key: chex.PRNGKey,
91 | state: environment.EnvState,
92 | action: Union[int, float],
93 | params: Optional[environment.EnvParams] = None,
94 | ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
95 | obs, env_state, reward, done, info = self._env.step(
96 | key, state.env_state, action, params
97 | )
98 | new_episode_return = state.episode_returns + reward
99 | new_episode_length = state.episode_lengths + 1
100 | state = LogEnvState(
101 | env_state=env_state,
102 | episode_returns=new_episode_return * (1 - done),
103 | episode_lengths=new_episode_length * (1 - done),
104 | returned_episode_returns=state.returned_episode_returns * (1 - done)
105 | + new_episode_return * done,
106 | returned_episode_lengths=state.returned_episode_lengths * (1 - done)
107 | + new_episode_length * done,
108 | timestep=state.timestep + 1,
109 | )
110 | info["returned_episode_returns"] = state.returned_episode_returns
111 | info["returned_episode_lengths"] = state.returned_episode_lengths
112 | info["timestep"] = state.timestep
113 | info["returned_episode"] = done
114 | return obs, state, reward, done, info
115 |
116 |
117 | class BraxGymnaxWrapper:
118 | def __init__(self, env_name, backend="positional"):
119 | env = envs.get_environment(env_name=env_name, backend=backend)
120 | env = EpisodeWrapper(env, episode_length=1000, action_repeat=1)
121 | env = AutoResetWrapper(env)
122 | self._env = env
123 | self.action_size = env.action_size
124 | self.observation_size = (env.observation_size,)
125 |
126 | def reset(self, key, params=None):
127 | state = self._env.reset(key)
128 | return state.obs, state
129 |
130 | def step(self, key, state, action, params=None):
131 | next_state = self._env.step(state, action)
132 | return next_state.obs, next_state, next_state.reward, next_state.done > 0.5, {}
133 |
134 | def observation_space(self, params):
135 | return spaces.Box(
136 | low=-jnp.inf,
137 | high=jnp.inf,
138 | shape=(self._env.observation_size,),
139 | )
140 |
141 | def action_space(self, params):
142 | return spaces.Box(
143 | low=-1.0,
144 | high=1.0,
145 | shape=(self._env.action_size,),
146 | )
147 |
148 | class NavixGymnaxWrapper:
149 | def __init__(self, env_name):
150 | self._env = nx.make(env_name)
151 |
152 | def reset(self, key, params=None):
153 | timestep = self._env.reset(key)
154 | return timestep.observation, timestep
155 |
156 | def step(self, key, state, action, params=None):
157 | timestep = self._env.step(state, action)
158 | return timestep.observation, timestep, timestep.reward, timestep.is_done(), {}
159 |
160 | def observation_space(self, params):
161 | return spaces.Box(
162 | low=self._env.observation_space.minimum,
163 | high=self._env.observation_space.maximum,
164 | shape=(np.prod(self._env.observation_space.shape),),
165 | dtype=self._env.observation_space.dtype,
166 | )
167 |
168 | def action_space(self, params):
169 | return spaces.Discrete(
170 | num_categories=self._env.action_space.maximum.item() + 1,
171 | )
172 |
173 |
174 | class ClipAction(GymnaxWrapper):
175 | def __init__(self, env, low=-1.0, high=1.0):
176 | super().__init__(env)
177 | self.low = low
178 | self.high = high
179 |
180 | def step(self, key, state, action, params=None):
181 | """TODO: In theory the below line should be the way to do this."""
182 | # action = jnp.clip(action, self.env.action_space.low, self.env.action_space.high)
183 | action = jnp.clip(action, self.low, self.high)
184 | return self._env.step(key, state, action, params)
185 |
186 |
187 | class TransformObservation(GymnaxWrapper):
188 | def __init__(self, env, transform_obs):
189 | super().__init__(env)
190 | self.transform_obs = transform_obs
191 |
192 | def reset(self, key, params=None):
193 | obs, state = self._env.reset(key, params)
194 | return self.transform_obs(obs), state
195 |
196 | def step(self, key, state, action, params=None):
197 | obs, state, reward, done, info = self._env.step(key, state, action, params)
198 | return self.transform_obs(obs), state, reward, done, info
199 |
200 |
201 | class TransformReward(GymnaxWrapper):
202 | def __init__(self, env, transform_reward):
203 | super().__init__(env)
204 | self.transform_reward = transform_reward
205 |
206 | def step(self, key, state, action, params=None):
207 | obs, state, reward, done, info = self._env.step(key, state, action, params)
208 | return obs, state, self.transform_reward(reward), done, info
209 |
210 |
211 | class VecEnv(GymnaxWrapper):
212 | def __init__(self, env):
213 | super().__init__(env)
214 | self.reset = jax.vmap(self._env.reset, in_axes=(0, None))
215 | self.step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
216 |
217 |
218 | @struct.dataclass
219 | class NormalizeVecObsEnvState:
220 | mean: jnp.ndarray
221 | var: jnp.ndarray
222 | count: float
223 | env_state: environment.EnvState
224 |
225 |
226 | class NormalizeVecObservation(GymnaxWrapper):
227 | def __init__(self, env):
228 | super().__init__(env)
229 |
230 | def reset(self, key, params=None):
231 | obs, state = self._env.reset(key, params)
232 | state = NormalizeVecObsEnvState(
233 | mean=jnp.zeros_like(obs),
234 | var=jnp.ones_like(obs),
235 | count=1e-4,
236 | env_state=state,
237 | )
238 | batch_mean = jnp.mean(obs, axis=0)
239 | batch_var = jnp.var(obs, axis=0)
240 | batch_count = obs.shape[0]
241 |
242 | delta = batch_mean - state.mean
243 | tot_count = state.count + batch_count
244 |
245 | new_mean = state.mean + delta * batch_count / tot_count
246 | m_a = state.var * state.count
247 | m_b = batch_var * batch_count
248 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
249 | new_var = M2 / tot_count
250 | new_count = tot_count
251 |
252 | state = NormalizeVecObsEnvState(
253 | mean=new_mean,
254 | var=new_var,
255 | count=new_count,
256 | env_state=state.env_state,
257 | )
258 |
259 | return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state
260 |
261 | def step(self, key, state, action, params=None):
262 | obs, env_state, reward, done, info = self._env.step(
263 | key, state.env_state, action, params
264 | )
265 |
266 | batch_mean = jnp.mean(obs, axis=0)
267 | batch_var = jnp.var(obs, axis=0)
268 | batch_count = obs.shape[0]
269 |
270 | delta = batch_mean - state.mean
271 | tot_count = state.count + batch_count
272 |
273 | new_mean = state.mean + delta * batch_count / tot_count
274 | m_a = state.var * state.count
275 | m_b = batch_var * batch_count
276 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
277 | new_var = M2 / tot_count
278 | new_count = tot_count
279 |
280 | state = NormalizeVecObsEnvState(
281 | mean=new_mean,
282 | var=new_var,
283 | count=new_count,
284 | env_state=env_state,
285 | )
286 | return (
287 | (obs - state.mean) / jnp.sqrt(state.var + 1e-8),
288 | state,
289 | reward,
290 | done,
291 | info,
292 | )
293 |
294 |
295 | @struct.dataclass
296 | class NormalizeVecRewEnvState:
297 | mean: jnp.ndarray
298 | var: jnp.ndarray
299 | count: float
300 | return_val: float
301 | env_state: environment.EnvState
302 |
303 |
304 | class NormalizeVecReward(GymnaxWrapper):
305 | def __init__(self, env, gamma):
306 | super().__init__(env)
307 | self.gamma = gamma
308 |
309 | def reset(self, key, params=None):
310 | obs, state = self._env.reset(key, params)
311 | batch_count = obs.shape[0]
312 | state = NormalizeVecRewEnvState(
313 | mean=0.0,
314 | var=1.0,
315 | count=1e-4,
316 | return_val=jnp.zeros((batch_count,)),
317 | env_state=state,
318 | )
319 | return obs, state
320 |
321 | def step(self, key, state, action, params=None):
322 | obs, env_state, reward, done, info = self._env.step(
323 | key, state.env_state, action, params
324 | )
325 | return_val = state.return_val * self.gamma * (1 - done) + reward
326 |
327 | batch_mean = jnp.mean(return_val, axis=0)
328 | batch_var = jnp.var(return_val, axis=0)
329 | batch_count = obs.shape[0]
330 |
331 | delta = batch_mean - state.mean
332 | tot_count = state.count + batch_count
333 |
334 | new_mean = state.mean + delta * batch_count / tot_count
335 | m_a = state.var * state.count
336 | m_b = batch_var * batch_count
337 | M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
338 | new_var = M2 / tot_count
339 | new_count = tot_count
340 |
341 | state = NormalizeVecRewEnvState(
342 | mean=new_mean,
343 | var=new_var,
344 | count=new_count,
345 | return_val=return_val,
346 | env_state=env_state,
347 | )
348 | return obs, state, reward / jnp.sqrt(state.var + 1e-8), done, info
349 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jax>=0.2.26
2 | jaxlib>=0.1.74
3 | gymnax
4 | evosax
5 | distrax
6 | optax
7 | flax
8 | numpy
9 | brax
10 | wandb
11 | flashbax
12 | navix
--------------------------------------------------------------------------------