├── .github └── FUNDING.yml ├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── core ├── __init__.py ├── common.py ├── evaluators │ ├── __init__.py │ ├── alphazero.py │ ├── evaluation_fns.py │ ├── evaluator.py │ └── mcts │ │ ├── __init__.py │ │ ├── action_selection.py │ │ ├── mcts.py │ │ ├── state.py │ │ └── weighted_mcts.py ├── memory │ ├── __init__.py │ └── replay_memory.py ├── networks │ ├── __init__.py │ └── azresnet.py ├── testing │ ├── __init__.py │ ├── tester.py │ ├── two_player_baseline.py │ ├── two_player_tester.py │ └── utils.py ├── training │ ├── __init__.py │ ├── loss_fns.py │ └── train.py ├── trees │ ├── __init__.py │ └── tree.py └── types.py ├── notebooks ├── hello_world.ipynb └── weighted_mcts.ipynb ├── poetry.lock └── pyproject.toml /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [lowrollr] 4 | 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | .vscode/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: "turbozero: fast + parallel AlphaZero" 3 | abstract: vectorized implementation of AlphaZero/MCTS with training and evaluation utilities 4 | url: "https://github.com/lowrollr/turbozero" 5 | authors: 6 | - family-names: "Marshall" 7 | given-names: "Jacob" 8 | 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *turbozero* 🏎️ 🏎️ 🏎️ 🏎️ 2 | 3 | 📣 If you're looking for the old PyTorch version of turbozero, it's been moved here: [turbozero_torch](https://github.com/lowrollr/turbozero_torch) 📣 4 | 5 | #### *`turbozero`* is a vectorized implementation of [AlphaZero](https://deepmind.google/discover/blog/alphazero-shedding-new-light-on-chess-shogi-and-go/) written in JAX 6 | 7 | It contains: 8 | * Monte Carlo Tree Search with subtree persistence 9 | * Batched Replay Memory 10 | * A complete, customizable training/evaluation loop 11 | 12 | #### *`turbozero`* is *_fast_* and *_parallelized_*: 13 | * every consequential part of the training loop is JIT-compiled 14 | * parititions across multiple GPUs by default when available 🚀 NEW! 🚀 15 | * self-play and evaluation episodes are batched/vmapped with hardware-acceleration in mind 16 | 17 | #### *`turbozero`* is *_extendable_*: 18 | * see an [idea on twitter](https://twitter.com/ptrschmdtnlsn/status/1748800529608888362) for a simple tweak to MCTS? 19 | * [implement it](https://github.com/lowrollr/turbozero/blob/main/core/evaluators/mcts/weighted_mcts.py) then [test it](https://github.com/lowrollr/turbozero/blob/main/notebooks/weighted_mcts.ipynb) by extending core components 20 | 21 | #### *`turbozero`* is *_flexible_*: 22 | * easy to integrate with you custom JAX environment or neural network architecture. 23 | * Use the provided training and evaluation utilities, or pick and choose the components that you need. 24 | 25 | To get started, check out the [Hello World Notebook](https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb) 26 | 27 | ## Installation 28 | `turbozero` uses `poetry` for dependency management, you can install it with: 29 | ``` 30 | pip install poetry==1.7.1 31 | ``` 32 | Then, to install dependencies: 33 | ``` 34 | poetry install 35 | ``` 36 | If you're using a GPU/TPU/etc., after running the previous command you'll need to install the device-specific version of JAX. 37 | 38 | For a GPU w/ CUDA 12: 39 | ``` 40 | poetry source add jax https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 41 | ``` 42 | to point poetry towards JAX cuda releases, then use 43 | ``` 44 | poetry add jax[cuda12_pip]==0.4.35 45 | ``` 46 | to install the CUDA 12 release for JAX. See https://jax.readthedocs.io/en/latest/installation.html for other devices/cuda versions. 47 | 48 | I have tested this project with CUDA 11 and CUDA 12. 49 | 50 | To launch an ipython kernel, run: 51 | ``` 52 | poetry run python -m ipykernel install --user --name turbozero 53 | ``` 54 | 55 | ## Issues 56 | If you use this project and encounter an issue, error, or undesired behavior, please submit a [GitHub Issue](https://github.com/lowrollr/turbozero/issues) and I will do my best to resolve it as soon as I can. You may also contact me directly via `hello@jacob.land`. 57 | 58 | ## Contributing 59 | Contributions, improvements, and fixes are more than welcome! For now I don't have a formal process for this, other than creating a [Pull Request](https://github.com/lowrollr/turbozero/pulls). For large changes, consider creating an [Issue](https://github.com/lowrollr/turbozero/issues) beforehand. 60 | 61 | If you are interested in contributing but don't know what to work on, please reach out. I have plenty of things you could do. 62 | 63 | ## References 64 | Papers/Repos I found helpful. 65 | 66 | Repositories: 67 | * [google-deepmind/mctx](https://github.com/google-deepmind/mctx): Monte Carlo tree search in JAX 68 | * [sotetsuk/pgx](https://github.com/sotetsuk/pgx): Vectorized RL game environments in JAX 69 | * [instadeepai/flashbax](https://github.com/instadeepai/flashbax): Accelerated Replay Buffers in JAX 70 | * [google-deepmind/open_spiel](https://github.com/google-deepmind/open_spiel): RL algorithms 71 | 72 | Papers: 73 | * [Mastering Chess and Shogi by Self-Play with a General Reinforcement Learning Algorithm](https://arxiv.org/abs/1712.01815) 74 | * [Revisiting Fundamentals of Experience Replay](https://arxiv.org/abs/2007.06700) 75 | 76 | 77 | ## Cite This Work 78 | If you found this work useful, please cite it with: 79 | ``` 80 | @software{turbozero, 81 | author = {Marshall, Jacob}, 82 | title = {{turbozero: fast + parallel AlphaZero}}, 83 | url = {https://github.com/lowrollr/turbozero} 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero/74cdb9cbd736b396141df4a07db268f28b687f9a/core/__init__.py -------------------------------------------------------------------------------- /core/common.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import partial 3 | from typing import Tuple 4 | import chex 5 | from chex import dataclass 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | from core.evaluators.evaluator import EvalOutput, Evaluator 10 | from core.types import EnvInitFn, EnvStepFn, StepMetadata 11 | 12 | def partition( 13 | data: chex.ArrayTree, 14 | num_partitions: int 15 | ) -> chex.ArrayTree: 16 | """Partition each array in a data structure into num_partitions along the first axis. 17 | e.g. partitions an array of shape (N, ...) into (num_partitions, N//num_partitions, ...) 18 | 19 | Args: 20 | - `data`: ArrayTree to partition 21 | - `num_partitions`: number of partitions 22 | 23 | Returns: 24 | - (chex.ArrayTree): partitioned ArrayTree 25 | """ 26 | return jax.tree_map( 27 | lambda x: x.reshape(num_partitions, x.shape[0] // num_partitions, *x.shape[1:]), 28 | data 29 | ) 30 | 31 | 32 | def step_env_and_evaluator( 33 | key: jax.random.PRNGKey, 34 | env_state: chex.ArrayTree, 35 | env_state_metadata: StepMetadata, 36 | eval_state: chex.ArrayTree, 37 | params: chex.ArrayTree, 38 | evaluator: Evaluator, 39 | env_step_fn: EnvStepFn, 40 | env_init_fn: EnvInitFn, 41 | max_steps: int, 42 | reset: bool = True 43 | ) -> Tuple[EvalOutput, chex.ArrayTree, StepMetadata, bool, bool, chex.Array]: 44 | """ 45 | - Evaluates the environment state with the Evaluator and selects an action. 46 | - Performs a step in the environment with the selected action. 47 | - Updates the internal state of the Evaluator. 48 | - Optionally resets the environment and evaluator state if the episode is terminated or truncated. 49 | 50 | Args: 51 | - `key`: rng 52 | - `env_state`: The environment state to evaluate. 53 | - `env_state_metadata`: Metadata associated with the environment state. 54 | - `eval_state`: The internal state of the Evaluator. 55 | - `params`: nn parameters used by the Evaluator. 56 | - `evaluator`: The Evaluator. 57 | - `env_step_fn`: The environment step function. 58 | - `env_init_fn`: The environment initialization function. 59 | - `max_steps`: The maximum number of environment steps per episode. 60 | - `reset`: Whether to reset the environment and evaluator state if the episode is terminated or truncated. 61 | 62 | Returns: 63 | - (EvalOutput, chex.ArrayTree, StepMetadata, bool, bool, chex.Array) 64 | - `output`: The output of the evaluation. 65 | - `env_state`: The updated environment state. 66 | - `env_state_metadata`: Metadata associated with the updated environment state. 67 | - `terminated`: Whether the episode is terminated. 68 | - `truncated`: Whether the episode is truncated. 69 | - `rewards`: Rewards emitted by the environment. 70 | """ 71 | key, evaluate_key = jax.random.split(key) 72 | # evaluate the environment state 73 | output = evaluator.evaluate( 74 | key=evaluate_key, 75 | eval_state=eval_state, 76 | env_state=env_state, 77 | root_metadata=env_state_metadata, 78 | params=params, 79 | env_step_fn=env_step_fn 80 | ) 81 | # take the selected action 82 | env_state, env_state_metadata = env_step_fn(env_state, output.action) 83 | # check for termination and truncation 84 | terminated = env_state_metadata.terminated 85 | truncated = env_state_metadata.step > max_steps 86 | # reset the environment and evaluator state if the episode is terminated or truncated 87 | # else, update the evaluator state 88 | rewards = env_state_metadata.rewards 89 | eval_state = jax.lax.cond( 90 | terminated | truncated, 91 | evaluator.reset if reset else lambda s: s, 92 | lambda s: evaluator.step(s, output.action), 93 | output.eval_state 94 | ) 95 | # reset the environment if the episode is terminated or truncated 96 | env_state, env_state_metadata = jax.lax.cond( 97 | terminated | truncated, 98 | lambda _: env_init_fn(key) if reset else (env_state, env_state_metadata), 99 | lambda _: (env_state, env_state_metadata), 100 | None 101 | ) 102 | output = output.replace(eval_state=eval_state) 103 | return output, env_state, env_state_metadata, terminated, truncated, rewards 104 | 105 | 106 | @dataclass(frozen=True) 107 | class TwoPlayerGameState: 108 | """Stores the state of a two player game using two different evaluators. 109 | - `key`: rng 110 | - `env_state`: The environment state. 111 | - `env_state_metadata`: Metadata associated with the environment state. 112 | - `p1_eval_state`: The internal state of the first evaluator. 113 | - `p2_eval_state`: The internal state of the second evaluator. 114 | - `p1_value_estimate`: The current state value estimate of the first evaluator. 115 | - `p2_value_estimate`: The current state value estimate of the second evaluator. 116 | - `outcomes`: The outcomes of the game (final rewards) for each player 117 | - `completed`: Whether the game is completed. 118 | """ 119 | key: jax.random.PRNGKey 120 | env_state: chex.ArrayTree 121 | env_state_metadata: StepMetadata 122 | p1_eval_state: chex.ArrayTree 123 | p2_eval_state: chex.ArrayTree 124 | p1_value_estimate: chex.Array 125 | p2_value_estimate: chex.Array 126 | outcomes: float 127 | completed: bool 128 | 129 | 130 | @dataclass(frozen=True) 131 | class GameFrame: 132 | """Stores information necessary for rendering the environment state in a two-player game. 133 | - `env_state`: The environment state. 134 | - `p1_value_estimate`: The current state value estimate of the first evaluator. 135 | - `p2_value_estimate`: The current state value estimate of the second evaluator. 136 | - `completed`: Whether the game is completed. 137 | - `outcomes`: The outcomes of the game (final rewards) for each player 138 | """ 139 | env_state: chex.ArrayTree 140 | p1_value_estimate: chex.Array 141 | p2_value_estimate: chex.Array 142 | completed: chex.Array 143 | outcomes: chex.Array 144 | 145 | 146 | def two_player_game_step( 147 | state: TwoPlayerGameState, 148 | p1_evaluator: Evaluator, 149 | p2_evaluator: Evaluator, 150 | params: chex.ArrayTree, 151 | env_step_fn: EnvStepFn, 152 | env_init_fn: EnvInitFn, 153 | use_p1: bool, 154 | max_steps: int 155 | ) -> TwoPlayerGameState: 156 | """Make a single step in a two player game. 157 | 158 | Args: 159 | - `state`: The current game state. 160 | - `p1_evaluator`: The first evaluator. 161 | - `p2_evaluator`: The second evaluator. 162 | - `params`: The parameters of the active evaluator. 163 | - `env_step_fn`: The environment step function. 164 | - `env_init_fn`: The environment initialization function. 165 | - `use_p1`: Whether to use the first evaluator. 166 | - `max_steps`: The maximum number of steps per episode. 167 | 168 | Returns: 169 | - (TwoPlayerGameState): The updated game state. 170 | """ 171 | # determine which evaluator to use based on the current player 172 | if use_p1: 173 | active_evaluator = p1_evaluator 174 | other_evaluator = p2_evaluator 175 | active_eval_state = state.p1_eval_state 176 | other_eval_state = state.p2_eval_state 177 | else: 178 | active_evaluator = p2_evaluator 179 | other_evaluator = p1_evaluator 180 | active_eval_state = state.p2_eval_state 181 | other_eval_state = state.p1_eval_state 182 | 183 | # step 184 | step_key, key = jax.random.split(state.key) 185 | output, env_state, env_state_metadata, terminated, truncated, rewards = step_env_and_evaluator( 186 | key = step_key, 187 | env_state = state.env_state, 188 | env_state_metadata = state.env_state_metadata, 189 | eval_state = active_eval_state, 190 | params = params, 191 | evaluator = active_evaluator, 192 | env_step_fn = env_step_fn, 193 | env_init_fn = env_init_fn, 194 | max_steps = max_steps, 195 | reset = False 196 | ) 197 | 198 | 199 | active_eval_state = output.eval_state 200 | active_value_estimate = active_evaluator.get_value(active_eval_state) 201 | active_value_estimate = jax.lax.cond( 202 | terminated | truncated, 203 | lambda a: a, 204 | lambda a: active_evaluator.discount * a, 205 | active_value_estimate 206 | ) 207 | # update the other evaluator 208 | other_eval_state = other_evaluator.step(other_eval_state, output.action) 209 | other_value_estimate = other_evaluator.get_value(other_eval_state) 210 | # update the game state 211 | if use_p1: 212 | p1_eval_state, p2_eval_state = active_eval_state, other_eval_state 213 | p1_value_estimate, p2_value_estimate = active_value_estimate, other_value_estimate 214 | else: 215 | p1_eval_state, p2_eval_state = other_eval_state, active_eval_state 216 | p1_value_estimate, p2_value_estimate = other_value_estimate, active_value_estimate 217 | return state.replace( 218 | key = key, 219 | env_state = env_state, 220 | env_state_metadata = env_state_metadata, 221 | p1_eval_state = p1_eval_state, 222 | p2_eval_state = p2_eval_state, 223 | p1_value_estimate = p1_value_estimate, 224 | p2_value_estimate = p2_value_estimate, 225 | outcomes=jnp.where( 226 | ((terminated | truncated) & ~state.completed)[..., None], 227 | rewards, 228 | state.outcomes 229 | ), 230 | completed = state.completed | terminated | truncated, 231 | ) 232 | 233 | 234 | def two_player_game( 235 | key: jax.random.PRNGKey, 236 | evaluator_1: Evaluator, 237 | evaluator_2: Evaluator, 238 | params_1: chex.ArrayTree, 239 | params_2: chex.ArrayTree, 240 | env_step_fn: EnvStepFn, 241 | env_init_fn: EnvInitFn, 242 | max_steps: int 243 | ) -> Tuple[chex.Array, TwoPlayerGameState, chex.Array]: 244 | """ 245 | Play a two player game between two evaluators. 246 | 247 | Args: 248 | - `key`: rng 249 | - `evaluator_1`: The first evaluator. 250 | - `evaluator_2`: The second evaluator. 251 | - `params_1`: The parameters of the first evaluator. 252 | - `params_2`: The parameters of the second evaluator. 253 | - `env_step_fn`: The environment step function. 254 | - `env_init_fn`: The environment initialization function. 255 | - `max_steps`: The maximum number of steps per episode. 256 | 257 | Returns: 258 | - (chex.Array, TwoPlayerGameState, chex.Array, chex.Array) 259 | - `outcomes`: The outcomes of the game (final rewards) for each player. 260 | - `frames`: Frames collected from the game (used for rendering) 261 | - `p_ids`: The player ids of the two evaluators. [evaluator_1_id, evaluator_2_id] 262 | """ 263 | # init rng 264 | env_key, turn_key, key = jax.random.split(key, 3) 265 | # init env state 266 | env_state, metadata = env_init_fn(env_key) 267 | # init evaluator states 268 | p1_eval_state = evaluator_1.init(template_embedding=env_state) 269 | p2_eval_state = evaluator_2.init(template_embedding=env_state) 270 | # compile step functions 271 | game_step = partial(two_player_game_step, 272 | p1_evaluator=evaluator_1, 273 | p2_evaluator=evaluator_2, 274 | env_step_fn=env_step_fn, 275 | env_init_fn=env_init_fn, 276 | max_steps=max_steps 277 | ) 278 | step_p1 = partial(game_step, params=params_1, use_p1=True) 279 | step_p2 = partial(game_step, params=params_2, use_p1=False) 280 | 281 | # determine who goes first 282 | first_player = jax.random.randint(turn_key, (), 0, 2) 283 | p1_first = first_player == 0 284 | p1_id, p2_id = jax.lax.cond( 285 | p1_first, 286 | lambda _: (metadata.cur_player_id, 1 - metadata.cur_player_id), 287 | lambda _: (1 - metadata.cur_player_id, metadata.cur_player_id), 288 | None 289 | ) 290 | # init game state 291 | state = TwoPlayerGameState( 292 | key = key, 293 | env_state = env_state, 294 | env_state_metadata = metadata, 295 | p1_eval_state = p1_eval_state, 296 | p2_eval_state = p2_eval_state, 297 | p1_value_estimate = jnp.array(0.0, dtype=jnp.float32), 298 | p2_value_estimate = jnp.array(0.0, dtype=jnp.float32), 299 | outcomes = jnp.zeros((2,), dtype=jnp.float32), 300 | completed = jnp.zeros((), dtype=jnp.bool_) 301 | ) 302 | # make initial render frame 303 | initial_game_frame = GameFrame( 304 | env_state = state.env_state, 305 | p1_value_estimate = state.p1_value_estimate, 306 | p2_value_estimate = state.p2_value_estimate, 307 | completed = state.completed, 308 | outcomes = state.outcomes 309 | ) 310 | 311 | # takes a turn for each player 312 | def step_step(state: TwoPlayerGameState, _) -> TwoPlayerGameState: 313 | # take a turn for the active player 314 | state = jax.lax.cond( 315 | state.completed, 316 | lambda s: s, 317 | lambda s: jax.lax.cond( 318 | p1_first, 319 | step_p1, 320 | step_p2, 321 | s 322 | ), 323 | state 324 | ) 325 | # collect render frame 326 | frame1 = GameFrame( 327 | env_state = state.env_state, 328 | p1_value_estimate = state.p1_value_estimate, 329 | p2_value_estimate = state.p2_value_estimate, 330 | completed = state.completed, 331 | outcomes = state.outcomes 332 | ) 333 | # take a turn for the other player 334 | state = jax.lax.cond( 335 | state.completed, 336 | lambda s: s, 337 | lambda s: jax.lax.cond( 338 | p1_first, 339 | step_p2, 340 | step_p1, 341 | s 342 | ), 343 | state 344 | ) 345 | # collect render frame 346 | frame2 = GameFrame( 347 | env_state = state.env_state, 348 | p1_value_estimate = state.p1_value_estimate, 349 | p2_value_estimate = state.p2_value_estimate, 350 | completed = state.completed, 351 | outcomes = state.outcomes 352 | ) 353 | # return game state and render frames 354 | return state, jax.tree_map(lambda x, y: jnp.stack([x, y]), frame1, frame2) 355 | 356 | # play the game 357 | state, frames = jax.lax.scan( 358 | step_step, 359 | state, 360 | xs=jnp.arange(max_steps//2) 361 | ) 362 | # reshape frames 363 | frames = jax.tree_map(lambda x: x.reshape(max_steps, *x.shape[2:]), frames) 364 | # append initial state to front of frames 365 | frames = jax.tree_map(lambda i, x: jnp.concatenate([jnp.expand_dims(i, 0), x]), initial_game_frame, frames) 366 | # return outcome, frames, player ids 367 | return jnp.array([state.outcomes[p1_id], state.outcomes[p2_id]]), frames, jnp.array([p1_id, p2_id]) 368 | -------------------------------------------------------------------------------- /core/evaluators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero/74cdb9cbd736b396141df4a07db268f28b687f9a/core/evaluators/__init__.py -------------------------------------------------------------------------------- /core/evaluators/alphazero.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Dict 3 | 4 | import chex 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from core.evaluators.mcts.state import MCTSTree 9 | from core.evaluators.mcts.mcts import MCTS 10 | from core.types import StepMetadata 11 | 12 | 13 | class _AlphaZero: 14 | """AlphaZero-specific logic for MCTS. 15 | Extends MCTS using the `AlphaZero` class, this class serves as a mixin to add AlphaZero-specific logic. 16 | """ 17 | 18 | def __init__(self, 19 | dirichlet_alpha: float = 0.3, 20 | dirichlet_epsilon: float = 0.25, 21 | **kwargs 22 | ): 23 | """ 24 | Args: 25 | - `dirichlet_alpha`: magnitude of Dirichlet noise. 26 | - `dirichlet_epsilon`: proportion of root policy composed of Dirichlet noise. 27 | (see `MCTS` class for additional configuration) 28 | """ 29 | super().__init__(**kwargs) 30 | self.dirichlet_alpha = dirichlet_alpha 31 | self.dirichlet_epsilon = dirichlet_epsilon 32 | 33 | 34 | def get_config(self) -> Dict: 35 | """Returns the configuration of the AlphaZero evaluator. Used for logging.""" 36 | return { 37 | "dirichlet_alpha": self.dirichlet_alpha, 38 | "dirichlet_epsilon": self.dirichlet_epsilon, 39 | **super().get_config() #pylint: disable=no-member 40 | } 41 | 42 | 43 | def update_root(self, key: chex.PRNGKey, tree: MCTSTree, root_embedding: chex.ArrayTree, params: chex.ArrayTree, root_metadata: StepMetadata) -> MCTSTree: 44 | """Populates the root node of the search tree. Adds Dirichlet noise to the root policy. 45 | 46 | Args: 47 | - `key`: rng 48 | - `tree`: The search tree. 49 | - `root_embedding`: root environment state. 50 | - `params`: nn parameters. 51 | - `root_metadata`: metadata of the root environment state 52 | 53 | Returns: 54 | - `tree`: The updated search tree. 55 | """ 56 | # evaluate the root state 57 | root_key, dir_key = jax.random.split(key, 2) 58 | root_policy_logits, root_value = self.eval_fn(root_embedding, params, root_key) #pylint: disable=no-member 59 | root_policy = jax.nn.softmax(root_policy_logits) 60 | 61 | # add Dirichlet noise to the root policy 62 | dirichlet_noise = jax.random.dirichlet( 63 | dir_key, 64 | alpha=jnp.full( 65 | [tree.branching_factor], 66 | fill_value=self.dirichlet_alpha 67 | ) 68 | ) 69 | noisy_policy = ( 70 | ((1-self.dirichlet_epsilon) * root_policy) + 71 | (self.dirichlet_epsilon * dirichlet_noise) 72 | ) 73 | # re-normalize the policy 74 | new_logits = jnp.log(jnp.maximum(noisy_policy, jnp.finfo(noisy_policy).tiny)) 75 | policy = jnp.where(root_metadata.action_mask, new_logits, jnp.finfo(noisy_policy).min) 76 | renorm_policy = jax.nn.softmax(policy) 77 | 78 | # update the root node 79 | root_node = tree.data_at(tree.ROOT_INDEX) 80 | root_node = self.update_root_node(root_node, renorm_policy, root_value, root_embedding) #pylint: disable=no-member 81 | return tree.set_root(root_node) 82 | 83 | 84 | class AlphaZero(MCTS): 85 | """AlphaZero: Monte Carlo Tree Search + Neural Network Leaf Evaluation 86 | - https://arxiv.org/abs/1712.01815 87 | 88 | Most of the work is actually done in the `MCTS` class, which AlphaZero extends. 89 | This class can take an arbitrary MCTS backend, which is why we use a separate class `_AlphaZero` 90 | to handle the AlphaZero-specific logic, then combine them here. 91 | """ 92 | 93 | def __new__(cls, base_type: type = MCTS): 94 | """Creates a new AlphaZero class that extends the given MCTS class.""" 95 | assert issubclass(base_type, MCTS) 96 | cls_type = type("AlphaZero", (_AlphaZero, base_type), {}) 97 | cls_type.__name__ = f'AlphaZero({base_type.__name__})' 98 | return cls_type 99 | -------------------------------------------------------------------------------- /core/evaluators/evaluation_fns.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Callable, Tuple 3 | 4 | import chex 5 | import flax 6 | import jax 7 | 8 | 9 | def make_nn_eval_fn( 10 | nn: flax.linen.Module, 11 | state_to_nn_input_fn: Callable[[chex.ArrayTree], chex.Array] 12 | ) -> Callable[[chex.ArrayTree, chex.ArrayTree, chex.PRNGKey], Tuple[chex.Array, chex.Array]]: 13 | """Creates a leaf evaluation function using a neural network (state, params) -> (policy, value). 14 | 15 | Args: 16 | - `nn`: The neural network module. 17 | - `state_to_nn_input_fn`: A function that converts the state to the input format expected by the neural network. 18 | 19 | Returns: 20 | - `eval_fn`: A function that evaluates the state using the neural network (state, params) -> (policy, value) 21 | """ 22 | 23 | def eval_fn(state, params, *args): 24 | # get the policy and value from the neural network 25 | policy_logits, value = nn.apply(params, state_to_nn_input_fn(state)[None,...], train=False) 26 | # apply softmax to the policy logits 27 | return jax.nn.softmax(policy_logits, axis=-1).squeeze(0), value.squeeze() 28 | 29 | return eval_fn 30 | 31 | 32 | def make_nn_eval_fn_no_params_callable( 33 | nn: Callable[[chex.Array], Tuple[chex.Array, chex.Array]], 34 | state_to_nn_input_fn: Callable[[chex.ArrayTree], chex.Array] 35 | ) -> Callable[[chex.ArrayTree, chex.ArrayTree, chex.PRNGKey], Tuple[chex.Array, chex.Array]]: 36 | """Creates a leaf evaluation function that uses a stateless neural net evaluation function (state) -> (policy, value). 37 | 38 | Args: 39 | - `nn`: The stateless evaluation function. 40 | - `state_to_nn_input_fn`: A function that converts the state to the input format expected by the neural network 41 | 42 | Returns: 43 | - `eval_fn`: A function that evaluates the state using the neural network (state) -> (policy, value) 44 | """ 45 | 46 | def eval_fn(state, *args): 47 | # get the policy and value from the neural network 48 | policy_logits, value = nn(state_to_nn_input_fn(state)[None,...]) 49 | # apply softmax to the policy logits 50 | return jax.nn.softmax(policy_logits, axis=-1).squeeze(0), value.squeeze() 51 | 52 | return eval_fn 53 | -------------------------------------------------------------------------------- /core/evaluators/evaluator.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Dict 3 | 4 | import chex 5 | import jax 6 | import jax.numpy as jnp 7 | from chex import dataclass 8 | 9 | 10 | @dataclass(frozen=True) 11 | class EvalOutput: 12 | """Output of an evaluation. 13 | - `eval_state`: The updated internal state of the Evaluator. 14 | - `action`: The action to take. 15 | - `policy_weights`: The policy weights assigned to each action. 16 | """ 17 | eval_state: chex.ArrayTree 18 | action: int 19 | policy_weights: chex.Array 20 | 21 | 22 | class Evaluator: 23 | """Base class for Evaluators. 24 | An Evaluator *evaluates* an environment state, and returns an action to take, as well as a 'policy', assigning a weight to each action. 25 | Evaluators may maintain an internal state, which is updated by the `step` method. 26 | """ 27 | 28 | def __init__(self, discount: float, *args, **kwargs): # pylint: disable=unused-argument 29 | """Initializes an Evaluator. 30 | 31 | Args: 32 | - `discount`: The discount factor applied to future rewards/value estimates. 33 | """ 34 | self.discount = discount 35 | 36 | 37 | def init(self, *args, **kwargs) -> chex.ArrayTree: 38 | """Initializes the internal state of the Evaluator.""" 39 | raise NotImplementedError() 40 | 41 | 42 | def init_batched(self, batch_size: int, *args, **kwargs) -> chex.ArrayTree: 43 | """Initializes the internal state of the Evaluator across a batch dimension.""" 44 | tree = self.init(*args, **kwargs) 45 | return jax.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), tree) 46 | 47 | 48 | def reset(self, state: chex.ArrayTree) -> chex.ArrayTree: 49 | """Resets the internal state of the Evaluator.""" 50 | raise NotImplementedError() 51 | 52 | 53 | def evaluate(self, key: chex.PRNGKey, eval_state: chex.ArrayTree, env_state: chex.ArrayTree, **kwargs) -> EvalOutput: 54 | """Evaluates the environment state. 55 | 56 | Args: 57 | - `key`: rng 58 | - `eval_state`: The internal state of the Evaluator. 59 | - `env_state`: The environment state to evaluate. 60 | 61 | Returns: 62 | - `EvalOutput`: The output of the evaluation. 63 | - `eval_state`: The updated internal state of the Evaluator. 64 | - `action`: The action to take. 65 | - `policy_weights`: The policy weights assigned to each action. 66 | """ 67 | raise NotImplementedError() 68 | 69 | 70 | def step(self, state: chex.ArrayTree, action: chex.Array) -> chex.ArrayTree: # pylint: disable=unused-argument 71 | """Updates the internal state of the Evaluator. 72 | 73 | Args: 74 | - `state`: The internal state of the Evaluator. 75 | - `action`: The action taken in the environment. 76 | 77 | Returns: 78 | - (chex.ArrayTree): The updated internal state of the Evaluator. 79 | """ 80 | return state 81 | 82 | 83 | def get_value(self, state: chex.ArrayTree) -> chex.Array: 84 | """Extracts the state value estimate (for the current/root environment state) from the internal state of the Evaluator. 85 | 86 | Args: 87 | - `state`: The internal state of the Evaluator. 88 | 89 | Returns: 90 | - `chex.Array`: The value estimate. 91 | """ 92 | raise NotImplementedError() 93 | 94 | 95 | def get_config(self) -> Dict: 96 | """Returns the configuration of the Evaluator. Used for logging.""" 97 | return {'discount': self.discount} 98 | -------------------------------------------------------------------------------- /core/evaluators/mcts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero/74cdb9cbd736b396141df4a07db268f28b687f9a/core/evaluators/mcts/__init__.py -------------------------------------------------------------------------------- /core/evaluators/mcts/action_selection.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Dict 3 | 4 | import chex 5 | import jax.numpy as jnp 6 | 7 | from core.evaluators.mcts.state import MCTSTree 8 | 9 | 10 | def normalize_q_values( 11 | q_values: chex.Array, 12 | child_n_values: chex.Array, 13 | parent_q_value: float, 14 | epsilon: float 15 | ) -> chex.Array: 16 | """Normalize Q-values to be in the range [0, 1]. 17 | 18 | Args: 19 | - `q_values`: Q-values to normalize 20 | - `child_n_values`: visit counts of child nodes 21 | - `parent_q_value`: Q-value of the parent node 22 | - `epsilon`: small value to avoid division by zero 23 | 24 | Returns: 25 | - (chex.Array): normalized Q-values 26 | """ 27 | min_value = jnp.minimum(parent_q_value, jnp.min(q_values, axis=-1)) 28 | max_value = jnp.maximum(parent_q_value, jnp.max(q_values, axis=-1)) 29 | completed_by_min = jnp.where(child_n_values > 0, q_values, min_value) 30 | normalized = (completed_by_min - min_value) / ( 31 | jnp.maximum(max_value - min_value, epsilon)) 32 | return normalized 33 | 34 | 35 | class MCTSActionSelector: 36 | """Base class for action selection in MCTS. 37 | 38 | Is callable, selects an action given a search tree state. 39 | """ 40 | 41 | def __init__(self, epsilon: float = 1e-8): 42 | """ 43 | Args: 44 | - `epsilon`: small value to avoid division by zero 45 | """ 46 | self.epsilon = epsilon 47 | 48 | 49 | def __call__(self, tree: MCTSTree, index: int, discount: float) -> int: 50 | """Selects an action given a search tree state. Implemented by subclasses.""" 51 | raise NotImplementedError() 52 | 53 | 54 | def get_config(self) -> Dict: 55 | """Returns the configuration of the action selector. Used for logging.""" 56 | return { 57 | "epsilon": self.epsilon 58 | } 59 | 60 | 61 | class PUCTSelector(MCTSActionSelector): 62 | """PUCT (Polynomial Upper Confidence Trees) action selector. 63 | 64 | This is the algorithm used for action selection within AlphaZero.""" 65 | 66 | def __init__(self, 67 | c: float = 1.0, 68 | epsilon: float = 1e-8, 69 | q_transform = normalize_q_values 70 | ): 71 | """ 72 | Args: 73 | - `c`: exploration constant (larger values encourage exploration) 74 | - `epsilon`: small value to avoid division by zero 75 | - `q_transform`: function applied to q-values before selection 76 | """ 77 | super().__init__(epsilon=epsilon) 78 | self.c = c 79 | self.q_transform = q_transform 80 | 81 | 82 | def get_config(self) -> Dict: 83 | """Returns the configuration of the PUCT action selector. Used for logging.""" 84 | return { 85 | "c": self.c, 86 | 'q_transform': self.q_transform.__name__, 87 | **super().get_config() 88 | } 89 | 90 | 91 | def __call__(self, tree: MCTSTree, index: int, discount: float) -> int: 92 | """Selects an action given a search tree state. 93 | 94 | Args: 95 | - `tree`: search tree 96 | - `index`: index of the node in the search tree to select an action to take from 97 | - `discount`: discount factor 98 | 99 | Returns: 100 | - (int): id of action to take 101 | """ 102 | # get child q-values 103 | node = tree.data_at(index) 104 | q_values = tree.get_child_data('q', index) 105 | # apply discount to q-values 106 | discounted_q_values = q_values * discount 107 | # get child visit counts 108 | n_values = tree.get_child_data('n', index) 109 | # normalize/transform q-values 110 | q_values = self.q_transform(discounted_q_values, n_values, node.q, self.epsilon) 111 | # calculate U-values 112 | u_values = self.c * node.p * jnp.sqrt(node.n) / (n_values + 1) 113 | # PUCT = Q-value + U-value 114 | puct_values = q_values + u_values 115 | # select action with highest PUCT value 116 | return puct_values.argmax() 117 | 118 | 119 | class MuZeroPUCTSelector(MCTSActionSelector): 120 | """Implements the variant of PUCT used in MuZero.""" 121 | 122 | def __init__(self, 123 | c1: float = 1.25, 124 | c2: float = 19652, 125 | epsilon: float = 1e-8, 126 | q_transform = normalize_q_values 127 | ): 128 | """ 129 | Args: 130 | - `c1`: 1st exploration constant 131 | - `c2`: 2nd exploration constant 132 | - `epsilon`: small value to avoid division by zero 133 | - `q_transform`: function applied to q-values before selection 134 | """ 135 | super().__init__(epsilon=epsilon) 136 | self.c1 = c1 137 | self.c2 = c2 138 | self.q_transform = q_transform 139 | 140 | 141 | def get_config(self) -> Dict: 142 | """Returns the configuration of the MuZero PUCT action selector. Used for logging.""" 143 | return { 144 | "c1": self.c1, 145 | "c2": self.c2, 146 | "q_transform": self.q_transform.__name__, 147 | **super().get_config() 148 | } 149 | 150 | def __call__(self, tree: MCTSTree, index: int, discount: float) -> int: 151 | """Selects an action given a search tree state. 152 | 153 | Args: 154 | - `tree`: search tree 155 | - `index`: index of the node in the search tree to select an action to take from 156 | - `discount`: discount factor 157 | 158 | Returns: 159 | - (int): id of action to take 160 | """ 161 | # get child q-values 162 | node = tree.data_at(index) 163 | q_values = tree.get_child_data('q', index) 164 | # apply discount to q-values 165 | discounted_q_values = q_values * discount 166 | # get child visit counts 167 | n_values = tree.get_child_data('n', index) 168 | # normalize/transform q-values 169 | q_values = self.q_transform(discounted_q_values, q_values, n_values, node.q, self.epsilon) 170 | # calculate U-values 171 | base_term = node.p * jnp.sqrt(node.n) / (n_values + 1) 172 | log_term = jnp.log((node.n + self.c2 + 1) / self.c2) + self.c1 173 | u_values = base_term * log_term 174 | # PUCT = Q-value + U-value 175 | puct_values = q_values + u_values 176 | # select action with highest PUCT value 177 | return puct_values.argmax() 178 | -------------------------------------------------------------------------------- /core/evaluators/mcts/mcts.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import partial 3 | from typing import Dict, Optional, Tuple 4 | import jax 5 | import chex 6 | import jax.numpy as jnp 7 | from core.evaluators.evaluator import Evaluator 8 | from core.evaluators.mcts.action_selection import MCTSActionSelector 9 | from core.evaluators.mcts.state import BackpropState, MCTSNode, MCTSTree, TraversalState, MCTSOutput 10 | from core.trees.tree import init_tree 11 | from core.types import EnvStepFn, EvalFn, StepMetadata 12 | 13 | class MCTS(Evaluator): 14 | """Batched implementation of Monte Carlo Tree Search (MCTS). 15 | 16 | Not stateful. This class operates on 'MCTSTree' state objects. 17 | 18 | Compatible with `jax.vmap`, `jax.pmap`, `jax.jit`, etc.""" 19 | def __init__(self, 20 | eval_fn: EvalFn, 21 | action_selector: MCTSActionSelector, 22 | branching_factor: int, 23 | max_nodes: int, 24 | num_iterations: int, 25 | discount: float = -1.0, 26 | temperature: float = 1.0, 27 | tiebreak_noise: float = 1e-8, 28 | persist_tree: bool = True 29 | ): 30 | """ 31 | Args: 32 | - `eval_fn`: leaf node evaluation function (env_state -> (policy_logits, value)) 33 | - `action_selector`: action selection function (eval_state -> action) 34 | - `branching_factor`: max number of actions (== children per node) 35 | - `max_nodes`: allocated size of MCTS tree, any additional nodes will not be created, 36 | but values from out-of-bounds leaf nodes will still backpropagate 37 | - `num_iterations`: number of MCTS iterations to perform per evaluate call 38 | - `discount`: discount factor for MCTS (default: -1.0) 39 | - use a negative discount in two-player games (e.g. -1.0) 40 | - use a positive discount in single-player games (e.g. 1.0) 41 | - `temperature`: temperature for root action selection (default: 1.0) 42 | - `tiebreak_noise`: magnitude of noise to add to policy weights for breaking ties (default: 1e-8) 43 | - `persist_tree`: whether to persist search tree state between calls to `evaluate` (default: True) 44 | """ 45 | super().__init__(discount=discount) 46 | self.eval_fn = eval_fn 47 | self.num_iterations = num_iterations 48 | self.branching_factor = branching_factor 49 | self.max_nodes = max_nodes 50 | self.action_selector = action_selector 51 | self.temperature = temperature 52 | self.tiebreak_noise = tiebreak_noise 53 | self.persist_tree = persist_tree 54 | 55 | 56 | def get_config(self) -> Dict: 57 | """returns a config object for checkpoints""" 58 | return { 59 | "eval_fn": self.eval_fn.__name__, 60 | "num_iterations": self.num_iterations, 61 | "branching_factor": self.branching_factor, 62 | "max_nodes": self.max_nodes, 63 | "action_selection_config": self.action_selector.get_config(), 64 | "discount": self.discount, 65 | "temperature": self.temperature, 66 | "tiebreak_noise": self.tiebreak_noise, 67 | "persist_tree": self.persist_tree 68 | } 69 | 70 | 71 | def evaluate(self, #pylint: disable=arguments-differ 72 | key: chex.PRNGKey, 73 | eval_state: MCTSTree, 74 | env_state: chex.ArrayTree, 75 | root_metadata: StepMetadata, 76 | params: chex.ArrayTree, 77 | env_step_fn: EnvStepFn, 78 | **kwargs 79 | ) -> MCTSOutput: 80 | """Performs `self.num_iterations` MCTS iterations on an `MCTSTree`. 81 | Samples an action to take from the root node after search is completed. 82 | 83 | Args: 84 | - `eval_state`: `MCTSTree` to evaluate, could be empty or partially complete 85 | - `env_state`: current environment state 86 | - `root_metadata`: metadata for the root node of the tree 87 | - `params`: parameters to pass to the the leaf evaluation function 88 | - `env_step_fn`: env step fn: (env_state, action) -> (new_env_state, metadata) 89 | 90 | Returns: 91 | - (MCTSOutput): contains new tree state, selected action, root value, and policy weights 92 | """ 93 | # store current state metadata in the root node 94 | key, root_key = jax.random.split(key) 95 | eval_state = self.update_root(root_key, eval_state, env_state, params, root_metadata=root_metadata) 96 | # perform 'num_iterations' iterations of MCTS 97 | iterate = partial(self.iterate, params=params, env_step_fn=env_step_fn) 98 | 99 | iterate_keys = jax.random.split(key, self.num_iterations) 100 | eval_state, _ = jax.lax.scan(lambda state, k: (iterate(k, state), None), eval_state, iterate_keys) 101 | # sample action based on root visit counts 102 | # (also get normalized policy weights for training purposes) 103 | action, policy_weights = self.sample_root_action(key, eval_state) 104 | return MCTSOutput( 105 | eval_state=eval_state, 106 | action=action, 107 | policy_weights=policy_weights 108 | ) 109 | 110 | 111 | def get_value(self, state: MCTSTree) -> chex.Array: 112 | """Returns value estimate of the environment state stored in the root node of the tree. 113 | 114 | Args: 115 | - `state`: MCTSTree to evaluate 116 | 117 | Returns: 118 | - (chex.Array): value estimate of the environment state stored in the root node of the tree 119 | """ 120 | return state.data_at(state.ROOT_INDEX).q 121 | 122 | 123 | def update_root(self, key: chex.PRNGKey, tree: MCTSTree, root_embedding: chex.ArrayTree, 124 | params: chex.ArrayTree, **kwargs) -> MCTSTree: #pylint: disable=unused-argument 125 | """Populates the root node of an MCTSTree. 126 | 127 | Args: 128 | - `key`: rng 129 | - `tree`: MCTSTree to update 130 | - `root_embedding`: root environment state 131 | - `params`: nn parameters 132 | 133 | Returns: 134 | - (MCTSTree): updated MCTSTree 135 | """ 136 | # evaluate root state 137 | root_policy_logits, root_value = self.eval_fn(root_embedding, params, key) 138 | root_policy = jax.nn.softmax(root_policy_logits) 139 | # update root node 140 | root_node = tree.data_at(tree.ROOT_INDEX) 141 | root_node = self.update_root_node(root_node, root_policy, root_value, root_embedding) 142 | return tree.set_root(root_node) 143 | 144 | 145 | def iterate(self, key: chex.PRNGKey, tree: MCTSTree, params: chex.ArrayTree, env_step_fn: EnvStepFn) -> MCTSTree: 146 | """ Performs one iteration of MCTS. 147 | 1. Traverse to leaf node. 148 | 2. Evaluate Leaf Node 149 | 3. Expand Leaf Node (add to tree) 150 | 4. Backpropagate 151 | 152 | Args: 153 | - `tree`: MCTSTree to evaluate 154 | - `params`: parameters to pass to the the leaf evaluation function 155 | - `env_step_fn`: env step fn: (env_state, action) -> (new_env_state, metadata) 156 | 157 | Returns: 158 | - (MCTSTree): updated MCTSTree 159 | """ 160 | # traverse from root -> leaf 161 | traversal_state = self.traverse(tree) 162 | parent, action = traversal_state.parent, traversal_state.action 163 | # get env state (embedding) for leaf node 164 | embedding = tree.data_at(parent).embedding 165 | new_embedding, metadata = env_step_fn(embedding, action) 166 | player_reward = metadata.rewards[metadata.cur_player_id] 167 | # evaluate leaf node 168 | eval_key, key = jax.random.split(key) 169 | policy_logits, value = self.eval_fn(new_embedding, params, eval_key) 170 | policy_logits = jnp.where(metadata.action_mask, policy_logits, jnp.finfo(policy_logits).min) 171 | policy = jax.nn.softmax(policy_logits) 172 | value = jnp.where(metadata.terminated, player_reward, value) 173 | # add leaf node to tree 174 | node_exists = tree.is_edge(parent, action) 175 | node_idx = tree.edge_map[parent, action] 176 | 177 | node_data = jax.lax.cond( 178 | node_exists, 179 | lambda: self.visit_node(node=tree.data_at(node_idx), value=value, p=policy, terminated=metadata.terminated, embedding=new_embedding), 180 | lambda: self.new_node(policy=policy, value=value, embedding=new_embedding, terminated=metadata.terminated) 181 | ) 182 | 183 | tree = jax.lax.cond( 184 | node_exists, 185 | lambda: tree.update_node(index=node_idx, data = node_data), 186 | lambda: tree.add_node(parent_index=parent, edge_index=action, data=node_data) 187 | ) 188 | # backpropagate 189 | return self.backpropagate(key, tree, parent, value) 190 | 191 | 192 | def traverse(self, tree: MCTSTree) -> TraversalState: 193 | """ Traverse from the root node until an unvisited leaf node is reached. 194 | 195 | Args: 196 | - `tree`: MCTSTree to evaluate 197 | 198 | Returns: 199 | - (TraversalState): state of the traversal 200 | - `parent`: index of the parent node 201 | - `action`: action to take from the parent node 202 | """ 203 | 204 | # continue while: 205 | # - there is an existing edge corresponding to the chosen action 206 | # - AND the child node connected to that edge is not terminal 207 | def cond_fn(state: TraversalState) -> bool: 208 | return jnp.logical_and( 209 | tree.is_edge(state.parent, state.action), 210 | ~(tree.data_at(tree.edge_map[state.parent, state.action]).terminated) 211 | # TODO: maximum depth 212 | ) 213 | 214 | # each iterration: 215 | # - get the index of the child node connected to the chosen action 216 | # - choose the action to take from the child node 217 | def body_fn(state: TraversalState) -> TraversalState: 218 | node_idx = tree.edge_map[state.parent, state.action] 219 | action = self.action_selector(tree, node_idx, self.discount) 220 | return TraversalState(parent=node_idx, action=action) 221 | 222 | # choose the action to take from the root 223 | root_action = self.action_selector(tree, tree.ROOT_INDEX, self.discount) 224 | # traverse from root to leaf 225 | return jax.lax.while_loop( 226 | cond_fn, body_fn, 227 | TraversalState(parent=tree.ROOT_INDEX, action=root_action) 228 | ) 229 | 230 | 231 | def backpropagate(self, key: chex.PRNGKey, tree: MCTSTree, parent: int, value: float) -> MCTSTree: #pylint: disable=unused-argument 232 | """Backpropagate the value estimate from the leaf node to the root node and update visit counts. 233 | 234 | Args: 235 | - `key`: rng 236 | - `tree`: MCTSTree to evaluate 237 | - `parent`: index of the parent node (in most cases, this is the new node added to the tree this iteration) 238 | - `value`: value estimate of the leaf node 239 | 240 | Returns: 241 | - (MCTSTree): updated search tree 242 | """ 243 | 244 | def body_fn(state: BackpropState) -> Tuple[int, MCTSTree]: 245 | node_idx, value, tree = state.node_idx, state.value, state.tree 246 | # apply discount to value estimate 247 | value *= self.discount 248 | node = tree.data_at(node_idx) 249 | # increment visit count and update value estimate 250 | new_node = self.visit_node(node, value) 251 | tree = tree.update_node(node_idx, new_node) 252 | # go to parent 253 | return BackpropState(node_idx=tree.parents[node_idx], value=value, tree=tree) 254 | 255 | # backpropagate while the node is a valid node 256 | # the root has no parent, so the loop will terminate 257 | # when the parent of the root is visited 258 | state = jax.lax.while_loop( 259 | lambda s: s.node_idx != s.tree.NULL_INDEX, body_fn, 260 | BackpropState(node_idx=parent, value=value, tree=tree) 261 | ) 262 | return state.tree 263 | 264 | 265 | def sample_root_action(self, key: chex.PRNGKey, tree: MCTSTree) -> Tuple[int, chex.Array]: 266 | """Sample an action based on the root visit counts. 267 | 268 | Args: 269 | - `key`: rng 270 | - `tree`: MCTSTree to evaluate 271 | 272 | Returns: 273 | - (Tuple[int, chex.Array]): sampled action, normalized policy weights 274 | """ 275 | # get root visit counts 276 | action_visits = tree.get_child_data('n', tree.ROOT_INDEX) 277 | # normalize visit counts to get policy weights 278 | total_visits = action_visits.sum(axis=-1) 279 | policy_weights = action_visits / jnp.maximum(total_visits, 1) 280 | policy_weights = jnp.where(total_visits > 0, policy_weights, 1 / self.branching_factor) 281 | 282 | # zero temperature == argmax 283 | if self.temperature == 0: 284 | # break ties by adding small amount of noise 285 | noise = jax.random.uniform(key, shape=policy_weights.shape, maxval=self.tiebreak_noise) 286 | noisy_policy_weights = policy_weights + noise 287 | return jnp.argmax(noisy_policy_weights), policy_weights 288 | 289 | # apply temperature 290 | policy_weights_t = policy_weights ** (1/self.temperature) 291 | # re-normalize 292 | policy_weights_t /= policy_weights_t.sum() 293 | # sample action 294 | action = jax.random.choice(key, policy_weights_t.shape[-1], p=policy_weights_t) 295 | # return original policy weights (we train on the policy before temperature is applied) 296 | return action, policy_weights 297 | 298 | 299 | @staticmethod 300 | def visit_node( 301 | node: MCTSNode, 302 | value: float, 303 | p: Optional[chex.Array] = None, 304 | terminated: Optional[bool] = None, 305 | embedding: Optional[chex.ArrayTree] = None 306 | ) -> MCTSNode: 307 | """ Update the visit counts and value estimate of a node. 308 | 309 | Args: 310 | - `node`: MCTSNode to update 311 | - `value`: value estimate to update the node with 312 | 313 | ( we could optionally overwrite the following: ) 314 | - `p`: policy weights to update the node with 315 | - `terminated`: whether the node is terminal 316 | - `embedding`: embedding to update the node with 317 | 318 | Returns: 319 | - (MCTSNode): updated MCTSNode 320 | """ 321 | # update running value estimate 322 | q_value = ((node.q * node.n) + value) / (node.n + 1) 323 | # update other node attributes 324 | if p is None: 325 | p = node.p 326 | if terminated is None: 327 | terminated = node.terminated 328 | if embedding is None: 329 | embedding = node.embedding 330 | return node.replace( 331 | n=node.n + 1, # increment visit count 332 | q=q_value, 333 | p=p, 334 | terminated=terminated, 335 | embedding=embedding 336 | ) 337 | 338 | 339 | @staticmethod 340 | def new_node(policy: chex.Array, value: float, embedding: chex.ArrayTree, terminated: bool) -> MCTSNode: 341 | """Create a new MCTSNode. 342 | 343 | Args: 344 | - `policy`: policy weights 345 | - `value`: value estimate 346 | - `embedding`: environment state embedding 347 | - 'embedding' because in some MCTS use-cases, e.g. MuZero, we store an embedding of the state 348 | rather than the state itself. In AlphaZero, this is just the entire environment state. 349 | - `terminated`: whether the state is terminal 350 | 351 | Returns: 352 | - (MCTSNode): initialized MCTSNode 353 | """ 354 | return MCTSNode( 355 | n=jnp.array(1, dtype=jnp.int32), # init visit count to 1 356 | p=policy, 357 | q=jnp.array(value, dtype=jnp.float32), 358 | terminated=jnp.array(terminated, dtype=jnp.bool_), 359 | embedding=embedding 360 | ) 361 | 362 | 363 | @staticmethod 364 | def update_root_node(root_node: MCTSNode, root_policy: chex.Array, root_value: float, root_embedding: chex.ArrayTree) -> MCTSNode: 365 | """Update the root node of the search tree. 366 | 367 | Args: 368 | - `root_node`: node to update 369 | - `root_policy`: policy weights 370 | - `root_value`: value estimate 371 | - `root_embedding`: environment state embedding 372 | 373 | Returns: 374 | - (MCTSNode): updated root node 375 | """ 376 | visited = root_node.n > 0 377 | return root_node.replace( 378 | p=root_policy, 379 | # keep old value estimate if the node has already been visited 380 | q=jnp.where(visited, root_node.q, root_value), 381 | # keep old visit count if the node has already been visited 382 | n=jnp.where(visited, root_node.n, 1), 383 | embedding=root_embedding 384 | ) 385 | 386 | 387 | def reset(self, state: MCTSTree) -> MCTSTree: 388 | """Resets the internal state of MCTS. 389 | 390 | Args: 391 | - `state`: evaluator state 392 | 393 | Returns: 394 | - (MCTSTree): reset evaluator state 395 | """ 396 | return state.reset() 397 | 398 | 399 | def step(self, state: MCTSTree, action: int) -> MCTSTree: 400 | """Update the internal state of MCTS after taking an action in the environment. 401 | 402 | Args: 403 | - `state`: evaluator state 404 | - `action`: action taken in the environment 405 | 406 | Returns: 407 | - (MCTSTree): updated evaluator state 408 | """ 409 | 410 | if self.persist_tree: 411 | # get subtree corresponding to action taken if persist_tree is True 412 | return state.get_subtree(action) 413 | # just reset to an empty tree if persist_tree is False 414 | return state.reset() 415 | 416 | 417 | def init(self, template_embedding: chex.ArrayTree, *args, **kwargs) -> MCTSTree: #pylint: disable=arguments-differ 418 | """Initializes the internal state of the MCTS evaluator. 419 | 420 | Args: 421 | - `template_embedding`: template environment state embedding 422 | - not stored, just used to initialize data structures to the correct shape 423 | 424 | Returns: 425 | - (MCTSTree): initialized MCTSTree 426 | """ 427 | return init_tree(self.max_nodes, self.branching_factor, self.new_node( 428 | policy=jnp.zeros((self.branching_factor,)), 429 | value=0.0, 430 | embedding=template_embedding, 431 | terminated=False 432 | )) 433 | -------------------------------------------------------------------------------- /core/evaluators/mcts/state.py: -------------------------------------------------------------------------------- 1 | 2 | import chex 3 | from chex import dataclass 4 | import graphviz 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from core.evaluators.evaluator import EvalOutput 9 | from core.trees.tree import Tree 10 | 11 | 12 | @dataclass(frozen=True) 13 | class MCTSNode: 14 | """Base MCTS node data strucutre. 15 | - `n`: visit count 16 | - `p`: policy vector 17 | - `q`: cumulative value estimate / visit count 18 | - `terminated`: whether the environment state is terminal 19 | - `embedding`: environment state 20 | """ 21 | n: jnp.number 22 | p: chex.Array 23 | q: jnp.number 24 | terminated: jnp.number 25 | embedding: chex.ArrayTree 26 | 27 | @property 28 | def w(self) -> jnp.number: 29 | """cumulative value estimate""" 30 | return self.q * self.n 31 | 32 | 33 | # an MCTSTree is a Tree containing MCTSNodes 34 | MCTSTree = Tree[MCTSNode] 35 | 36 | 37 | @dataclass(frozen=True) 38 | class TraversalState: 39 | """State used during traversal step of MCTS. 40 | - `parent`: parent node index 41 | - `action`: action taken from parent 42 | """ 43 | parent: int 44 | action: int 45 | 46 | 47 | @dataclass(frozen=True) 48 | class BackpropState: 49 | """State used during backpropagation step of MCTS. 50 | - `node_idx`: current node 51 | - `value`: value to backpropagate 52 | - `tree`: search tree 53 | """ 54 | node_idx: int 55 | value: float 56 | tree: MCTSTree 57 | 58 | 59 | @dataclass(frozen=True) 60 | class MCTSOutput(EvalOutput): 61 | """Output of an MCTS evaluation. See EvalOutput. 62 | - `eval_state`: The updated internal state of the Evaluator. 63 | - `policy_weights`: The policy weights assigned to each action. 64 | """ 65 | eval_state: MCTSTree 66 | policy_weights: chex.Array 67 | 68 | 69 | def tree_to_graph(tree, batch_id=0): 70 | """Converts a search tree to a graphviz graph.""" 71 | graph = graphviz.Digraph() 72 | 73 | def get_child_visits_no_batch(tree, index): 74 | mapping = tree.edge_map[batch_id, index] 75 | child_data = tree.data.n[batch_id, mapping] 76 | return jnp.where( 77 | (mapping == Tree.NULL_INDEX).reshape((-1,) + (1,) * (child_data.ndim - 1)), 78 | 0, 79 | child_data, 80 | ) 81 | 82 | for n_i in range(tree.parents.shape[1]): 83 | node = jax.tree_util.tree_map(lambda x: x[batch_id, n_i], tree.data) 84 | if node.n.item() > 0: 85 | graph.node(str(n_i), str({ 86 | "i": str(n_i), 87 | "n": str(node.n.item()), 88 | "q": f"{node.q.item():.2f}", 89 | "t": str(node.terminated.item()) 90 | })) 91 | 92 | child_visits = get_child_visits_no_batch(tree, n_i) 93 | mapping = tree.edge_map[batch_id, n_i] 94 | for a_i in range(tree.edge_map.shape[2]): 95 | v_a = child_visits[a_i].item() 96 | if v_a > 0: 97 | graph.edge(str(n_i), str(mapping[a_i]), f'{a_i}:{node.p[a_i]:.4f}') 98 | else: 99 | break 100 | 101 | return graph 102 | -------------------------------------------------------------------------------- /core/evaluators/mcts/weighted_mcts.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Dict, Tuple 3 | 4 | import chex 5 | from chex import dataclass 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | from core.evaluators.mcts.mcts import MCTS 10 | from core.evaluators.mcts.state import BackpropState, MCTSNode, MCTSTree 11 | from core.evaluators.mcts.action_selection import normalize_q_values 12 | 13 | 14 | @dataclass(frozen=True) 15 | class WeightedMCTSNode(MCTSNode): 16 | # Weighted MCTS needs access to the original raw value returned by the leaf evaluation 17 | r: float # raw value 18 | 19 | 20 | class WeightedMCTS(MCTS): 21 | """Weighted MCTS implementation: 22 | - https://twitter.com/ptrschmdtnlsn/status/1748800529608888362 23 | """ 24 | 25 | def __init__(self, q_temperature: float = 1.0, *args, **kwargs): 26 | """Initializes a WeightedMCTS evaluator. 27 | 28 | Args: 29 | - `q_temperature`: temperature to apply to child q-values when backpropagating 30 | """ 31 | super().__init__(*args, **kwargs) 32 | self.q_temperature = q_temperature 33 | 34 | 35 | def get_config(self) -> Dict: 36 | """Returns the configuration of the WeightedMCTS evaluator. Used for logging.""" 37 | return { 38 | "q_temperature": self.q_temperature, 39 | **super().get_config() 40 | } 41 | 42 | 43 | @staticmethod 44 | def new_node(policy: chex.Array, value: float, embedding: chex.ArrayTree, terminated: bool) -> WeightedMCTSNode: 45 | """Create a new WeightedMCTSNode. 46 | 47 | Args: 48 | - `policy`: policy vector 49 | - `value`: value estimate 50 | - `embedding`: environment state 51 | - `terminated`: whether the environment state is terminal 52 | 53 | Returns: 54 | - (WeightedMCTSNode): new node 55 | """ 56 | return WeightedMCTSNode( 57 | n=jnp.array(1, dtype=jnp.int32), 58 | p=policy, 59 | q=jnp.array(value, dtype=jnp.float32), 60 | r=jnp.array(value, dtype=jnp.float32), 61 | terminated=jnp.array(terminated, dtype=jnp.bool_), 62 | embedding=embedding 63 | ) 64 | 65 | 66 | @staticmethod 67 | def update_root_node(root_node: MCTSNode, root_policy: chex.Array, root_value: float, root_embedding: chex.ArrayTree) -> WeightedMCTSNode: 68 | """ Updates the root node 69 | - if the tree is empty, create a new node 70 | - otherwise, update the existing root node 71 | 72 | Args: 73 | - `root_node`: root node 74 | - `root_policy`: root policy 75 | - `root_value`: root value 76 | - `root_embedding`: root environment state 77 | 78 | Returns: 79 | - (WeightedMCTSNode): updated root node""" 80 | visited = root_node.n > 0 81 | return root_node.replace( 82 | p=root_policy, 83 | q=jnp.where(visited, root_node.q, root_value), 84 | r=jnp.where(visited, root_node.r, root_value), 85 | n=jnp.where(visited, root_node.n, 1), 86 | embedding=root_embedding 87 | ) 88 | 89 | 90 | def backpropagate(self, key: chex.PRNGKey, tree: MCTSTree, parent: int, value: float) -> MCTSTree: 91 | """Backpropagate weighted sums of child q-values and update visit counts. 92 | 93 | Args: 94 | - `key`: rng 95 | - `tree`: The search tree. 96 | - `parent`: index of the parent node (in most cases, this is the new node added to the tree this iteration) 97 | - `value`: expanded node value estimate 98 | 99 | Returns: 100 | - `tree`: updated search tree 101 | """ 102 | def body_fn(state: BackpropState) -> Tuple[int, MCTSTree]: 103 | node_idx, tree = state.node_idx, state.tree 104 | # get node data 105 | node = tree.data_at(node_idx) 106 | # get q values, visit counts of children 107 | child_q_values = tree.get_child_data('q', node_idx) * self.discount 108 | child_n_values = tree.get_child_data('n', node_idx) 109 | 110 | # normalize q-values to [0, 1] 111 | normalized_q_values = normalize_q_values(child_q_values, child_n_values, node.q, jnp.finfo(node.q).eps) 112 | 113 | if self.q_temperature > 0: 114 | # if temperature > 0, apply temperature to q-values 115 | q_values = normalized_q_values ** (1/self.q_temperature) 116 | # mask out unvisited action a[i] so softmax(a)[i] = 0.0 117 | q_values_masked = jnp.where( 118 | child_n_values > 0, normalized_q_values, jnp.finfo(normalized_q_values).min 119 | ) 120 | else: 121 | # if temperature == 0, select max q-value 122 | # apply random noise to break ties amongst nodes w/ same number of visits 123 | noise = jax.random.uniform(key, shape=normalized_q_values.shape, maxval=self.tiebreak_noise) 124 | noisy_q_values = normalized_q_values + noise 125 | 126 | # mask out all values except for max value index 127 | # so softmax output at a[max_index] = 1 and 0 everywhere else 128 | max_vector = jnp.full_like(noisy_q_values, jnp.finfo(noisy_q_values).min) 129 | index_of_max = jnp.argmax(noisy_q_values) 130 | max_vector = max_vector.at[index_of_max].set(1) 131 | q_values = normalized_q_values 132 | q_values_masked = max_vector 133 | 134 | # compute weights 135 | child_weights = jax.nn.softmax(q_values_masked, axis=-1) 136 | # computer weighted sum of q-values 137 | weighted_value = jnp.sum(child_weights * q_values) 138 | # update node with weighted value 139 | node = node.replace(q=weighted_value) 140 | # adjust node value to ((weighted_value * node_visits) + raw_value) / (node_visits + 1) 141 | # and increment visit count 142 | node = self.visit_node(node, node.r) 143 | # update search tree 144 | tree = tree.update_node(node_idx, node) 145 | # backprop to parent node 146 | return BackpropState(node_idx=tree.parents[node_idx], value=value, tree=tree) 147 | 148 | state = jax.lax.while_loop( 149 | lambda s: s.node_idx != s.tree.NULL_INDEX, body_fn, 150 | BackpropState(node_idx=parent, value=value, tree=tree) 151 | ) 152 | return state.tree 153 | -------------------------------------------------------------------------------- /core/memory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero/74cdb9cbd736b396141df4a07db268f28b687f9a/core/memory/__init__.py -------------------------------------------------------------------------------- /core/memory/replay_memory.py: -------------------------------------------------------------------------------- 1 | 2 | import chex 3 | from chex import dataclass 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | 8 | @dataclass(frozen=True) 9 | class BaseExperience: 10 | """Experience data structure. Stores a training sample. 11 | - `reward`: reward for each player in the episode this sample belongs to 12 | - `policy_weights`: policy weights 13 | - `policy_mask`: mask for policy weights (mask out invalid/illegal actions) 14 | - `observation_nn`: observation for neural network input 15 | - `cur_player_id`: current player id 16 | """ 17 | reward: chex.Array 18 | policy_weights: chex.Array 19 | policy_mask: chex.Array 20 | observation_nn: chex.Array 21 | cur_player_id: chex.Array 22 | 23 | 24 | @dataclass(frozen=True) 25 | class ReplayBufferState: 26 | """State of the replay buffer. Stores objects stored in the buffer 27 | and metadata used to determine where to store the next object, as well as 28 | which objects are valid to sample from. 29 | - `next_idx`: index where the next experience will be stored 30 | - `episode_start_idx`: index where the current episode started, samples are placed in order 31 | - `buffer`: buffer of experiences 32 | - `populated`: mask for populated buffer indices 33 | - `has_reward`: mask for buffer indices that have been assigned a reward 34 | - we store samples from in-progress episodes, but don't want to be able to sample them 35 | until the episode is complete 36 | """ 37 | next_idx: int 38 | episode_start_idx: int 39 | buffer: BaseExperience 40 | populated: chex.Array 41 | has_reward: chex.Array 42 | 43 | 44 | class EpisodeReplayBuffer: 45 | """Replay buffer, stores trajectories from episodes for training. 46 | 47 | Compatible with `jax.jit`, `jax.vmap`, and `jax.pmap`.""" 48 | 49 | def __init__(self, 50 | capacity: int, 51 | ): 52 | """ 53 | Args: 54 | - `capacity`: number of experiences to store in the buffer 55 | """ 56 | self.capacity = capacity 57 | 58 | 59 | def get_config(self): 60 | """Returns the configuration of the replay buffer. Used for logging.""" 61 | return { 62 | 'capacity': self.capacity, 63 | } 64 | 65 | 66 | def add_experience(self, state: ReplayBufferState, experience: BaseExperience) -> ReplayBufferState: 67 | """Adds an experience to the replay buffer. 68 | 69 | Args: 70 | - `state`: replay buffer state 71 | - `experience`: experience to add 72 | 73 | Returns: 74 | - (ReplayBufferState): updated replay buffer state""" 75 | return state.replace( 76 | buffer = jax.tree_util.tree_map( 77 | lambda x, y: x.at[state.next_idx].set(y), 78 | state.buffer, 79 | experience 80 | ), 81 | next_idx = (state.next_idx + 1) % self.capacity, 82 | populated = state.populated.at[state.next_idx].set(True), 83 | has_reward = state.has_reward.at[state.next_idx].set(False) 84 | ) 85 | 86 | 87 | def assign_rewards(self, state: ReplayBufferState, reward: chex.Array) -> ReplayBufferState: 88 | """ Assign rewards to the current episode. 89 | 90 | Args: 91 | - `state`: replay buffer state 92 | - `reward`: rewards to assign (for each player) 93 | 94 | Returns: 95 | - (ReplayBufferState): updated replay buffer state 96 | """ 97 | return state.replace( 98 | episode_start_idx = state.next_idx, 99 | has_reward = jnp.full_like(state.has_reward, True), 100 | buffer = state.buffer.replace( 101 | reward = jnp.where( 102 | ~state.has_reward[..., None], 103 | reward[None, ...], 104 | state.buffer.reward 105 | ) 106 | ) 107 | ) 108 | 109 | 110 | def truncate(self, 111 | state: ReplayBufferState, 112 | ) -> ReplayBufferState: 113 | """Truncates the replay buffer, removing all experiences from the current episode. 114 | Use this if we want to discard all experiences from the current episode. 115 | 116 | Args: 117 | - `state`: replay buffer state 118 | 119 | Returns: 120 | - (ReplayBufferState): updated replay buffer state 121 | """ 122 | # un-assigned trajectory indices have populated set to False 123 | # so their buffer contents will be overwritten (eventually) 124 | # and cannot be sampled 125 | # so there's no need to overwrite them with zeros here 126 | return state.replace( 127 | next_idx = state.episode_start_idx, 128 | has_reward = jnp.full_like(state.has_reward, True), 129 | populated = jnp.where( 130 | ~state.has_reward, 131 | False, 132 | state.populated 133 | ) 134 | ) 135 | 136 | # assumes input is batched!! (dont vmap/pmap) 137 | def sample(self, 138 | state: ReplayBufferState, 139 | key: jax.random.PRNGKey, 140 | sample_size: int 141 | ) -> chex.ArrayTree: 142 | """Samples experiences from the replay buffer. 143 | 144 | Assumes the buffer has two batch dimensions, so shape = (devices, batch_size, capacity, ...) 145 | Perhaps there is a dimension-agnostic way to do this? 146 | 147 | Samples across all batch dimensions, not per-batch/device. 148 | 149 | Args: 150 | - `state`: replay buffer state 151 | - `key`: rng 152 | - `sample_size`: size of minibatch to sample 153 | 154 | Returns: 155 | - (chex.ArrayTree): minibatch of size (sample_size, ...) 156 | """ 157 | masked_weights = jnp.logical_and( 158 | state.populated, 159 | state.has_reward 160 | ).reshape(-1) 161 | 162 | num_partitions = state.populated.shape[0] 163 | num_batches = state.populated.shape[1] 164 | 165 | indices = jax.random.choice( 166 | key, 167 | self.capacity * num_partitions * num_batches, 168 | shape=(sample_size,), 169 | replace=False, 170 | p = masked_weights / masked_weights.sum() 171 | ) 172 | 173 | partition_indices, batch_indices, item_indices = jnp.unravel_index( 174 | indices, 175 | (num_partitions, num_batches, self.capacity) 176 | ) 177 | 178 | sampled_buffer_items = jax.tree_util.tree_map( 179 | lambda x: x[partition_indices, batch_indices, item_indices], 180 | state.buffer 181 | ) 182 | 183 | return sampled_buffer_items 184 | 185 | 186 | def init(self, batch_size: int, template_experience: BaseExperience) -> ReplayBufferState: 187 | """Initializes the replay buffer state. 188 | 189 | Args: 190 | - `batch_size`: number of parallel environments 191 | - `template_experience`: template experience data structure 192 | - just used to determine the shape of the replay buffer data 193 | 194 | Returns: 195 | - (ReplayBufferState): initialized replay buffer state 196 | """ 197 | return ReplayBufferState( 198 | next_idx = jnp.zeros((batch_size,), dtype=jnp.int32), 199 | episode_start_idx = jnp.zeros((batch_size,), dtype=jnp.int32), 200 | buffer = jax.tree_util.tree_map( 201 | lambda x: jnp.zeros((batch_size, self.capacity, *x.shape), dtype=x.dtype), 202 | template_experience 203 | ), 204 | populated = jnp.full((batch_size, self.capacity,), fill_value=False, dtype=jnp.bool_), 205 | has_reward = jnp.full((batch_size, self.capacity,), fill_value=True, dtype=jnp.bool_), 206 | ) 207 | -------------------------------------------------------------------------------- /core/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero/74cdb9cbd736b396141df4a07db268f28b687f9a/core/networks/__init__.py -------------------------------------------------------------------------------- /core/networks/azresnet.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | 4 | import flax.linen as nn 5 | 6 | 7 | @dataclass 8 | class AZResnetConfig: 9 | """Configuration for AlphaZero ResNet model: 10 | - `policy_head_out_size`: output size of the policy head (number of actions) 11 | - `num_blocks`: number of residual blocks 12 | - `num_channels`: number of channels in each residual block 13 | """ 14 | policy_head_out_size: int 15 | num_blocks: int 16 | num_channels: int 17 | 18 | 19 | class ResidualBlock(nn.Module): 20 | """Residual block for AlphaZero ResNet model. 21 | - `channels`: number of channels""" 22 | channels: int 23 | 24 | @nn.compact 25 | def __call__(self, x, train: bool): 26 | y = nn.Conv(features=self.channels, kernel_size=(3,3), strides=(1,1), padding='SAME', use_bias=False)(x) 27 | y = nn.BatchNorm(use_running_average=not train)(y) 28 | y = nn.relu(y) 29 | y = nn.Conv(features=self.channels, kernel_size=(3,3), strides=(1,1), padding='SAME', use_bias=False)(y) 30 | y = nn.BatchNorm(use_running_average=not train)(y) 31 | return nn.relu(x + y) 32 | 33 | 34 | class AZResnet(nn.Module): 35 | """Implements the AlphaZero ResNet model. 36 | - `config`: network configuration""" 37 | config: AZResnetConfig 38 | 39 | @nn.compact 40 | def __call__(self, x, train: bool): 41 | # initial conv layer 42 | x = nn.Conv(features=self.config.num_channels, kernel_size=(3,3), strides=(1,1), padding='SAME', use_bias=False)(x) 43 | x = nn.BatchNorm(use_running_average=not train)(x) 44 | x = nn.relu(x) 45 | 46 | # residual blocks 47 | for _ in range(self.config.num_blocks): 48 | x = ResidualBlock(channels=self.config.num_channels)(x, train=train) 49 | 50 | # policy head 51 | policy = nn.Conv(features=2, kernel_size=(1,1), strides=(1,1), padding='SAME', use_bias=False)(x) 52 | policy = nn.BatchNorm(use_running_average=not train)(policy) 53 | policy = nn.relu(policy) 54 | policy = policy.reshape((policy.shape[0], -1)) 55 | policy = nn.Dense(features=self.config.policy_head_out_size)(policy) 56 | 57 | # value head 58 | value = nn.Conv(features=1, kernel_size=(1,1), strides=(1,1), padding='SAME', use_bias=False)(x) 59 | value = nn.BatchNorm(use_running_average=not train)(value) 60 | value = nn.relu(value) 61 | value = value.reshape((value.shape[0], -1)) 62 | value = nn.Dense(features=1)(value) 63 | value = nn.tanh(value) 64 | 65 | return policy, value 66 | -------------------------------------------------------------------------------- /core/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero/74cdb9cbd736b396141df4a07db268f28b687f9a/core/testing/__init__.py -------------------------------------------------------------------------------- /core/testing/tester.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import partial 3 | from typing import Callable, Dict, Optional, Tuple 4 | 5 | import chex 6 | from chex import dataclass 7 | import jax 8 | 9 | from core.common import partition 10 | from core.evaluators.evaluator import Evaluator 11 | from core.types import EnvInitFn, EnvStepFn 12 | 13 | 14 | @dataclass(frozen=True) 15 | class TestState: 16 | """Base class for TestState.""" 17 | 18 | 19 | class BaseTester: 20 | """Base class for Testers. 21 | A Tester is used to evaluate the performance of an agent in an environment, 22 | in some cases against one or more opponents. 23 | 24 | A Tester may maintain its own internal state. 25 | """ 26 | def __init__(self, num_keys: int, epochs_per_test: int = 1, render_fn: Optional[Callable] = None, 27 | render_dir: str = '/tmp/turbozero/', name: Optional[str] = None): 28 | """ 29 | Args: 30 | - `num_keys`: number of keys to use for tester 31 | - often equal to number of episodes 32 | - provided on initialization to ensure reproducibility, even when a different number of devices is used 33 | - perhaps there is a better way to enforce this 34 | - `epochs_per_test`: number of epochs between each test 35 | - `render_fn`: (optional) function to render frames from a test episode to a .gif 36 | - `render_dir`: directory to save .gifs 37 | - `name`: (optional) name of the tester (used for logging and differentiating between testers) 38 | - defaults to the class name 39 | """ 40 | self.num_keys = num_keys 41 | self.epochs_per_test = epochs_per_test 42 | self.render_fn = render_fn 43 | self.render_dir = render_dir 44 | if name is None: 45 | name = self.__class__.__name__ 46 | self.name = name 47 | 48 | 49 | def init(self, **kwargs) -> TestState: #pylint: disable=unused-argument 50 | """Initializes the internal state of the Tester.""" 51 | return TestState() 52 | 53 | 54 | def check_size_compatibilities(self, num_devices: int) -> None: #pylint: disable=unused-argument 55 | """Checks if tester configuration is compatible with number of devices being utilized.""" 56 | return 57 | 58 | 59 | def split_keys(self, key: chex.PRNGKey, num_devices: int) -> chex.PRNGKey: 60 | """Splits keys across devices. 61 | Args: 62 | - `key`: rng 63 | - `num_devices`: number of devices 64 | 65 | Returns: 66 | - (chex.PRNGKey): keys split across devices 67 | """ 68 | # partition keys across devices (do this here so its reproducible no matter the number of devices used) 69 | keys = jax.random.split(key, self.num_keys) 70 | keys = partition(keys, num_devices) 71 | return keys 72 | 73 | 74 | def run(self, key: chex.PRNGKey, epoch_num: int, max_steps: int, num_devices: int, #pylint: disable=unused-argument 75 | env_step_fn: EnvStepFn, env_init_fn: EnvInitFn, evaluator: Evaluator, state: TestState, 76 | params: chex.ArrayTree, *args) -> Tuple[TestState, Dict, str]: 77 | """Runs the test, if the current epoch is an epoch that should be tested on 78 | 79 | If a render function is provided, saves a .gif of the first episode of the test. 80 | 81 | Args: 82 | - `key`: rng 83 | - `epoch_num`: current epoch number 84 | - `max_steps`: maximum number of steps per episode 85 | - `num_devices`: number of devices 86 | - `env_step_fn`: environment step function 87 | - `env_init_fn`: environment initialization function 88 | - `evaluator`: evaluator used by agent 89 | - `state`: internal state of the tester 90 | - `params`: nn parameters used by agent 91 | 92 | Returns: 93 | - (TestState, Dict, str) 94 | - updated internal state of the tester 95 | - metrics from the test 96 | - path to .gif of the first episode of the test (if render function provided, otherwise None) 97 | """ 98 | # split keys across devices 99 | keys = self.split_keys(key, num_devices) 100 | 101 | if epoch_num % self.epochs_per_test == 0: 102 | # run test 103 | state, metrics, frames, p_ids = self.test(max_steps, env_step_fn, \ 104 | env_init_fn, evaluator, keys, state, params) 105 | 106 | if self.render_fn is not None: 107 | # render first episode to .gif 108 | # get frames from first episode 109 | frames = jax.tree_map(lambda x: x[0], frames) 110 | # get player ids from first episode 111 | p_ids = p_ids[0] 112 | # get list of frames 113 | frame_list = [jax.device_get(jax.tree_map(lambda x: x[i], frames)) for i in range(max_steps)] 114 | # render frames to .gif 115 | path_to_rendering = self.render_fn(frame_list, p_ids, f"{self.name}_{epoch_num}", self.render_dir) 116 | else: 117 | path_to_rendering = None 118 | return state, metrics, path_to_rendering 119 | 120 | 121 | @partial(jax.pmap, axis_name='d', static_broadcasted_argnums=(0, 1, 2, 3, 4)) 122 | def test(self, max_steps: int, env_step_fn: EnvStepFn, env_init_fn: EnvInitFn, evaluator: Evaluator, 123 | keys: chex.PRNGKey, state: TestState, params: chex.ArrayTree) -> Tuple[TestState, Dict, chex.ArrayTree, chex.Array]: 124 | """Run the test implemented by the Tester. Parallelized across devices. 125 | 126 | Implemented by subclasses. 127 | 128 | Args: 129 | - `max_steps`: maximum number of steps per episode 130 | - `env_step_fn`: environment step function 131 | - `env_init_fn`: environment initialization function 132 | - `evaluator`: evaluator used by agent 133 | - `keys`: rng 134 | - `state`: internal state of the tester 135 | - `params`: nn parameters used by agent 136 | 137 | Returns: 138 | - (TestState, Dict, chex.ArrayTree, chex.Array) 139 | - updated internal state of the tester 140 | - metrics from the test 141 | - frames from the test (used to produce renderings) 142 | - player ids from the test (used to produce renderings) 143 | """ 144 | raise NotImplementedError() 145 | -------------------------------------------------------------------------------- /core/testing/two_player_baseline.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | from functools import partial 6 | from typing import Dict, Optional, Tuple 7 | 8 | import chex 9 | import jax 10 | import jax.numpy as jnp 11 | from core.common import GameFrame, two_player_game 12 | from core.evaluators.evaluator import Evaluator 13 | from core.testing.tester import BaseTester, TestState 14 | from core.types import EnvInitFn, EnvStepFn 15 | 16 | 17 | class TwoPlayerBaseline(BaseTester): 18 | """Implements a tester that evaluates an agent against a baseline evaluator in a two-player game.""" 19 | 20 | def __init__(self, num_episodes: int, baseline_evaluator: Evaluator, baseline_params: Optional[chex.ArrayTree] = None, 21 | *args, **kwargs): 22 | """ 23 | Args: 24 | - `num_episodes`: number of episodes to evaluate against the baseline 25 | - `baseline_evaluator`: the baseline evaluator to evaluate against 26 | - `baseline_params`: (optional) the parameters of the baseline evaluator 27 | """ 28 | super().__init__(num_keys=num_episodes, *args, **kwargs) 29 | self.num_episodes = num_episodes 30 | self.baseline_evaluator = baseline_evaluator 31 | if baseline_params is None: 32 | baseline_params = jnp.array([]) 33 | self.baseline_params = baseline_params 34 | 35 | 36 | def check_size_compatibilities(self, num_devices: int) -> None: 37 | """Checks if tester configuration is compatible with number of devices being utilized. 38 | 39 | Args: 40 | - `num_devices`: number of devices 41 | """ 42 | if self.num_episodes % num_devices != 0: 43 | raise ValueError(f"{self.__class__.__name__}: number of episodes ({self.num_episodes}) must be divisible by number of devices ({num_devices})") 44 | 45 | 46 | @partial(jax.pmap, axis_name='d', static_broadcasted_argnums=(0, 1, 2, 3, 4)) 47 | def test(self, max_steps: int, env_step_fn: EnvStepFn, env_init_fn: EnvInitFn, evaluator: Evaluator, 48 | keys: chex.PRNGKey, state: TestState, params: chex.ArrayTree) -> Tuple[TestState, Dict, GameFrame, chex.Array]: 49 | """Test the agent against the baseline evaluator in a two-player game. 50 | 51 | Args: 52 | - `max_steps`: maximum number of steps per episode 53 | - `env_step_fn`: environment step function 54 | - `env_init_fn`: environment initialization function 55 | - `evaluator`: the agent evaluator 56 | - `keys`: rng 57 | - `state`: internal state of the tester 58 | - `params`: nn parameters used by agent 59 | 60 | Returns: 61 | - (TestState, Dict, GameFrame, chex.Array) 62 | - updated internal state of the tester 63 | - metrics from the test 64 | - frames from the first episode of the test 65 | - player ids from the first episode of the test 66 | """ 67 | 68 | game_fn = partial(two_player_game, 69 | evaluator_1 = evaluator, 70 | evaluator_2 = self.baseline_evaluator, 71 | params_1 = params, 72 | params_2 = self.baseline_params, 73 | env_step_fn = env_step_fn, 74 | env_init_fn = env_init_fn, 75 | max_steps = max_steps 76 | ) 77 | 78 | results, frames, p_ids = jax.vmap(game_fn)(keys) 79 | frames = jax.tree_map(lambda x: x[0], frames) 80 | p_ids = p_ids[0] 81 | 82 | avg = results[:, 0].mean() 83 | 84 | metrics = { 85 | f"{self.name}_avg_outcome": avg 86 | } 87 | 88 | return state, metrics, frames, p_ids 89 | -------------------------------------------------------------------------------- /core/testing/two_player_tester.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import partial 3 | from typing import Dict, Tuple 4 | 5 | import chex 6 | from chex import dataclass 7 | import jax 8 | 9 | from core.common import two_player_game 10 | from core.evaluators.evaluator import Evaluator 11 | from core.testing.tester import BaseTester, TestState 12 | from core.types import EnvInitFn, EnvStepFn 13 | 14 | 15 | @dataclass(frozen=True) 16 | class TwoPlayerTestState(TestState): 17 | """Internal state of a TwoPlayerTester. Stores the best parameters found so far. 18 | - `best_params`: best performing parameters 19 | """ 20 | best_params: chex.ArrayTree 21 | 22 | 23 | class TwoPlayerTester(BaseTester): 24 | """Implements a tester that evaluates an agent against the best performing parameters 25 | found so far in a two-player game.""" 26 | def __init__(self, num_episodes: int, *args, **kwargs): 27 | """ 28 | Args: 29 | - `num_episodes`: number of episodes to play in each test 30 | """ 31 | super().__init__(*args, num_keys=num_episodes, **kwargs) 32 | self.num_episodes = num_episodes 33 | 34 | 35 | def init(self, params: chex.ArrayTree, **kwargs) -> TwoPlayerTestState: #pylint: disable=unused-argument 36 | """Initializes the internal state of the TwoPlayerTester. 37 | Args: 38 | - `params`: initial parameters to store as the best performing 39 | - can just be the initial parameters of the agent 40 | """ 41 | return TwoPlayerTestState(best_params=params) 42 | 43 | 44 | def check_size_compatibilities(self, num_devices: int) -> None: 45 | """Checks if tester configuration is compatible with number of devices being utilized. 46 | 47 | Args: 48 | - `num_devices`: number of devices 49 | """ 50 | if self.num_episodes % num_devices != 0: 51 | raise ValueError(f"{self.__class__.__name__}: number of episodes ({self.num_episodes}) must be divisible by number of devices ({num_devices})") 52 | 53 | 54 | @partial(jax.pmap, axis_name='d', static_broadcasted_argnums=(0, 1, 2, 3, 4)) 55 | def test(self, max_steps: int, env_step_fn: EnvStepFn, env_init_fn: EnvInitFn, evaluator: Evaluator, 56 | keys: chex.PRNGKey, state: TwoPlayerTestState, params: chex.ArrayTree) -> Tuple[TwoPlayerTestState, Dict, chex.ArrayTree, chex.Array]: 57 | """Test the agent against the best performing parameters found so far in a two-player game. 58 | 59 | Args: 60 | - `max_steps`: maximum number of steps per episode 61 | - `env_step_fn`: environment step function 62 | - `env_init_fn`: environment initialization function 63 | - `evaluator`: the agent evaluator 64 | - `keys`: rng 65 | - `state`: internal state of the tester 66 | - `params`: nn parameters used by agent 67 | 68 | Returns: 69 | - (TwoPlayerTestState, Dict, chex.ArrayTree, chex.Array) 70 | - updated internal state of the tester 71 | - metrics from the test 72 | - frames from the test (used for rendering) 73 | - player ids from the test (used for rendering) 74 | """ 75 | 76 | game_fn = partial(two_player_game, 77 | evaluator_1 = evaluator, 78 | evaluator_2 = evaluator, 79 | params_1 = params, 80 | params_2 = state.best_params, 81 | env_step_fn = env_step_fn, 82 | env_init_fn = env_init_fn, 83 | max_steps = max_steps 84 | ) 85 | 86 | results, frames, p_ids = jax.vmap(game_fn)(keys) 87 | frames = jax.tree_map(lambda x: x[0], frames) 88 | p_ids = p_ids[0] 89 | 90 | avg = results[:, 0].mean() 91 | 92 | metrics = { 93 | f"{self.name}_avg_outcome": avg 94 | } 95 | 96 | best_params = jax.lax.cond( 97 | avg > 0.0, 98 | lambda _: params, 99 | lambda _: state.best_params, 100 | None 101 | ) 102 | 103 | return state.replace(best_params=best_params), metrics, frames, p_ids 104 | -------------------------------------------------------------------------------- /core/testing/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | import xml.etree.ElementTree as ET 5 | import cairosvg 6 | from PIL import Image 7 | 8 | def render_pgx_2p(frames, p_ids, title, frame_dir, p1_label='Black', p2_label='White', duration=900): 9 | """really messy render function for rendering frames from a 2-player game 10 | from a PGX environment to a .gif""" 11 | digit_length = len(str(len(frames))) 12 | trained_agent_color = p1_label if frames[0].env_state.current_player == p_ids[0] else p2_label 13 | opponent_color = p2_label if trained_agent_color == p1_label else p1_label 14 | agent_win = False 15 | opp_win = False 16 | draw = False 17 | images = [] 18 | for i,frame in enumerate(frames): 19 | env_state = frame.env_state 20 | if frame.completed.item(): 21 | num = '9' * digit_length 22 | env_state.save_svg(f"{frame_dir}/{num}.svg", color_theme='dark') 23 | agent_win = frame.outcomes[p_ids[0]] > frame.outcomes[p_ids[1]] 24 | opp_win = frame.outcomes[p_ids[1]] > frame.outcomes[p_ids[0]] 25 | draw = frame.outcomes[0] == frame.outcomes[1] 26 | 27 | else: 28 | num = str(i).zfill(digit_length) 29 | env_state.save_svg(f"{frame_dir}/{num}.svg", color_theme='dark') 30 | 31 | tree = ET.parse(f"{frame_dir}/{num}.svg") 32 | root = tree.getroot() 33 | 34 | viewBox = root.attrib.get('viewBox', None) 35 | if viewBox: 36 | viewBox = viewBox.split() 37 | viewBox = [float(v) for v in viewBox] 38 | original_height = viewBox[3] 39 | else: 40 | original_width = float(root.attrib.get('width', 0)) 41 | original_height = float(root.attrib.get('height', 0)) 42 | 43 | new_height = original_height * 1.2 44 | # Update the viewBox and height attributes 45 | if viewBox: 46 | 47 | viewBox[3] = new_height 48 | root.attrib['viewBox'] = ' '.join(map(str, viewBox)) 49 | root.attrib['height'] = str(new_height) 50 | 51 | # Create a new text element 52 | p1_text = ET.Element('ns0:text', x=str(0.01 * original_width), y=str(original_height * 1.05), fill='white', style='font-family: Arial;') 53 | p2_text = ET.Element('ns0:text', x=str(0.01 * original_width), y=str(original_height * 1.15), fill='white', style='font-family: Arial;') 54 | emoji = "[W]" if agent_win else "[L]" if opp_win else "[D]" if draw else "" 55 | agent_text = f"{emoji} Trained Agent ({trained_agent_color}): {'+' if frame.p1_value_estimate > 0 else ''}{frame.p1_value_estimate:.4f}" 56 | emoji = "[W]" if opp_win else "[L]" if agent_win else "[D]" if draw else "" 57 | opp_text = f"{emoji} Opponent ({opponent_color}): {'+' if frame.p2_value_estimate > 0 else ''}{frame.p2_value_estimate:.4f}" 58 | p1_text.text = agent_text 59 | p2_text.text = opp_text 60 | 61 | new_area = ET.Element('ns0:rect', fill='#1e1e1e', height=str(new_height - original_height), width=str(original_width), x='0', y=str(original_height)) 62 | root.append(new_area) 63 | 64 | root.append(p1_text) 65 | root.append(p2_text) 66 | 67 | tree.write(f"{frame_dir}/{num}.svg", encoding='utf-8', xml_declaration=True) 68 | 69 | cairosvg.svg2png(url=f"{frame_dir}/{num}.svg", write_to=f"{frame_dir}/{num}.png") 70 | images.append(f"{frame_dir}/{num}.png") 71 | if frame.completed.item(): 72 | break 73 | 74 | 75 | images = [Image.open(png) for png in images] 76 | 77 | gif_path = f"{frame_dir}/{title}.gif" 78 | images[0].save(gif_path, save_all=True, append_images=images[1:] + ([images[-1]] * 2), duration=duration, loop=0) 79 | 80 | os.system(f"rm {frame_dir}/*.svg") 81 | os.system(f"rm {frame_dir}/*.png") 82 | return gif_path 83 | -------------------------------------------------------------------------------- /core/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero/74cdb9cbd736b396141df4a07db268f28b687f9a/core/training/__init__.py -------------------------------------------------------------------------------- /core/training/loss_fns.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Tuple 3 | 4 | import chex 5 | import jax 6 | import jax.numpy as jnp 7 | import optax 8 | from flax.training.train_state import TrainState 9 | 10 | from core.memory.replay_memory import BaseExperience 11 | 12 | 13 | def az_default_loss_fn(params: chex.ArrayTree, train_state: TrainState, experience: BaseExperience, 14 | l2_reg_lambda: float = 0.0001) -> Tuple[chex.Array, Tuple[chex.ArrayTree, optax.OptState]]: 15 | """ Implements the default AlphaZero loss function. 16 | 17 | = Policy Loss + Value Loss + L2 Regularization 18 | Policy Loss: Cross-entropy loss between predicted policy and target policy 19 | Value Loss: L2 loss between predicted value and target value 20 | 21 | Args: 22 | - `params`: the parameters of the neural network 23 | - `train_state`: flax TrainState (holds optimizer and other state) 24 | - `experience`: experience sampled from replay buffer 25 | - stores the observation, target policy, target value 26 | - `l2_reg_lambda`: L2 regularization weight (default = 1e-4) 27 | 28 | Returns: 29 | - (loss, (aux_metrics, updates)) 30 | - `loss`: total loss 31 | - `aux_metrics`: auxiliary metrics (policy_loss, value_loss) 32 | - `updates`: optimizer updates 33 | """ 34 | 35 | # get batch_stats if using batch_norm 36 | variables = {'params': params, 'batch_stats': train_state.batch_stats} \ 37 | if hasattr(train_state, 'batch_stats') else {'params': params} 38 | mutables = ['batch_stats'] if hasattr(train_state, 'batch_stats') else [] 39 | 40 | # get predictions 41 | (pred_policy, pred_value), updates = train_state.apply_fn( 42 | variables, 43 | x=experience.observation_nn, 44 | train=True, 45 | mutable=mutables 46 | ) 47 | 48 | # set invalid actions in policy to -inf 49 | pred_policy = jnp.where( 50 | experience.policy_mask, 51 | pred_policy, 52 | jnp.finfo(jnp.float32).min 53 | ) 54 | 55 | # compute policy loss 56 | policy_loss = optax.softmax_cross_entropy(pred_policy, experience.policy_weights).mean() 57 | # select appropriate value from experience.reward 58 | current_player = experience.cur_player_id 59 | target_value = experience.reward[jnp.arange(experience.reward.shape[0]), current_player] 60 | # compute MSE value loss 61 | value_loss = optax.l2_loss(pred_value.squeeze(), target_value).mean() 62 | 63 | # compute L2 regularization 64 | l2_reg = l2_reg_lambda * jax.tree_util.tree_reduce( 65 | lambda x, y: x + y, 66 | jax.tree_map( 67 | lambda x: (x ** 2).sum(), 68 | params 69 | ) 70 | ) 71 | 72 | # total loss 73 | loss = policy_loss + value_loss + l2_reg 74 | aux_metrics = { 75 | 'policy_loss': policy_loss, 76 | 'value_loss': value_loss 77 | } 78 | return loss, (aux_metrics, updates) 79 | -------------------------------------------------------------------------------- /core/training/train.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import partial 3 | import os 4 | import shutil 5 | from typing import Any, List, Optional, Tuple 6 | 7 | import chex 8 | from chex import dataclass 9 | import flax 10 | from flax.training.train_state import TrainState 11 | from flax.training import orbax_utils 12 | import jax 13 | import jax.numpy as jnp 14 | import optax 15 | import orbax.checkpoint as ocp 16 | import wandb 17 | 18 | from core.common import partition, step_env_and_evaluator 19 | from core.evaluators.evaluator import Evaluator 20 | from core.memory.replay_memory import BaseExperience, EpisodeReplayBuffer, ReplayBufferState 21 | from core.testing.tester import BaseTester, TestState 22 | from core.types import DataTransformFn, EnvInitFn, EnvStepFn, ExtractModelParamsFn, LossFn, StateToNNInputFn, StepMetadata 23 | 24 | 25 | @dataclass(frozen=True) 26 | class CollectionState: 27 | """Stores state of self-play episode collection. Persists across generations. 28 | - `eval_state`: state of the evaluator 29 | - `env_state`: state of the environment 30 | - `buffer_state`: state of the replay buffer 31 | - `metadata`: metadata of the current environment state 32 | """ 33 | eval_state: chex.ArrayTree 34 | env_state: chex.ArrayTree 35 | buffer_state: ReplayBufferState 36 | metadata: StepMetadata 37 | 38 | @dataclass(frozen=True) 39 | class TrainLoopOutput: 40 | """ 41 | Stores the state of the training loop. 42 | collection_state is included to access replay memory. 43 | - `collection_state`: state of self-play episode collection. 44 | - `train_state`: flax TrainState, holds optimizer state, model params 45 | - `test_states`: states of testers 46 | - `cur_epoch`: current epoch num 47 | """ 48 | collection_state: CollectionState 49 | train_state: TrainState 50 | test_states: List[TestState] 51 | cur_epoch: int 52 | 53 | 54 | class TrainStateWithBS(TrainState): 55 | """Custom flax TrainState to handle BatchNorm""" 56 | batch_stats: chex.ArrayTree 57 | 58 | 59 | def extract_params(state: TrainState) -> chex.ArrayTree: 60 | """Extracts model parameters from TrainState. 61 | 62 | Args: 63 | - `state`: TrainState containing model parameters 64 | 65 | Returns: 66 | - (chex.ArrayTree): model parameters 67 | """ 68 | if hasattr(state, 'batch_stats'): 69 | return {'params': state.params, 'batch_stats': state.batch_stats} 70 | return {'params': state.params} 71 | 72 | 73 | class Trainer: 74 | """Implements a training loop for AlphaZero. 75 | Maintains state across self-play game collection, training, and testing.""" 76 | 77 | def __init__(self, 78 | batch_size: int, 79 | train_batch_size: int, 80 | warmup_steps: int, 81 | collection_steps_per_epoch: int, 82 | train_steps_per_epoch: int, 83 | nn: flax.linen.Module, 84 | loss_fn: LossFn, 85 | optimizer: optax.GradientTransformation, 86 | evaluator: Evaluator, 87 | memory_buffer: EpisodeReplayBuffer, 88 | max_episode_steps: int, 89 | env_step_fn: EnvStepFn, 90 | env_init_fn: EnvInitFn, 91 | state_to_nn_input_fn: StateToNNInputFn, 92 | testers: List[BaseTester], 93 | evaluator_test: Optional[Evaluator] = None, 94 | data_transform_fns: List[DataTransformFn] = [], 95 | extract_model_params_fn: Optional[ExtractModelParamsFn] = extract_params, 96 | wandb_project_name: str = "", 97 | ckpt_dir: str = "/tmp/turbozero_checkpoints", 98 | max_checkpoints: int = 2, 99 | num_devices: Optional[int] = None, 100 | wandb_run: Optional[Any] = None, 101 | extra_wandb_config: Optional[dict] = None 102 | ): 103 | """ 104 | Args: 105 | - `batch_size`: batch size for self-play games 106 | - `train_batch_size`: minibatch size for training steps 107 | - `warmup_steps`: # of steps (per batch) to collect via self-play prior to entering the training loop. 108 | - This is used to populate the replay memory with some initial samples 109 | - `collection_steps_per_epoch`: # of steps (per batch) to collect via self-play in each epoch 110 | - `train_steps_per_epoch`: # of training steps to take in each epoch 111 | - `nn`: flax.linen.Module containing configured neural network 112 | - `loss_fn`: loss function for training (see core.training.loss_fns) 113 | - `optimizer`: optax optimizer 114 | - `evaluator`: the `Evaluator` to use during self-play 115 | - `memory_buffer`: replay memory buffer class, used to store self-play experiences 116 | - `max_episode_steps`: maximum number of steps in an episode 117 | - `env_step_fn`: environment step function (env_state, action) -> (new_env_state, metadata) 118 | - `env_init_fn`: environment initialization function (key) -> (env_state, metadata) 119 | - `state_to_nn_input_fn`: function to convert environment state to neural network input 120 | - `testers`: list of testers to evaluate the agent against (see core.testing.tester) 121 | - `evaluator_test`: (optional) evaluator to use during testing. If not provided, `evaluator` is used. 122 | - `data_transform_fns`: (optional) list of data transform functions to apply to self-play experiences (e.g. rotation, reflection, etc.) 123 | - `extract_model_params_fn`: (optional) function to extract model parameters from TrainState 124 | - `wandb_project_name`: (optional) name of wandb project to log to 125 | - `ckpt_dir`: directory to save checkpoints 126 | - `max_checkpoints`: maximum number of checkpoints to keep 127 | - `num_devices`: (optional) number of devices to use, defaults to jax.local_device_count() 128 | - `wandb_run`: (optional) wandb run object, will continue logging to this run if passed, else a new run is initialized 129 | - `extra_wandb_config`: (optional) extra config to pass to wandb 130 | """ 131 | self.num_devices = num_devices if num_devices is not None else jax.local_device_count() 132 | # environment 133 | self.env_step_fn = env_step_fn 134 | self.env_init_fn = env_init_fn 135 | self.max_episode_steps = max_episode_steps 136 | self.template_env_state = self.make_template_env_state() 137 | # nn 138 | self.state_to_nn_input_fn = state_to_nn_input_fn 139 | self.nn = nn 140 | self.loss_fn = loss_fn 141 | self.optimizer = optimizer 142 | self.extract_model_params_fn = extract_model_params_fn 143 | # selfplay 144 | self.batch_size = batch_size 145 | self.warmup_steps = warmup_steps 146 | self.collection_steps_per_epoch = collection_steps_per_epoch 147 | self.memory_buffer = memory_buffer 148 | self.evaluator_train = evaluator 149 | self.transform_fns = data_transform_fns 150 | self.step_train = partial(step_env_and_evaluator, 151 | evaluator=self.evaluator_train, 152 | env_step_fn=self.env_step_fn, 153 | env_init_fn=self.env_init_fn, 154 | max_steps=self.max_episode_steps 155 | ) 156 | # training 157 | self.train_steps_per_epoch = train_steps_per_epoch 158 | self.train_batch_size = train_batch_size 159 | # testing 160 | self.testers = testers 161 | self.evaluator_test = evaluator_test if evaluator_test is not None else evaluator 162 | self.step_test = partial(step_env_and_evaluator, 163 | evaluator=self.evaluator_test, 164 | env_step_fn=self.env_step_fn, 165 | env_init_fn=self.env_init_fn, 166 | max_steps=self.max_episode_steps 167 | ) 168 | # checkpoints 169 | self.ckpt_dir = ckpt_dir 170 | options = ocp.CheckpointManagerOptions(max_to_keep=max_checkpoints, create=True) 171 | self.checkpoint_manager = ocp.CheckpointManager( 172 | ocp.test_utils.erase_and_create_empty(ckpt_dir), options=options) 173 | # wandb 174 | self.wandb_project_name = wandb_project_name 175 | self.use_wandb = wandb_project_name != "" 176 | if self.use_wandb: 177 | if wandb_run is not None: 178 | self.run = wandb_run 179 | else: 180 | self.run = self.init_wandb(wandb_project_name, extra_wandb_config) 181 | else: 182 | self.run = None 183 | # check batch sizes, etc. are compatible with number of devices 184 | self.check_size_compatibilities() 185 | 186 | 187 | def init_wandb(self, project_name: str, extra_wandb_config: Optional[dict]): 188 | """Initializes wandb run. 189 | Args: 190 | - `project_name`: name of wandb project 191 | - `extra_wandb_config`: (optional) extra config to pass to wandb 192 | 193 | Returns: 194 | - (wandb.Run): wandb run 195 | """ 196 | if extra_wandb_config is None: 197 | extra_wandb_config = {} 198 | return wandb.init( 199 | project=project_name, 200 | config={**self.get_config(), **extra_wandb_config} 201 | ) 202 | 203 | 204 | def check_size_compatibilities(self): 205 | """Checks if batch sizes, etc. are compatible with number of devices. 206 | Calls check_size_compatibilities on each tester.""" 207 | 208 | err_fmt = "Batch size must be divisible by the number of devices. Got {b} batch size and {d} devices." 209 | # check train batch size 210 | if self.train_batch_size % self.num_devices != 0: 211 | raise ValueError(err_fmt.format(b=self.train_batch_size, d=self.num_devices)) 212 | # check collection batch size 213 | if self.batch_size % self.num_devices != 0: 214 | raise ValueError(err_fmt.format(b=self.batch_size, d=self.num_devices)) 215 | # check testers 216 | for tester in self.testers: 217 | tester.check_size_compatibilities(self.num_devices) 218 | 219 | 220 | @partial(jax.pmap, axis_name='d', static_broadcasted_argnums=(0,)) 221 | def init_train_state(self, key: jax.random.PRNGKey) -> TrainState: 222 | """Initializes the training state (params, optimizer, etc.) partitions across devices. 223 | 224 | Args: 225 | - `key`: rng 226 | 227 | Returns: 228 | - (TrainState): initialized training state 229 | """ 230 | # get template env state 231 | sample_env_state = self.make_template_env_state() 232 | # get sample nn input 233 | sample_obs = self.state_to_nn_input_fn(sample_env_state) 234 | # initialize nn parameters 235 | variables = self.nn.init(key, sample_obs[None, ...], train=False) 236 | params = variables['params'] 237 | # handle batchnorm 238 | if 'batch_stats' in variables: 239 | return TrainStateWithBS.create( 240 | apply_fn=self.nn.apply, 241 | params=params, 242 | tx=self.optimizer, 243 | batch_stats=variables['batch_stats'] 244 | ) 245 | # init TrrainState 246 | return TrainState.create( 247 | apply_fn=self.nn.apply, 248 | params=params, 249 | tx=self.optimizer, 250 | ) 251 | 252 | 253 | def get_config(self): 254 | """Returns a dictionary of the configuration of the trainer. Used for logging/wand.""" 255 | return { 256 | 'batch_size': self.batch_size, 257 | 'train_batch_size': self.train_batch_size, 258 | 'warmup_steps': self.warmup_steps, 259 | 'collection_steps_per_epoch': self.collection_steps_per_epoch, 260 | 'train_steps_per_epoch': self.train_steps_per_epoch, 261 | 'num_devices': self.num_devices, 262 | 'evaluator_train': self.evaluator_train.__class__.__name__, 263 | 'evaluator_train_config': self.evaluator_train.get_config(), 264 | 'evaluator_test': self.evaluator_test.__class__.__name__, 265 | 'evaluator_test_config': self.evaluator_test.get_config(), 266 | 'memory_buffer': self.memory_buffer.__class__.__name__, 267 | 'memory_buffer_config': self.memory_buffer.get_config(), 268 | } 269 | 270 | 271 | def collect(self, 272 | key: jax.random.PRNGKey, 273 | state: CollectionState, 274 | params: chex.ArrayTree 275 | ) -> CollectionState: 276 | """ 277 | - Collects self-play data for a single step. 278 | - Stores experience in replay buffer. 279 | - Resets environment/evaluator if episode is terminated. 280 | 281 | Args: 282 | - `key`: rng 283 | - `state`: current collection state (environment, evaluator, replay buffer) 284 | - `params`: model parameters 285 | 286 | Returns: 287 | - (CollectionState): updated collection state 288 | """ 289 | # step environment and evaluator 290 | eval_output, new_env_state, new_metadata, terminated, truncated, rewards = \ 291 | self.step_train( 292 | key = key, 293 | env_state = state.env_state, 294 | env_state_metadata = state.metadata, 295 | eval_state = state.eval_state, 296 | params = params 297 | ) 298 | 299 | # store experience in replay buffer 300 | buffer_state = self.memory_buffer.add_experience( 301 | state = state.buffer_state, 302 | experience = BaseExperience( 303 | observation_nn=self.state_to_nn_input_fn(state.env_state), 304 | policy_mask=state.metadata.action_mask, 305 | policy_weights=eval_output.policy_weights, 306 | reward=jnp.empty_like(state.metadata.rewards), 307 | cur_player_id=state.metadata.cur_player_id 308 | ) 309 | ) 310 | # apply transforms 311 | for transform_fn in self.transform_fns: 312 | t_policy_mask, t_policy_weights, t_env_state = transform_fn( 313 | state.metadata.action_mask, 314 | eval_output.policy_weights, 315 | state.env_state 316 | ) 317 | buffer_state = self.memory_buffer.add_experience( 318 | state = buffer_state, 319 | experience = BaseExperience( 320 | observation_nn=self.state_to_nn_input_fn(t_env_state), 321 | policy_mask=t_policy_mask, 322 | policy_weights=t_policy_weights, 323 | reward=jnp.empty_like(state.metadata.rewards), 324 | cur_player_id=state.metadata.cur_player_id 325 | ) 326 | ) 327 | # assign rewards to buffer if episode is terminated 328 | buffer_state = jax.lax.cond( 329 | terminated, 330 | lambda s: self.memory_buffer.assign_rewards(s, rewards), 331 | lambda s: s, 332 | buffer_state 333 | ) 334 | # truncate episode experiences in buffer if episode is too long 335 | buffer_state = jax.lax.cond( 336 | truncated, 337 | self.memory_buffer.truncate, 338 | lambda s: s, 339 | buffer_state 340 | ) 341 | # return new collection state 342 | return state.replace( 343 | eval_state=eval_output.eval_state, 344 | env_state=new_env_state, 345 | buffer_state=buffer_state, 346 | metadata=new_metadata 347 | ) 348 | 349 | @partial(jax.pmap, axis_name='d', static_broadcasted_argnums=(0, 4)) 350 | def collect_steps(self, 351 | key: chex.PRNGKey, 352 | state: CollectionState, 353 | params: chex.ArrayTree, 354 | num_steps: int 355 | ) -> CollectionState: 356 | """Collects self-play data for `num_steps` steps. Mapped across devices. 357 | 358 | Args: 359 | - `key`: rng 360 | - `state`: current collection state 361 | - `params`: model parameters 362 | - `num_steps`: number of self-play steps to collect 363 | 364 | Returns: 365 | - (CollectionState): updated collection state 366 | """ 367 | if num_steps > 0: 368 | collect = partial(self.collect, params=params) 369 | keys = jax.random.split(key, num_steps) 370 | return jax.lax.fori_loop( 371 | 0, num_steps, 372 | lambda i, s: collect(keys[i], s), 373 | state 374 | ) 375 | return state 376 | 377 | 378 | @partial(jax.pmap, axis_name='d', static_broadcasted_argnums=(0,)) 379 | def one_train_step(self, ts: TrainState, batch: BaseExperience) -> Tuple[TrainState, dict]: 380 | """Make a single training step. 381 | 382 | Args: 383 | - `ts`: TrainState 384 | - `batch`: minibatch of experiences 385 | 386 | Returns: 387 | - (TrainState, dict): updated TrainState and metrics 388 | """ 389 | # calculate loss, get gradients 390 | grad_fn = jax.value_and_grad(self.loss_fn, has_aux=True) 391 | (loss, (metrics, updates)), grads = grad_fn(ts.params, ts, batch) 392 | # apply gradients 393 | grads = jax.lax.pmean(grads, axis_name='d') 394 | ts = ts.apply_gradients(grads=grads) 395 | # update batchnorm stats 396 | if hasattr(ts, 'batch_stats'): 397 | ts = ts.replace(batch_stats=jax.lax.pmean(updates['batch_stats'], axis_name='d')) 398 | # return updated train state and metrics 399 | metrics = { 400 | **metrics, 401 | 'loss': loss 402 | } 403 | return ts, metrics 404 | 405 | 406 | def train_steps(self, 407 | key: chex.PRNGKey, 408 | collection_state: CollectionState, 409 | train_state: TrainState, 410 | num_steps: int 411 | ) -> Tuple[CollectionState, TrainState, dict]: 412 | """Performs `num_steps` training steps. 413 | Each step consists of sampling a minibatch from the replay buffer and updating the parameters. 414 | 415 | Args: 416 | - `key`: rng 417 | - `collection_state`: current collection state 418 | - `train_state`: current training state 419 | - `num_steps`: number of training steps to perform 420 | 421 | Returns: 422 | - (CollectionState, TrainState, dict): 423 | - updated collection state 424 | - updated training state 425 | - metrics 426 | """ 427 | # get replay memory buffer 428 | buffer_state = collection_state.buffer_state 429 | 430 | batch_metrics = [] 431 | 432 | for _ in range(num_steps): 433 | step_key, key = jax.random.split(key) 434 | # sample from replay memory 435 | batch = self.memory_buffer.sample(buffer_state, step_key, self.train_batch_size) 436 | # reshape into minibatch 437 | batch = jax.tree_map(lambda x: x.reshape((self.num_devices, -1, *x.shape[1:])), batch) 438 | # make training step 439 | train_state, metrics = self.one_train_step(train_state, batch) 440 | # append metrics from step 441 | if metrics: 442 | batch_metrics.append(metrics) 443 | # take mean of metrics across all training steps 444 | if batch_metrics: 445 | metrics = {k: jnp.stack([m[k] for m in batch_metrics]).mean() for k in batch_metrics[0].keys()} 446 | else: 447 | metrics = {} 448 | # return updated collection state, train state, and metrics 449 | return collection_state, train_state, metrics 450 | 451 | 452 | def log_metrics(self, metrics: dict, epoch: int, step: Optional[int] = None): 453 | """Logs metrics to console and wandb. 454 | 455 | Args: 456 | - `metrics`: dictionary of metrics 457 | - `epoch`: current epoch 458 | - `step`: current step 459 | """ 460 | # log to console 461 | metrics_str = {k: f"{v.item():.4f}" for k, v in metrics.items()} 462 | print(f"Epoch {epoch}: {metrics_str}") 463 | # log to wandb 464 | if self.use_wandb: 465 | wandb.log(metrics, step) 466 | 467 | 468 | def save_checkpoint(self, train_state: TrainState, epoch: int) -> None: 469 | """Saves an orbax checkpoint of the training state. 470 | 471 | Args: 472 | - `train_state`: current training state 473 | - `epoch`: current epoch 474 | """ 475 | # convert pmap-sharded train_state to a single-device one 476 | ckpt = jax.tree_map(lambda x: jax.device_get(x), train_state) 477 | # save checkpoint (async) 478 | self.checkpoint_manager.save(epoch, args=ocp.args.StandardSave(ckpt)) 479 | 480 | 481 | def load_train_state_from_checkpoint(self, path_to_checkpoint: str, epoch: int) -> TrainState: 482 | """Loads a training state from a checkpoint. 483 | 484 | Args: 485 | - `path_to_checkpoint`: path to checkpoint 486 | - `epoch`: epoch to load 487 | 488 | Returns: 489 | - (TrainState): loaded training state 490 | """ 491 | # create dummy TrainState 492 | key = jax.random.PRNGKey(0) 493 | init_key, key = jax.random.split(key) 494 | init_keys = jnp.tile(init_key[None], (self.num_devices, 1)) 495 | dummy_state = self.init_train_state(init_keys) 496 | # load checkpoint 497 | train_state = self.checkpoint_manager.restore( 498 | epoch, 499 | items=dummy_state, 500 | directory=path_to_checkpoint, 501 | # allowing for saved checkpoints on different number of jax devices (unsafe) 502 | restore_kwargs={ 503 | 'strict': False, 504 | } 505 | ) 506 | return train_state 507 | 508 | 509 | def make_template_env_state(self) -> chex.ArrayTree: 510 | """Create a template environment state used for initializing data structures 511 | that hold environment states to the correct shape. 512 | 513 | Returns: 514 | - (chex.ArrayTree): template environment state 515 | """ 516 | env_state, _ = self.env_init_fn(jax.random.PRNGKey(0)) 517 | return env_state 518 | 519 | 520 | def make_template_experience(self) -> BaseExperience: 521 | """Create a template experience used for initializing data structures 522 | that hold experiences to the correct shape. 523 | 524 | Returns: 525 | - (BaseExperience): template experience 526 | """ 527 | env_state, metadata = self.env_init_fn(jax.random.PRNGKey(0)) 528 | return BaseExperience( 529 | observation_nn=self.state_to_nn_input_fn(env_state), 530 | policy_mask=metadata.action_mask, 531 | policy_weights=jnp.zeros_like(metadata.action_mask, dtype=jnp.float32), 532 | reward=jnp.zeros_like(metadata.rewards), 533 | cur_player_id=metadata.cur_player_id 534 | ) 535 | 536 | 537 | def init_collection_state(self, key: jax.random.PRNGKey, batch_size: int) -> CollectionState: 538 | """Initializes the collection state (see CollectionState). 539 | 540 | Args: 541 | - `key`: rng 542 | - `batch_size`: number of parallel environments 543 | 544 | Returns: 545 | - (CollectionState): initialized collection state 546 | """ 547 | # make template experience 548 | template_experience = self.make_template_experience() 549 | # init buffer state 550 | buffer_state = self.memory_buffer.init(batch_size, template_experience) 551 | # init env state 552 | env_init_key, key = jax.random.split(key) 553 | env_keys = jax.random.split(env_init_key, batch_size) 554 | env_state, metadata = jax.vmap(self.env_init_fn)(env_keys) 555 | # init evaluator state 556 | eval_state = self.evaluator_train.init_batched(batch_size, template_embedding=self.template_env_state) 557 | # return collection state 558 | return CollectionState( 559 | eval_state=eval_state, 560 | env_state=env_state, 561 | buffer_state=buffer_state, 562 | metadata=metadata 563 | ) 564 | 565 | 566 | def train_loop(self, 567 | seed: int, 568 | num_epochs: int, 569 | eval_every: int = 1, 570 | initial_state: Optional[TrainLoopOutput] = None 571 | ) -> Tuple[CollectionState, TrainState]: 572 | """Runs the training loop for `num_epochs` epochs. Mostly configured by the Trainer's attributes. 573 | - Collects self-play episdoes across a batch of environments. 574 | - Trains the neural network on the collected experiences. 575 | - Tests the agent on a set of Testers, which evaluate the agent's performance. 576 | 577 | Args: 578 | - `seed`: rng seed (int) 579 | - `num_epochs`: number of epochs to run the training loop for 580 | - `eval_every`: number of epochs between evaluations 581 | - `initial_state`: (optional) TrainLoopOutput, used to continue training from a previous state 582 | 583 | Returns: 584 | - (TrainLoopOutput): contains train_state, collection_state, test_states, cur_epoch after training loop 585 | """ 586 | # init rng 587 | key = jax.random.PRNGKey(seed) 588 | 589 | # initialize states 590 | if initial_state: 591 | collection_state = initial_state.collection_state 592 | train_state = initial_state.train_state 593 | tester_states = initial_state.test_states 594 | cur_epoch = initial_state.cur_epoch 595 | else: 596 | cur_epoch = 0 597 | # initialize collection state 598 | init_key, key = jax.random.split(key) 599 | collection_state = partition(self.init_collection_state(init_key, self.batch_size), self.num_devices) 600 | # initialize train state 601 | init_key, key = jax.random.split(key) 602 | init_keys = jnp.tile(init_key[None], (self.num_devices, 1)) 603 | train_state = self.init_train_state(init_keys) 604 | params = self.extract_model_params_fn(train_state) 605 | # initialize tester states 606 | tester_states = [] 607 | for tester in self.testers: 608 | state = jax.pmap(tester.init, axis_name='d')(params=params) 609 | tester_states.append(state) 610 | 611 | # warmup 612 | # populate replay buffer with initial self-play games 613 | collect = jax.vmap(self.collect_steps, in_axes=(1, 1, None, None), out_axes=1) 614 | params = self.extract_model_params_fn(train_state) 615 | collect_key, key = jax.random.split(key) 616 | collect_keys = partition(jax.random.split(collect_key, self.batch_size), self.num_devices) 617 | collection_state = collect(collect_keys, collection_state, params, self.warmup_steps) 618 | 619 | # training loop 620 | while cur_epoch < num_epochs: 621 | # collect self-play games 622 | collect_key, key = jax.random.split(key) 623 | collect_keys = partition(jax.random.split(collect_key, self.batch_size), self.num_devices) 624 | collection_state = collect(collect_keys, collection_state, params, self.collection_steps_per_epoch) 625 | # train 626 | train_key, key = jax.random.split(key) 627 | collection_state, train_state, metrics = self.train_steps(train_key, collection_state, train_state, self.train_steps_per_epoch) 628 | # log metrics 629 | collection_steps = self.batch_size * (cur_epoch+1) * self.collection_steps_per_epoch 630 | self.log_metrics(metrics, cur_epoch, step=collection_steps) 631 | 632 | # test 633 | if cur_epoch % eval_every == 0: 634 | params = self.extract_model_params_fn(train_state) 635 | for i, test_state in enumerate(tester_states): 636 | run_key, key = jax.random.split(key) 637 | new_test_state, metrics, rendered = self.testers[i].run( 638 | key=run_key, epoch_num=cur_epoch, max_steps=self.max_episode_steps, num_devices=self.num_devices, 639 | env_step_fn=self.env_step_fn, env_init_fn=self.env_init_fn, evaluator=self.evaluator_test, 640 | state=test_state, params=params) 641 | 642 | metrics = {k: v.mean() for k, v in metrics.items()} 643 | self.log_metrics(metrics, cur_epoch, step=collection_steps) 644 | if rendered and self.run is not None: 645 | self.run.log({f'{self.testers[i].name}_game': wandb.Video(rendered)}, step=collection_steps) 646 | tester_states[i] = new_test_state 647 | # save checkpoint 648 | # make sure previous save task has finished 649 | self.checkpoint_manager.wait_until_finished() 650 | self.save_checkpoint(train_state, cur_epoch) 651 | # next epoch 652 | cur_epoch += 1 653 | 654 | # make sure last save task has finished 655 | self.checkpoint_manager.wait_until_finished() # 656 | # return state so that training can be continued! 657 | return TrainLoopOutput( 658 | collection_state=collection_state, 659 | train_state=train_state, 660 | test_states=tester_states, 661 | cur_epoch=cur_epoch 662 | ) 663 | -------------------------------------------------------------------------------- /core/trees/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero/74cdb9cbd736b396141df4a07db268f28b687f9a/core/trees/__init__.py -------------------------------------------------------------------------------- /core/trees/tree.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | from typing import Tuple, TypeVar, Generic, ClassVar 4 | import chex 5 | from chex import dataclass 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | NodeType = TypeVar('NodeType') 10 | 11 | @dataclass(frozen=True) 12 | class Tree(Generic[NodeType]): 13 | """A generic DAG tree data structure that holds arbitrary structured data within nodes.""" 14 | # N -> max nodes 15 | # F -> branching Factor 16 | next_free_idx: chex.Array # () 17 | parents: chex.Array # (N) 18 | edge_map: chex.Array # (N, F) 19 | data: chex.ArrayTree # structured data with leaves of shape (N, ...) 20 | 21 | NULL_INDEX: ClassVar[int] = -1 22 | NULL_VALUE: ClassVar[int] = 0 23 | ROOT_INDEX: ClassVar[int] = 0 24 | 25 | @property 26 | def capacity(self) -> int: 27 | """the maximum number of nodes that can be stored in the tree.""" 28 | return self.parents.shape[-1] 29 | 30 | 31 | @property 32 | def branching_factor(self) -> int: 33 | """the maximum number of children a node can have.""" 34 | return self.edge_map.shape[-1] 35 | 36 | 37 | def data_at(self, index: int) -> NodeType: 38 | """returns a node's data at a specific index. 39 | 40 | Args: 41 | - `index`: the index of the node to retrieve data from. 42 | 43 | Returns: 44 | - (NodeType): the data stored at the specified index. 45 | """ 46 | return jax.tree_util.tree_map( 47 | lambda x: x[index], 48 | self.data 49 | ) 50 | 51 | 52 | def check_data_type(self, data: NodeType) -> None: 53 | """checks if the data type matches the tree's data type. 54 | 55 | Args: 56 | - `data`: the data to check. 57 | 58 | Returns: 59 | - None 60 | """ 61 | assert isinstance(data, type(self.data)), \ 62 | f"data type mismatch, tree contains {type(self.data)} data, but got {type(data)} data." 63 | 64 | 65 | def is_edge(self, parent_index: int, edge_index: int) -> bool: 66 | """checks if an edge exists from a parent node along a specific edge. 67 | 68 | Args: 69 | - `parent_index`: the index of the parent node. 70 | - `edge_index`: the index of the edge to check. 71 | 72 | Returns: 73 | - (bool): whether an edge exists from the parent node along the specified edge. 74 | """ 75 | return self.edge_map[parent_index, edge_index] != self.NULL_INDEX 76 | 77 | 78 | def get_child_data(self, x: str, index: int, null_value=None) -> chex.ArrayTree: 79 | """returns a specified data field for all children of a node 80 | 81 | Args: 82 | - `x`: the data field to extract from child nodes 83 | - `index`: the index of the parent node 84 | - `null_value`: the value to use for children that do not exist 85 | 86 | Returns: 87 | - (chex.ArrayTree): the extracted data field for all children of the parent node 88 | """ 89 | assert hasattr(self.data, x), f"field {x} not found in node data." 90 | 91 | if null_value is None: 92 | null_value = self.NULL_VALUE 93 | mapping = self.edge_map[index] 94 | child_data = getattr(self.data, x)[mapping] 95 | 96 | return jnp.where( 97 | (mapping == self.NULL_INDEX).reshape((-1,) + (1,) * (child_data.ndim - 1)), 98 | null_value, child_data) 99 | 100 | 101 | def add_node(self, parent_index: int, edge_index: int, data: NodeType) -> Tree[NodeType]: 102 | """adds a new node to the tree at the next free index, if the tree has capacity left. 103 | Is a no-op if the tree is full. 104 | 105 | Args: 106 | - `parent_index`: the index of the parent node to attach the new node to. 107 | - `edge_index`: the index of the parent edge to attach the new node to. 108 | - `data`: the data to store at the new node. 109 | 110 | Returns: 111 | - (Tree[NodeType]): tree with the new node added. 112 | """ 113 | # check types 114 | self.check_data_type(data) 115 | # if the tree is full, tree.next_free_idx will be out of bounds 116 | in_bounds = self.next_free_idx < self.capacity 117 | # updating data at this index will be a no-op 118 | # e.g. tree.parents.at[tree.next_free_idx].set(parent_index) 119 | # will do nothing 120 | # BUT 121 | # we don't want to modify edge_map to point to this index 122 | # so we set it to NULL_INDEX instead when the tree is full 123 | edge_map_index = jnp.where(in_bounds, self.next_free_idx, self.NULL_INDEX) 124 | # ... 125 | return self.replace( #pylint: disable=no-member 126 | next_free_idx=jnp.where(in_bounds, self.next_free_idx + 1, self.next_free_idx), 127 | parents=self.parents.at[self.next_free_idx].set(parent_index), 128 | edge_map=self.edge_map.at[parent_index, edge_index].set(edge_map_index), 129 | data=jax.tree_map( 130 | lambda x, y: x.at[self.next_free_idx].set(y), 131 | self.data, data) 132 | ) 133 | 134 | 135 | def set_root(self, data: NodeType) -> Tree[NodeType]: 136 | """Sets node data at the root node. 137 | 138 | Args: 139 | - `data`: the data to store at the root node. 140 | 141 | Returns: 142 | - (Tree[NodeType]): tree with the root node data set. 143 | """ 144 | self.check_data_type(data) 145 | 146 | return self.replace( #pylint: disable=no-member 147 | next_free_idx=jnp.maximum(self.next_free_idx, 1), 148 | data=jax.tree_map( 149 | lambda x, y: x.at[self.ROOT_INDEX].set(y), 150 | self.data, data)) 151 | 152 | 153 | def update_node(self, index: int, data: NodeType) -> Tree: 154 | """updates the data of a node at a specific index. 155 | 156 | Args: 157 | - `index`: the index of the node to update. 158 | - `data`: the new data to store at the specified index. 159 | 160 | Returns: 161 | - (Tree): tree with the node data updated. 162 | """ 163 | return self.replace( #pylint: disable=no-member 164 | data=jax.tree_util.tree_map( 165 | lambda x, y: x.at[index].set(y), 166 | self.data, data)) 167 | 168 | 169 | def _get_translation(self, child_index: int) -> Tuple[chex.Array, chex.Array, chex.Array]: 170 | """Extracts mapping of node_idxs in a particular root subtree (with root at `child_index`) to collapsed indices. 171 | 172 | Args: 173 | - `child_index`: the index of the child node to use as the root of the subtree. 174 | 175 | Returns: 176 | - (Tuple[chex.Array, chex.Array, chex.Array]): 177 | - old_subtree_idxs: the indices of the nodes in the subtree rooted at `child_index`. 178 | - translation: the mapping from the old indices to the new indices after collapsing the subtree. 179 | - erase_idxs: the indices of the nodes that will be erased after collapsing the subtree. 180 | """ 181 | # initialize each node as its own subtree 182 | subtrees = jnp.arange(self.capacity) 183 | 184 | def propagate(_, subtrees): 185 | # propagates parent subtrees to children 186 | parents_subtrees = jnp.where( 187 | self.parents != self.NULL_INDEX, 188 | subtrees[self.parents], 189 | 0 190 | ) 191 | return jnp.where( 192 | jnp.greater(parents_subtrees, 0), 193 | parents_subtrees, 194 | subtrees 195 | ) 196 | 197 | # propagate parent subtrees to children, until all nodes are assigned to one of the root subtrees 198 | subtrees = jax.lax.fori_loop(0, self.capacity-1, propagate, subtrees) 199 | 200 | # get idx of subtree 201 | subtree_idx = self.edge_map[self.ROOT_INDEX, child_index] 202 | # get nodes that are part of the subtree 203 | nodes_to_retain = subtrees == subtree_idx 204 | slots_aranged = jnp.arange(self.capacity) 205 | old_subtree_idxs = nodes_to_retain * slots_aranged 206 | # get translation of old indices to new indices (collapsed) 207 | cumsum = jnp.cumsum(nodes_to_retain) 208 | new_next_node_index = cumsum[-1] 209 | translation = jnp.where( 210 | nodes_to_retain, 211 | nodes_to_retain * (cumsum-1), 212 | self.NULL_INDEX 213 | ) 214 | # get indices of nodes that will be erased 215 | erase_idxs = slots_aranged >= new_next_node_index 216 | 217 | return old_subtree_idxs, translation, erase_idxs 218 | 219 | 220 | def get_subtree(self, subtree_index: int) -> Tree: 221 | """Extracts a subtree rooted at a specific node index. 222 | Collapses subtree into a new tree with the root node at index 0, and children in subsequent indices. 223 | 224 | Args: 225 | - `subtree_index`: the edge index (from the root node) of the node to use as the new root. 226 | 227 | Returns: 228 | - (Tree): the subtree rooted at the specified node index. 229 | """ 230 | # get subtree translation 231 | old_subtree_idxs, translation, erase_idxs = self._get_translation(subtree_index) 232 | 233 | new_next_node_index = translation.max(axis=-1) + 1 234 | 235 | def translate(x, null_value=self.NULL_VALUE): 236 | return jnp.where( 237 | erase_idxs.reshape((-1,) + (1,) * (x.ndim - 1)), 238 | jnp.full_like(x, null_value, dtype=x.dtype), 239 | # cases where translation == -1 will set last index 240 | # but since we are at least removing the root node 241 | # (and making one of its children the new root) 242 | # the last index will always be freed 243 | # and overwritten with zeros 244 | x.at[translation].set(x[old_subtree_idxs]), 245 | ) 246 | 247 | def translate_idx(x, null_value=self.NULL_INDEX): 248 | return jnp.where( 249 | erase_idxs.reshape((-1,) + (1,) * (x.ndim - 1)), 250 | null_value, 251 | # in this case we need to explicitly check for index 252 | # mappings to UNVISITED, since otherwise thsese will 253 | # map to the value of the last index of the translation 254 | x.at[translation].set(jnp.where( 255 | x == null_value, 256 | jnp.full_like(x, null_value, dtype=x.dtype), 257 | translation[x]))) 258 | 259 | def translate_pytree(x, null_value=self.NULL_VALUE): 260 | return jax.tree_map( 261 | lambda t: translate(t, null_value=null_value), x) 262 | 263 | # extract subtree using translation functions 264 | return self.replace( #pylint: disable=no-member 265 | next_free_idx=new_next_node_index, 266 | parents=translate_idx(self.parents), 267 | edge_map=translate_idx(self.edge_map), 268 | data=translate_pytree(self.data) 269 | ) 270 | 271 | 272 | def reset(self) -> Tree: 273 | """Resets the tree to its initial state.""" 274 | return self.replace( #pylint: disable=no-member 275 | next_free_idx=0, 276 | parents=jnp.full_like(self.parents, self.NULL_INDEX), 277 | edge_map=jnp.full_like(self.edge_map, self.NULL_INDEX), 278 | data=jax.tree_map(jnp.zeros_like, self.data)) 279 | 280 | 281 | def init_tree(max_nodes: int, branching_factor: int, template_data: NodeType) -> Tree: 282 | """ Initializes a new Tree. 283 | 284 | Args: 285 | - `max_nodes`: the maximum number of nodes the tree can store. 286 | - `branching_factor`: the maximum number of children a node can have. 287 | - `template_data`: template of node data 288 | 289 | Returns: 290 | - (Tree): a new tree with the specified parameters. 291 | """ 292 | return Tree( 293 | next_free_idx=jnp.array(0, dtype=jnp.int32), 294 | parents=jnp.full((max_nodes,), fill_value=Tree.NULL_INDEX, dtype=jnp.int32), 295 | edge_map=jnp.full((max_nodes, branching_factor), fill_value=Tree.NULL_INDEX, dtype=jnp.int32), 296 | data=jax.tree_util.tree_map( 297 | lambda x: jnp.zeros((max_nodes, *x.shape), dtype=x.dtype), 298 | template_data)) 299 | -------------------------------------------------------------------------------- /core/types.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Callable, Tuple 3 | 4 | import chex 5 | from flax.training.train_state import TrainState 6 | import jax 7 | import optax 8 | 9 | from core.memory.replay_memory import BaseExperience 10 | 11 | @chex.dataclass(frozen=True) 12 | class StepMetadata: 13 | """Metadata for a step in the environment. 14 | - `rewards`: rewards received by the players 15 | - `action_mask`: mask of valid actions 16 | - `terminated`: whether the environment is terminated 17 | - `cur_player_id`: current player id 18 | - `step`: step number 19 | """ 20 | rewards: chex.Array 21 | action_mask: chex.Array 22 | terminated: bool 23 | cur_player_id: int 24 | step: int 25 | 26 | 27 | EnvStepFn = Callable[[chex.ArrayTree, int], Tuple[chex.ArrayTree, StepMetadata]] 28 | EnvInitFn = Callable[[jax.random.PRNGKey], Tuple[chex.ArrayTree, StepMetadata]] 29 | DataTransformFn = Callable[[chex.Array, chex.Array, chex.ArrayTree], Tuple[chex.Array, chex.Array, chex.ArrayTree]] 30 | Params = chex.ArrayTree 31 | EvalFn = Callable[[chex.ArrayTree, Params, jax.random.PRNGKey], Tuple[chex.Array, float]] 32 | LossFn = Callable[[chex.ArrayTree, TrainState, BaseExperience], Tuple[chex.Array, Tuple[chex.ArrayTree, optax.OptState]]] 33 | ExtractModelParamsFn = Callable[[TrainState], chex.ArrayTree] 34 | StateToNNInputFn = Callable[[chex.ArrayTree], chex.Array] 35 | -------------------------------------------------------------------------------- /notebooks/hello_world.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Hello World, TurboZero 🏁" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "`turbozero` provides a vectorized implementation of AlphaZero. \n", 15 | "\n", 16 | "In a nutshell, this means we can massively speed up training, by collecting many self-play games and running Monte Carlo Tree Search in parallel across one or more GPUs!\n", 17 | "\n", 18 | "As the user, you just need to provide:\n", 19 | "* environment dynamics functions (step and init) that adhere to the TurboZero spec\n", 20 | "* a conversion function for environment state -> neural net input\n", 21 | "* and a few hyperparameters!\n", 22 | "\n", 23 | "TurboZero takes care of the rest. 😀 " 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Getting Started\n", 31 | "\n", 32 | "Follow the instructions in the repo readme to properly install dependencies and set up your environment." 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## Environments\n", 40 | "\n", 41 | "In order to take advantage of the batched implementation of AlphaZero, we need to pair it with a vectorized environment.\n", 42 | "\n", 43 | "Fortunately, there are many great vectorized RL environment libraries, one I like in particular is [pgx](https://github.com/sotetsuk/pgx).\n", 44 | "\n", 45 | "Let's use the 'othello' environment. You can see its documentation here: https://sotets.uk/pgx/othello/" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 1, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "import pgx\n", 55 | "\n", 56 | "env = pgx.make('othello')" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Environment Dynamics\n", 64 | "\n", 65 | "Turbozero needs to interface with the environment in order to build search trees and collect self-play episodes.\n", 66 | "\n", 67 | "We can define this interface with the following functions:\n", 68 | "* `env_step_fn`: given an environment state and an action, return the new environment state \n", 69 | "```python\n", 70 | " EnvStepFn = Callable[[chex.ArrayTree, int], Tuple[chex.ArrayTree, StepMetadata]]\n", 71 | "```\n", 72 | "* `env_init_fn`: given a key, initialize and reutrn a new environment state\n", 73 | "```python\n", 74 | " EnvInitFn = Callable[[chex.PRNGKey], Tuple[chex.ArrayTree, StepMetadata]]\n", 75 | "```\n", 76 | "Fortunately, environment libraries implement these for us! We just need to extract a few key pieces of information \n", 77 | "from the environment state so that we can match the TurboZero specification. We store this in a StepMetadata object:" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 2, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "\u001b[0;34m@\u001b[0m\u001b[0mchex\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataclass\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfrozen\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", 90 | "\u001b[0;34m\u001b[0m\u001b[0;32mclass\u001b[0m \u001b[0mStepMetadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", 91 | "\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\"Metadata for a step in the environment.\u001b[0m\n", 92 | "\u001b[0;34m - `rewards`: rewards received by the players\u001b[0m\n", 93 | "\u001b[0;34m - `action_mask`: mask of valid actions\u001b[0m\n", 94 | "\u001b[0;34m - `terminated`: whether the environment is terminated\u001b[0m\n", 95 | "\u001b[0;34m - `cur_player_id`: current player id\u001b[0m\n", 96 | "\u001b[0;34m - `step`: step number\u001b[0m\n", 97 | "\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", 98 | "\u001b[0;34m\u001b[0m \u001b[0mrewards\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mchex\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArray\u001b[0m\u001b[0;34m\u001b[0m\n", 99 | "\u001b[0;34m\u001b[0m \u001b[0maction_mask\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mchex\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArray\u001b[0m\u001b[0;34m\u001b[0m\n", 100 | "\u001b[0;34m\u001b[0m \u001b[0mterminated\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m\u001b[0;34m\u001b[0m\n", 101 | "\u001b[0;34m\u001b[0m \u001b[0mcur_player_id\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m\u001b[0m\n", 102 | "\u001b[0;34m\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "from core.types import StepMetadata\n", 108 | "%psource StepMetadata" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "* `rewards` stores the rewards emitted for each player for the given timestep\n", 116 | "* `action_mask` is a mask across all possible actions, where legal actions are set to `True`, and invalid/illegal actions are set to `False`\n", 117 | "* `terminated` True if the environment is terminated/completed\n", 118 | "* `cur_player_id`: id of the current player\n", 119 | "* `step`: step number" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "We can define the environment interface for `Othello` as follows:" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 3, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "def step_fn(state, action):\n", 136 | " new_state = env.step(state, action)\n", 137 | " return new_state, StepMetadata(\n", 138 | " rewards=new_state.rewards,\n", 139 | " action_mask=new_state.legal_action_mask,\n", 140 | " terminated=new_state.terminated,\n", 141 | " cur_player_id=new_state.current_player,\n", 142 | " step = new_state._step_count\n", 143 | " )\n", 144 | "\n", 145 | "def init_fn(key):\n", 146 | " state = env.init(key)\n", 147 | " return state, StepMetadata(\n", 148 | " rewards=state.rewards,\n", 149 | " action_mask=state.legal_action_mask,\n", 150 | " terminated=state.terminated,\n", 151 | " cur_player_id=state.current_player,\n", 152 | " step = state._step_count\n", 153 | " )" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "Pretty easy!" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "## Neural Network\n", 168 | "\n", 169 | "Next, we'll need to define the architecture of the neural network \n", 170 | "\n", 171 | "A simple implementation of the residual neural network used in the _AlphaZero_ paper is included for your convenience. \n", 172 | "\n", 173 | "You can implement your own architecture using `flax.linen`." 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 4, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "from core.networks.azresnet import AZResnetConfig, AZResnet\n", 183 | "\n", 184 | "resnet = AZResnet(AZResnetConfig(\n", 185 | " policy_head_out_size=env.num_actions,\n", 186 | " num_blocks=4,\n", 187 | " num_channels=32,\n", 188 | "))" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "We also need a way to convert our environment's state into something our neural network can take as input (i.e. structured data -> Array). `pgx` conveniently includes this in `state.observation`, but for other environments you may need to perform the conversion yourself." 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 5, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "def state_to_nn_input(state):\n", 205 | " return state.observation" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": {}, 211 | "source": [ 212 | "## Evaluator\n", 213 | "\n", 214 | "Next, we can initialize our evaluator, AlphaZero, which takes the following parameters:\n", 215 | "\n", 216 | "* `eval_fn`: function used to evaluate a leaf node (returns a policy and value)\n", 217 | "* `num_iterations`: number of MCTS iterations to run before returning the final policy\n", 218 | "* `max_nodes`: maximum capacity of search tree\n", 219 | "* `branching_factor`: branching factor of search tree == policy_size\n", 220 | "* `action_selector`: the algorithm used to select an action to take at any given search node, choose between:\n", 221 | " * `PUCTSelector`: AlphaZero action selection algorithm\n", 222 | " * `MuZeroPUCTSelector`: MuZero action selection algorithm\n", 223 | " * or write your own! :)\n", 224 | "\n", 225 | "There are also a few other optional parameters, a few of the important ones are:\n", 226 | "* `temperature`: temperature applied to move probabilities prior to sampling (0.0 == argmax, ->inf == completely random sampling). I reccommend setting this to 1.0 for training (default) and 0.0 for evaluation.\n", 227 | "* `dirichlet_alpha`: magnitude of Dirichlet noise to add to root policy (default 0.3). Generally, the more actions are possible in a game, the smaller this value should be. \n", 228 | "* `dirichlet_epsilon`: proportion of root policy composed of Dirichlet noise (default 0.25)\n", 229 | "\n", 230 | "\n", 231 | "We use `make_nn_eval_fn` to create a leaf evaluation function that uses our neural network to generate a policy and a value for the given state. " 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 6, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "from core.evaluators.alphazero import AlphaZero\n", 241 | "from core.evaluators.evaluation_fns import make_nn_eval_fn\n", 242 | "from core.evaluators.mcts.action_selection import PUCTSelector\n", 243 | "from core.evaluators.mcts.mcts import MCTS\n", 244 | "\n", 245 | "# alphazero can take an arbirary search `backend`\n", 246 | "# here we use classic MCTS\n", 247 | "az_evaluator = AlphaZero(MCTS)(\n", 248 | " eval_fn = make_nn_eval_fn(resnet, state_to_nn_input),\n", 249 | " num_iterations = 32,\n", 250 | " max_nodes = 40,\n", 251 | " branching_factor = env.num_actions,\n", 252 | " action_selector = PUCTSelector(),\n", 253 | " temperature = 1.0\n", 254 | ")" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "We also define a separate evaluator with different parameters to use for testing purposes. We'll give this one a larger budget (num_iterations), and set the temperature to zero so it always chooses the most-visited action after search is complete." 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 7, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "az_evaluator_test = AlphaZero(MCTS)(\n", 271 | " eval_fn = make_nn_eval_fn(resnet, state_to_nn_input),\n", 272 | " num_iterations = 64,\n", 273 | " max_nodes = 80,\n", 274 | " branching_factor = env.num_actions,\n", 275 | " action_selector = PUCTSelector(),\n", 276 | " temperature = 0.0\n", 277 | ")" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "## Baselines\n", 285 | "\n", 286 | "We can also test our trained model periodically against baselines, in order to gauge improvement.\n", 287 | "\n", 288 | "Conveniently, pgx offers pre-trained baseline models for certain environments. If we want to test against one, we can use `make_nn_eval_fn_no_params_callable`, which just returns an evaluation function that uses the baseline model to evaluate the game state.\n", 289 | "\n", 290 | "We can combine this with the `AlphaZero` evaluator like we did before to create a competeing AlphaZero instance to play against. Other than the eval_fn, it is important to use the same parameters that we give to our (test) evaluator, so as to give a true comparison between the strength of the policy/value estimates." 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 8, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "from core.evaluators.evaluation_fns import make_nn_eval_fn_no_params_callable\n", 300 | "\n", 301 | "model = pgx.make_baseline_model('othello_v0')\n", 302 | "\n", 303 | "baseline_eval_fn = make_nn_eval_fn_no_params_callable(model, state_to_nn_input)\n", 304 | "\n", 305 | "baseline_az = AlphaZero(MCTS)(\n", 306 | " eval_fn = baseline_eval_fn,\n", 307 | " num_iterations = 64,\n", 308 | " max_nodes = 80,\n", 309 | " branching_factor = env.num_actions,\n", 310 | " action_selector = PUCTSelector(),\n", 311 | " temperature = 0.0\n", 312 | ")" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": {}, 318 | "source": [ 319 | "To see baselines available for other pgx environments, see https://sotets.uk/pgx/api/#pgx.BaselineModelId" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "We can use similar ideas to write a greedy baseline evaluation function, one that doesn't use a neural network at all!\n", 327 | "\n", 328 | "Instead, it simply counts the number of tiles for the active player and compares it to the number of tiles controlled by the other player, so the value is higher for states where the active player controls more tiles than the other player.\n", 329 | "\n", 330 | "Using similar techniques as before, we can create another AlphaZero evaluator to test against." 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 9, 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "import jax.numpy as jnp\n", 340 | "\n", 341 | "def greedy_eval(obs):\n", 342 | " value = (obs[...,0].sum() - obs[...,1].sum()) / 64\n", 343 | " return jnp.ones((1,env.num_actions)), jnp.array([value])\n", 344 | "\n", 345 | "greedy_baseline_eval_fn = make_nn_eval_fn_no_params_callable(greedy_eval, state_to_nn_input)\n", 346 | "\n", 347 | "\n", 348 | "greedy_az = AlphaZero(MCTS)(\n", 349 | " eval_fn = greedy_baseline_eval_fn,\n", 350 | " num_iterations = 64,\n", 351 | " max_nodes = 80,\n", 352 | " branching_factor = env.num_actions,\n", 353 | " action_selector = PUCTSelector(),\n", 354 | " temperature = 0.0\n", 355 | ")" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": {}, 361 | "source": [ 362 | "## Replay Memory Buffer\n", 363 | "\n", 364 | "Next, we'll initialize a replay memory buffer to hold selfplay trajectories that we can sample from during training. This actually just defines an interface, the buffer state itself will be initialized and managed internally.\n", 365 | "\n", 366 | "The replay buffer is batched, it retains a buffer of trajectories across a batch dimension. We specify a `capacity`: the amount of samples stored in a single buffer. The total capacity of the entire replay buffer is then `batch_size * capacity`, where `batch_size` is the number of environments/self-play games being run in parallel." 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 10, 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "from core.memory.replay_memory import EpisodeReplayBuffer\n", 376 | "\n", 377 | "replay_memory = EpisodeReplayBuffer(capacity=1000)" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "metadata": {}, 383 | "source": [ 384 | "## Data Augmentation (Optional)\n", 385 | "\n", 386 | "During self-play, we allow for any number of custom data augmentation functions, in order to create more training samples. \n", 387 | "\n", 388 | "In RL, it's sometimes common to take advantage of rotations or symmetries in order to generate additional training examples. \n", 389 | "\n", 390 | "In Othello, we could simply consider rotating the board to generate a new training example, we will need to be careful to update our policy as well.\n", 391 | "\n", 392 | "In order to implement a data augmentation function, we must follow `DataTransformFn`:\n", 393 | "```python\n", 394 | "# (policy mask, policy weights, environment state) -> (transformed policy mask, transformed policy weights, transformed environment state)\n", 395 | "DataTransformFn = Callable[[chex.Array, chex.Array, chex.ArrayTree], Tuple[chex.Array, chex.Array, chex.ArrayTree]]\n", 396 | "```\n", 397 | "\n", 398 | "We create rotational transform functions for rotating 90, 180, 270 degrees:" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 11, 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [ 407 | "def make_rot_transform_fn(amnt: int):\n", 408 | " def rot_transform_fn(mask, policy, state):\n", 409 | " action_ids = jnp.arange(65) # 65 total actions, but only rotate the first 64! (65th is always do nothing action)\n", 410 | " # we only use state.observation, no need to update the rest of the state fields\n", 411 | " new_obs = jnp.rot90(state.observation, amnt, axes=(-3,-2))\n", 412 | " # map action ids to new action ids\n", 413 | " idxs = jnp.arange(64).reshape(8,8) # rotate first 64 actions\n", 414 | " new_idxs = jnp.rot90(idxs, amnt, axes=(0, 1)).flatten()\n", 415 | " action_ids = action_ids.at[:64].set(new_idxs)\n", 416 | " # get new mask and policy\n", 417 | " new_mask = mask[...,action_ids]\n", 418 | " new_policy = policy[...,action_ids]\n", 419 | " return new_mask, new_policy, state.replace(observation=new_obs)\n", 420 | "\n", 421 | " return rot_transform_fn\n", 422 | "\n", 423 | "# make transform fns for rotating 90, 180, 270 degrees\n", 424 | "transforms = [make_rot_transform_fn(i) for i in range(1,4)] " 425 | ] 426 | }, 427 | { 428 | "cell_type": "markdown", 429 | "metadata": {}, 430 | "source": [ 431 | "## Rendering\n", 432 | "We can optionally provide a `render_fn` that will record games played by our model against one of the baselines and save it as a `.gif`.\n", 433 | "\n", 434 | "I've included a helper fn that takes care of this:" 435 | ] 436 | }, 437 | { 438 | "cell_type": "markdown", 439 | "metadata": {}, 440 | "source": [ 441 | "This helper function depends upon cairoSVG, which itself depends upon `cairo`, which you'll need to install on your system.\n", 442 | "\n", 443 | "On Ubuntu, this can be done with:" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "! apt-get update && apt-get -y install libcairo2-dev" 453 | ] 454 | }, 455 | { 456 | "cell_type": "markdown", 457 | "metadata": {}, 458 | "source": [ 459 | "If you're on another OS, consult https://www.cairographics.org/download/ for installation guidelines" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": 13, 465 | "metadata": {}, 466 | "outputs": [], 467 | "source": [ 468 | "from functools import partial\n", 469 | "from core.testing.utils import render_pgx_2p\n", 470 | "render_fn = partial(render_pgx_2p, p1_label='Black', p2_label='White', duration=900)" 471 | ] 472 | }, 473 | { 474 | "cell_type": "markdown", 475 | "metadata": {}, 476 | "source": [ 477 | "## Trainer Initialization\n", 478 | "Now that we have all the proper pieces defined, we are ready to initialize a Trainer and start training!\n", 479 | "\n", 480 | "The `Trainer` takes many parameters, so let's walk through them all:\n", 481 | "* `batch_size`: # of parallel environments used to collect self-play games\n", 482 | "* `train_batch_size`: size of minibatch used during training step\n", 483 | "* `warmup_steps`: # of steps (per batch) to collect via self-play prior to entering the training loop. This is used to populate the replay memory with some initial samples\n", 484 | "* `collection_steps_per_epoch`: # of steps (per batch) to collect via self-play per epoch\n", 485 | "* `train_steps_per_epoch`: # of train steps per epoch\n", 486 | "* `nn`: neural network (`linen.Module`)\n", 487 | "* `loss_fn`: loss function used for training, we use a provided default loss which implements the loss function used in the `AlphaZero` paper\n", 488 | "* `optimizer`: an `optax` optimizer used for training\n", 489 | "* `evaluator`: the `Evaluator` to use during self-play, we initialized ours using `AlphaZero(MCTS)`\n", 490 | "* `memory_buffer`: the memory buffer used to store samples from self-play games, we initialized ours using `EpisodeReplayBuffer`\n", 491 | "* `max_episode_steps`: maximum number of steps/turns to allow before truncating an episode\n", 492 | "* `env_step_fn`: environment step function (we defined ours above)\n", 493 | "* `env_init_fn`: environment init function (we defined ours above)\n", 494 | "* `state_to_nn_input_fn`: function to convert environment state to nn input (we defined ours above)\n", 495 | "* `testers`: any number of `Tester`s, used to evaluate a given model and take their own parameters. We'll use the two evaluators defined above to initialize two Testers.\n", 496 | "* `evaluator_test`: (Optional) Evaluator used within Testers. By default used `evaluator`, but sometimes you may want to test with a larger MCTS iteration budget for example, or a different move sampling temperature\n", 497 | "* `data_transform_fns`: (optional) list of data transform functions to apply to self-play experiences (e.g. rotation, reflection, etc.)\n", 498 | "* `extract_model_params_fn`: (Optional) in special cases we need to define how to extract all model parameters from a flax `TrainState`. The default function handles BatchNorm, but if another special-case technique applied across batches is used (e.g. Dropout) we would need to define a function to extract the appropriate parameters. You usually won't need to define this!\n", 499 | "* `wandb_project_name`: (Optional) Weights and Biases project name. You will be prompted to login if a name is provided. If a name is provided, a run will be initialized and loss and other metrics will be logged to the given wandb project.\n", 500 | "* `ckpt_dir`: (Optional) directory to store checkpoints in, by default this is set to `/tmp/turbozero_checkpoints`\n", 501 | "* `max_checkpoints`: (Optional) maximum number of most-recent checkpoints to retain (default: 2)\n", 502 | "* `num_devices`: (Optional) number of hardware accelerators (GPUs/TPUs) to use. If not given, all available hardware accelerators are used\n", 503 | "* `wandb_run`: (Optional) continues from an initialized `wandb` run if provided, otherwise a new one is initialized\n", 504 | "* `extra_wandb_config`: (Optional) any extra metadata to store in the `wandb` run config\n", 505 | "\n", 506 | "A training epoch is comprised of M collection steps, followed by N training steps sampling minibatches from replay memory. Optionally, any number of Testers evaluate the current model. At the end of each epoch, a checkpoint is saved.\n", 507 | "\n", 508 | "If you are using one or more GPUs (reccommended), TurboZero by default will run on all your available hardware." 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": null, 514 | "metadata": {}, 515 | "outputs": [], 516 | "source": [ 517 | "from functools import partial\n", 518 | "from core.testing.two_player_baseline import TwoPlayerBaseline\n", 519 | "from core.training.loss_fns import az_default_loss_fn\n", 520 | "from core.training.train import Trainer\n", 521 | "import optax\n", 522 | "\n", 523 | "trainer = Trainer(\n", 524 | " batch_size = 1024,\n", 525 | " train_batch_size = 4096,\n", 526 | " warmup_steps = 0,\n", 527 | " collection_steps_per_epoch = 256,\n", 528 | " train_steps_per_epoch = 64,\n", 529 | " nn = resnet,\n", 530 | " loss_fn = partial(az_default_loss_fn, l2_reg_lambda = 0.0),\n", 531 | " optimizer = optax.adam(1e-3),\n", 532 | " evaluator = az_evaluator,\n", 533 | " memory_buffer = replay_memory,\n", 534 | " max_episode_steps = 80,\n", 535 | " env_step_fn = step_fn,\n", 536 | " env_init_fn = init_fn,\n", 537 | " state_to_nn_input_fn=state_to_nn_input,\n", 538 | " testers = [\n", 539 | " TwoPlayerBaseline(num_episodes=128, baseline_evaluator=baseline_az, render_fn=render_fn, render_dir='.', name='pretrained'),\n", 540 | " TwoPlayerBaseline(num_episodes=128, baseline_evaluator=greedy_az, render_fn=render_fn, render_dir='.', name='greedy'),\n", 541 | " ],\n", 542 | " evaluator_test = az_evaluator_test,\n", 543 | " data_transform_fns=transforms\n", 544 | " # wandb_project_name = 'turbozero-othello' \n", 545 | ")" 546 | ] 547 | }, 548 | { 549 | "cell_type": "markdown", 550 | "metadata": {}, 551 | "source": [ 552 | "## Training\n", 553 | "\n", 554 | "Now all that's left to do is to kick off the training loop! We need to pass an initial seed for reproducibility, and the number of epochs to run for!\n", 555 | "\n", 556 | "If you've set up `wandb`, you can track metrics via plots in the run dashboard. Metrics will also be printed to the console. \n", 557 | "\n", 558 | "IMPORTANT: The first epoch will not execute quickly! This is because there is significant overhead in JAX compilation (nearly all of the training loop is JIT-compiled). This will cause the first epoch to run very slowly, as JIT-compiled functions are traced and compiled the first time they are run. Expect epochs after the first to execute much more quickly. Typically, GPU utilization will also be low/zero during this period.\n", 559 | "\n", 560 | "It's also worth mentioning that the hyperparameters in this notebook are just here for example purposes. Regardless of the task, they will need to be tuned according to the characteristics of the environment as well as your available hardware and time/cost constraints." 561 | ] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "execution_count": 15, 566 | "metadata": {}, 567 | "outputs": [], 568 | "source": [ 569 | "output = trainer.train_loop(seed=0, num_epochs=100, eval_every=5)" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "metadata": {}, 575 | "source": [ 576 | "and GIFs generated will appear in the same directory as this notebook, and also on your `wandb` dashboard." 577 | ] 578 | } 579 | ], 580 | "metadata": { 581 | "kernelspec": { 582 | "display_name": "turbozero", 583 | "language": "python", 584 | "name": "turbozero" 585 | }, 586 | "language_info": { 587 | "codemirror_mode": { 588 | "name": "ipython", 589 | "version": 3 590 | }, 591 | "file_extension": ".py", 592 | "mimetype": "text/x-python", 593 | "name": "python", 594 | "nbconvert_exporter": "python", 595 | "pygments_lexer": "ipython3", 596 | "version": "3.12.2" 597 | } 598 | }, 599 | "nbformat": 4, 600 | "nbformat_minor": 2 601 | } 602 | -------------------------------------------------------------------------------- /notebooks/weighted_mcts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pgx\n", 10 | "import chex\n", 11 | "import jax\n", 12 | "import jax.numpy as jnp\n", 13 | "import optax\n", 14 | "from functools import partial\n", 15 | "\n", 16 | "from core.memory.replay_memory import EpisodeReplayBuffer\n", 17 | "from core.networks.azresnet import AZResnet, AZResnetConfig\n", 18 | "from core.evaluators.alphazero import AlphaZero\n", 19 | "from core.evaluators.mcts.weighted_mcts import WeightedMCTS\n", 20 | "from core.evaluators.mcts.action_selection import PUCTSelector\n", 21 | "from core.evaluators.evaluation_fns import make_nn_eval_fn\n", 22 | "from core.testing.two_player_tester import TwoPlayerTester\n", 23 | "from core.training.train import Trainer\n", 24 | "from core.training.loss_fns import az_default_loss_fn\n", 25 | "from core.types import StepMetadata" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "This is a demo of AlphaZero using weighted MCTS. \n", 33 | "\n", 34 | "Make sure to set specify a weights and biases project name if you have a wandb account to track metrics!\n", 35 | "\n", 36 | "Hyperparameters are mostly for the purposes of example, do not assume they are correct!" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "Weighted MCTS: https://twitter.com/ptrschmdtnlsn/status/1748800529608888362\n", 44 | "\n", 45 | "Implemented here: https://github.com/lowrollr/turbozero/blob/main/core/evaluators/mcts/weighted_mcts.py\n", 46 | "\n", 47 | "temperature controlled by `q_temperature` (passed to AlphaZero initialization below)\n", 48 | "\n", 49 | "For more on turbozero, see the [README](https://github.com/lowrollr/turbozero) and \n", 50 | "[Hello World notebook](https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb). The hello world notebook explains each component we set up in this notebook!\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# get connect 4 environment\n", 60 | "# pgx has lots more to choose from!\n", 61 | "# othello, chess, etc.\n", 62 | "env = pgx.make(\"connect_four\")" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "# define environment dynamics functions\n", 72 | "def step_fn(state, action):\n", 73 | " state = env.step(state, action)\n", 74 | " metadata = StepMetadata(\n", 75 | " rewards = state.rewards,\n", 76 | " terminated = state.terminated,\n", 77 | " action_mask = state.legal_action_mask,\n", 78 | " cur_player_id = state.current_player,\n", 79 | " step = state._step_count\n", 80 | " )\n", 81 | " return state, metadata\n", 82 | "\n", 83 | "def init_fn(key):\n", 84 | " state = env.init(key)\n", 85 | " metadata = StepMetadata(\n", 86 | " rewards = state.rewards,\n", 87 | " terminated = state.terminated,\n", 88 | " action_mask = state.legal_action_mask,\n", 89 | " cur_player_id = state.current_player,\n", 90 | " step=state._step_count\n", 91 | " )\n", 92 | " return state, metadata" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 4, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "# define ResNet architecture\n", 102 | "resnet = AZResnet(AZResnetConfig(\n", 103 | " policy_head_out_size=env.num_actions,\n", 104 | " num_blocks=4, # number of residual blocks\n", 105 | " num_channels=16 # channels per block\n", 106 | "))\n" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 5, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "# define replay buffer\n", 116 | "# store 300 experiences per batch\n", 117 | "replay_memory = EpisodeReplayBuffer(capacity=300)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 6, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "# define conversion fn for environment state to nn input\n", 127 | "def state_to_nn_input(state):\n", 128 | " # pgx does this for us with state.observation!\n", 129 | " return state.observation" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 7, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "# define AlphaZero evaluator to use during self-play\n", 139 | "# with weighted MCTS\n", 140 | "alphazero = AlphaZero(WeightedMCTS)(\n", 141 | " eval_fn = make_nn_eval_fn(resnet, state_to_nn_input),\n", 142 | " num_iterations = 100, # number of MCTS iterations\n", 143 | " max_nodes = 200,\n", 144 | " dirichlet_alpha=0.6,\n", 145 | " temperature = 1.0, # MCTS root action sampling temperature\n", 146 | " branching_factor = env.num_actions,\n", 147 | " action_selector = PUCTSelector(),\n", 148 | " q_temperature = 1.0, # temperature applied to child Q values prior to weighted propagation to parent\n", 149 | ")" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 8, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "# define AlphaZero evaluator to use during evaluation games\n", 159 | "alphazero_test = AlphaZero(WeightedMCTS)(\n", 160 | " eval_fn = make_nn_eval_fn(resnet, state_to_nn_input),\n", 161 | " num_iterations = 100,\n", 162 | " max_nodes = 200,\n", 163 | " temperature = 0.0, # set temperature to zero to always sample most visited action after search\n", 164 | " branching_factor = env.num_actions,\n", 165 | " action_selector = PUCTSelector(),\n", 166 | " q_temperature = 1.0\n", 167 | ")\n" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 9, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stderr", 177 | "output_type": "stream", 178 | "text": [ 179 | "WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "# initialize trainer\n", 185 | "# set `wandb_project_name` to log to wandb!!\n", 186 | "trainer = Trainer(\n", 187 | " batch_size = 128, # number of parallel environments to collect self-play games from\n", 188 | " train_batch_size = 512, # training minibatch size\n", 189 | " warmup_steps = 42,\n", 190 | " collection_steps_per_epoch = 42,\n", 191 | " train_steps_per_epoch=(128*42)//512,\n", 192 | " nn = resnet,\n", 193 | " loss_fn = partial(az_default_loss_fn, l2_reg_lambda = 0.0001),\n", 194 | " optimizer = optax.adam(5e-3),\n", 195 | " evaluator = alphazero,\n", 196 | " memory_buffer = replay_memory,\n", 197 | " max_episode_steps=42,\n", 198 | " env_step_fn = step_fn,\n", 199 | " env_init_fn = init_fn,\n", 200 | " state_to_nn_input_fn=state_to_nn_input,\n", 201 | " testers=[TwoPlayerTester(num_episodes=64)],\n", 202 | " evaluator_test = alphazero_test,\n", 203 | " # wandb_project_name='weighted_mcts_test' \n", 204 | ")" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "output = trainer.train_loop(seed=0, num_epochs=20)" 214 | ] 215 | } 216 | ], 217 | "metadata": { 218 | "kernelspec": { 219 | "display_name": "turbozero-mMa0U6zx-py3.10", 220 | "language": "python", 221 | "name": "python3" 222 | }, 223 | "language_info": { 224 | "codemirror_mode": { 225 | "name": "ipython", 226 | "version": 3 227 | }, 228 | "file_extension": ".py", 229 | "mimetype": "text/x-python", 230 | "name": "python", 231 | "nbconvert_exporter": "python", 232 | "pygments_lexer": "ipython3", 233 | "version": "3.10.9" 234 | } 235 | }, 236 | "nbformat": 4, 237 | "nbformat_minor": 2 238 | } 239 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "turbozero" 3 | version = "0.1.1" 4 | description = "vectorized alphazero/mcts in JAX" 5 | authors = ["lowrollr <92640744+lowrollr@users.noreply.github.com>"] 6 | license = "Apache-2.0" 7 | readme = "README.md" 8 | packages = [ 9 | { include = "core" } 10 | ] 11 | 12 | [[tool.poetry.source]] 13 | name = "PyPI" 14 | priority = "primary" 15 | 16 | 17 | [[tool.poetry.source]] 18 | name = "jax" 19 | url = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" 20 | priority = "primary" 21 | 22 | [tool.poetry.dependencies] 23 | python = "^3.10" 24 | graphviz = "^0.20.1" 25 | wandb = "^0.18.6" 26 | jax = "0.4.35" 27 | jaxlib = "^0.4.34" 28 | flax = "^0.8.4" 29 | optax = "^0.1.8" 30 | orbax-checkpoint = "^0.10.1" 31 | chex = "^0.1.85" 32 | pgx = "^2.0.1" 33 | dm-haiku = "^0.0.12" 34 | cairosvg = "^2.7.1" 35 | 36 | [tool.poetry.group.dev.dependencies] 37 | ipykernel = "^6.25.1" 38 | 39 | [build-system] 40 | requires = ["poetry-core"] 41 | build-backend = "poetry.core.masonry.api" 42 | --------------------------------------------------------------------------------